1+ import numpy as np
2+ from numba import njit , prange
3+ from quantecon import optimize
4+
5+ @njit
6+ def get_grid_nodes (grid ):
7+ """
8+ Get the actual grid points from a grid tuple.
9+ """
10+ x_min , x_max , x_num = grid
11+ return np .linspace (x_min , x_max , x_num )
12+
13+ @njit
14+ def linear_interp_1d_scalar (x_min , x_max , x_num , y_values , x_val ):
15+ """Helper function for scalar interpolation"""
16+ x_nodes = np .linspace (x_min , x_max , x_num )
17+
18+ # Extrapolation with linear extension
19+ if x_val <= x_nodes [0 ]:
20+ # Linear extrapolation using first two points
21+ if x_num >= 2 :
22+ slope = (y_values [1 ] - y_values [0 ]) / (x_nodes [1 ] - x_nodes [0 ])
23+ return y_values [0 ] + slope * (x_val - x_nodes [0 ])
24+ else :
25+ return y_values [0 ]
26+
27+ if x_val >= x_nodes [- 1 ]:
28+ # Linear extrapolation using last two points
29+ if x_num >= 2 :
30+ slope = (y_values [- 1 ] - y_values [- 2 ]) / (x_nodes [- 1 ] - x_nodes [- 2 ])
31+ return y_values [- 1 ] + slope * (x_val - x_nodes [- 1 ])
32+ else :
33+ return y_values [- 1 ]
34+
35+ # Binary search for the right interval
36+ left = 0
37+ right = x_num - 1
38+ while right - left > 1 :
39+ mid = (left + right ) // 2
40+ if x_nodes [mid ] <= x_val :
41+ left = mid
42+ else :
43+ right = mid
44+
45+ # Linear interpolation
46+ x_left = x_nodes [left ]
47+ x_right = x_nodes [right ]
48+ y_left = y_values [left ]
49+ y_right = y_values [right ]
50+
51+ weight = (x_val - x_left ) / (x_right - x_left )
52+ return y_left * (1 - weight ) + y_right * weight
53+
54+ @njit
55+ def linear_interp_1d (x_grid , y_values , x_query ):
56+ """
57+ Perform 1D linear interpolation.
58+ """
59+ x_min , x_max , x_num = x_grid
60+ return linear_interp_1d_scalar (x_min , x_max , x_num , y_values , x_query [0 ])
61+
162class AMSS :
263 # WARNING: THE CODE IS EXTREMELY SENSITIVE TO CHOCIES OF PARAMETERS.
364 # DO NOT CHANGE THE PARAMETERS AND EXPECT IT TO WORK
@@ -78,6 +139,10 @@ def simulate(self, s_hist, b_0):
78139 pref = self .pref
79140 x_grid , g , β , S = self .x_grid , self .g , self .β , self .S
80141 σ_v_star , σ_w_star = self .σ_v_star , self .σ_w_star
142+ Π = self .Π
143+
144+ # Extract the grid tuple from the list
145+ grid_tuple = x_grid [0 ] if isinstance (x_grid , list ) else x_grid
81146
82147 T = len (s_hist )
83148 s_0 = s_hist [0 ]
@@ -111,8 +176,8 @@ def simulate(self, s_hist, b_0):
111176 T = np .zeros (S )
112177 for s in range (S ):
113178 x_arr = np .array ([x_ ])
114- l [s ] = eval_linear ( x_grid , σ_v_star [s_ , :, s ], x_arr )
115- T [s ] = eval_linear ( x_grid , σ_v_star [s_ , :, S + s ], x_arr )
179+ l [s ] = linear_interp_1d ( grid_tuple , σ_v_star [s_ , :, s ], x_arr )
180+ T [s ] = linear_interp_1d ( grid_tuple , σ_v_star [s_ , :, S + s ], x_arr )
116181
117182 c = (1 - l ) - g
118183 u_c = pref .Uc (c , l )
@@ -135,6 +200,8 @@ def simulate(self, s_hist, b_0):
135200
136201def obj_factory (Π , β , x_grid , g ):
137202 S = len (Π )
203+ # Extract the grid tuple from the list
204+ grid_tuple = x_grid [0 ] if isinstance (x_grid , list ) else x_grid
138205
139206 @njit
140207 def obj_V (σ , state , V , pref ):
@@ -152,7 +219,7 @@ def obj_V(σ, state, V, pref):
152219 V_next = np .zeros (S )
153220
154221 for s in range (S ):
155- V_next [s ] = eval_linear ( x_grid , V [s ], np .array ([x [s ]]))
222+ V_next [s ] = linear_interp_1d ( grid_tuple , V [s ], np .array ([x [s ]]))
156223
157224 out = Π [s_ ] @ (pref .U (c , l ) + β * V_next )
158225
@@ -167,7 +234,7 @@ def obj_W(σ, state, V, pref):
167234 c = (1 - l ) - g [s_ ]
168235 x = - pref .Uc (c , l ) * (c - T - b_0 ) + pref .Ul (c , l ) * (1 - l )
169236
170- V_next = eval_linear ( x_grid , V [s_ ], np .array ([x ]))
237+ V_next = linear_interp_1d ( grid_tuple , V [s_ ], np .array ([x ]))
171238
172239 out = pref .U (c , l ) + β * V_next
173240
@@ -178,9 +245,11 @@ def obj_W(σ, state, V, pref):
178245
179246def bellman_operator_factory (Π , β , x_grid , g , bounds_v ):
180247 obj_V , obj_W = obj_factory (Π , β , x_grid , g )
181- n = x_grid [0 ][2 ]
248+ # Extract the grid tuple from the list
249+ grid_tuple = x_grid [0 ] if isinstance (x_grid , list ) else x_grid
250+ n = grid_tuple [2 ]
182251 S = len (Π )
183- x_nodes = nodes ( x_grid )
252+ x_nodes = get_grid_nodes ( grid_tuple )
184253
185254 @njit (parallel = True )
186255 def T_v (V , V_new , σ_star , pref ):
@@ -209,4 +278,4 @@ def T_w(W, σ_star, V, b_0, pref):
209278 W [s_ ] = res .fun
210279 σ_star [s_ ] = res .x
211280
212- return T_v , T_w
281+ return T_v , T_w
0 commit comments