@@ -116,7 +116,11 @@ Here's the Coleman-Reffett operator using EGM.
116116The key JAX feature here is ` vmap ` , which vectorizes the computation over the grid points.
117117
118118``` {code-cell} python3
119- def K(σ_array: jnp.ndarray, model: Model) -> jnp.ndarray:
119+ def K(
120+ c_in: jnp.ndarray, # Consumption values on the endogenous grid
121+ x_in: jnp.ndarray, # Current endogenous grid
122+ model: Model # Model specification
123+ ):
120124 """
121125 The Coleman-Reffett operator using EGM
122126
@@ -126,11 +130,8 @@ def K(σ_array: jnp.ndarray, model: Model) -> jnp.ndarray:
126130 β, α = model.β, model.α
127131 grid, shocks = model.grid, model.shocks
128132
129- # Determine endogenous grid
130- x = grid + σ_array # x_i = k_i + c_i
131-
132133 # Linear interpolation of policy using endogenous grid
133- σ = lambda x_val: jnp.interp(x_val, x, σ_array )
134+ σ = lambda x_val: jnp.interp(x_val, x_in, c_in )
134135
135136 # Define function to compute consumption at a single grid point
136137 def compute_c(k):
@@ -139,9 +140,12 @@ def K(σ_array: jnp.ndarray, model: Model) -> jnp.ndarray:
139140
140141 # Vectorize over grid using vmap
141142 compute_c_vectorized = jax.vmap(compute_c)
142- c = compute_c_vectorized(grid)
143+ c_out = compute_c_vectorized(grid)
144+
145+ # Determine corresponding endogenous grid
146+ x_out = grid + c_out # x_i = k_i + c_i
143147
144- return c
148+ return c_out, x_out
145149```
146150
147151We define utility and production functions globally.
@@ -171,47 +175,47 @@ The solver uses JAX's `jax.lax.while_loop` for the iteration and is JIT-compiled
171175``` {code-cell} python3
172176@jax.jit
173177def solve_model_time_iter(model: Model,
174- σ_init: jnp.ndarray,
178+ c_init: jnp.ndarray,
179+ x_init: jnp.ndarray,
175180 tol: float = 1e-5,
176- max_iter: int = 1000) -> jnp.ndarray :
181+ max_iter: int = 1000):
177182 """
178183 Solve the model using time iteration with EGM.
179184 """
180185
181186 def condition(loop_state):
182- i, σ , error = loop_state
187+ i, c, x , error = loop_state
183188 return (error > tol) & (i < max_iter)
184189
185190 def body(loop_state):
186- i, σ , error = loop_state
187- σ_new = K(σ , model)
188- error = jnp.max(jnp.abs(σ_new - σ ))
189- return i + 1, σ_new , error
191+ i, c, x , error = loop_state
192+ c_new, x_new = K(c, x , model)
193+ error = jnp.max(jnp.abs(c_new - c ))
194+ return i + 1, c_new, x_new , error
190195
191196 # Initialize loop state
192- initial_state = (0, σ_init , tol + 1)
197+ initial_state = (0, c_init, x_init , tol + 1)
193198
194199 # Run the loop
195- i, σ , error = jax.lax.while_loop(condition, body, initial_state)
200+ i, c, x , error = jax.lax.while_loop(condition, body, initial_state)
196201
197- return σ
202+ return c, x
198203```
199204
200205We solve the model starting from an initial guess.
201206
202207``` {code-cell} python3
203- σ_init = jnp.copy(grid)
204- σ = solve_model_time_iter(model, σ_init)
208+ c_init = jnp.copy(grid)
209+ x_init = grid + c_init
210+ c, x = solve_model_time_iter(model, c_init, x_init)
205211```
206212
207213Let's plot the resulting policy against the analytical solution.
208214
209215``` {code-cell} python3
210- x = grid + σ # x_i = k_i + c_i
211-
212216fig, ax = plt.subplots()
213217
214- ax.plot(x, σ , lw=2,
218+ ax.plot(x, c , lw=2,
215219 alpha=0.8, label='approximate policy function')
216220
217221ax.plot(x, σ_star(x, model.α, model.β), 'k--',
@@ -224,15 +228,16 @@ plt.show()
224228The fit is very good.
225229
226230``` {code-cell} python3
227- max_dev = jnp.max(jnp.abs(σ - σ_star(x, model.α, model.β)))
231+ max_dev = jnp.max(jnp.abs(c - σ_star(x, model.α, model.β)))
228232print(f"Maximum absolute deviation: {max_dev:.7}")
229233```
230234
231235The JAX implementation is very fast thanks to JIT compilation and vectorization.
232236
233237``` {code-cell} python3
234238with qe.Timer(precision=8):
235- σ = solve_model_time_iter(model, σ_init).block_until_ready()
239+ c, x = solve_model_time_iter(model, c_init, x_init)
240+ jax.block_until_ready(c)
236241```
237242
238243This speed comes from:
@@ -282,19 +287,21 @@ def u_prime_inv_crra(x, γ):
282287Now we create a version of the Coleman-Reffett operator that takes $\gamma$ as a parameter.
283288
284289``` {code-cell} python3
285- def K_crra(σ_array: jnp.ndarray, model: Model, γ: float) -> jnp.ndarray:
290+ def K_crra(
291+ c_in: jnp.ndarray, # Consumption values on the endogenous grid
292+ x_in: jnp.ndarray, # Current endogenous grid
293+ model: Model, # Model specification
294+ γ: float # CRRA parameter
295+ ):
286296 """
287297 The Coleman-Reffett operator using EGM with CRRA utility
288298 """
289299 # Simplify names
290300 β, α = model.β, model.α
291301 grid, shocks = model.grid, model.shocks
292302
293- # Determine endogenous grid
294- x = grid + σ_array
295-
296303 # Linear interpolation of policy using endogenous grid
297- σ = lambda x_val: jnp.interp(x_val, x, σ_array )
304+ σ = lambda x_val: jnp.interp(x_val, x_in, c_in )
298305
299306 # Define function to compute consumption at a single grid point
300307 def compute_c(k):
@@ -303,55 +310,63 @@ def K_crra(σ_array: jnp.ndarray, model: Model, γ: float) -> jnp.ndarray:
303310
304311 # Vectorize over grid using vmap
305312 compute_c_vectorized = jax.vmap(compute_c)
306- c = compute_c_vectorized(grid)
313+ c_out = compute_c_vectorized(grid)
314+
315+ # Determine corresponding endogenous grid
316+ x_out = grid + c_out
307317
308- return c
318+ return c_out, x_out
309319```
310320
311321We also need a solver that uses this operator.
312322
313323``` {code-cell} python3
314324@jax.jit
315325def solve_model_crra(model: Model,
316- σ_init: jnp.ndarray,
326+ c_init: jnp.ndarray,
327+ x_init: jnp.ndarray,
317328 γ: float,
318329 tol: float = 1e-5,
319- max_iter: int = 1000) -> jnp.ndarray :
330+ max_iter: int = 1000):
320331 """
321332 Solve the model using time iteration with EGM and CRRA utility.
322333 """
323334
324335 def condition(loop_state):
325- i, σ , error = loop_state
336+ i, c, x , error = loop_state
326337 return (error > tol) & (i < max_iter)
327338
328339 def body(loop_state):
329- i, σ , error = loop_state
330- σ_new = K_crra(σ , model, γ)
331- error = jnp.max(jnp.abs(σ_new - σ ))
332- return i + 1, σ_new , error
340+ i, c, x , error = loop_state
341+ c_new, x_new = K_crra(c, x , model, γ)
342+ error = jnp.max(jnp.abs(c_new - c ))
343+ return i + 1, c_new, x_new , error
333344
334345 # Initialize loop state
335- initial_state = (0, σ_init , tol + 1)
346+ initial_state = (0, c_init, x_init , tol + 1)
336347
337348 # Run the loop
338- i, σ , error = jax.lax.while_loop(condition, body, initial_state)
349+ i, c, x , error = jax.lax.while_loop(condition, body, initial_state)
339350
340- return σ
351+ return c, x
341352```
342353
343354Now we solve for $\gamma = 1$ (log utility) and values approaching 1 from above.
344355
345356``` {code-cell} python3
346357γ_values = [1.0, 1.05, 1.1, 1.2]
347358policies = {}
359+ endogenous_grids = {}
348360
349361model_crra = create_model(α=α)
350362
351363for γ in γ_values:
352- σ_init = jnp.copy(model_crra.grid)
353- σ_gamma = solve_model_crra(model_crra, σ_init, γ).block_until_ready()
354- policies[γ] = σ_gamma
364+ c_init = jnp.copy(model_crra.grid)
365+ x_init = model_crra.grid + c_init
366+ c_gamma, x_gamma = solve_model_crra(model_crra, c_init, x_init, γ)
367+ jax.block_until_ready(c_gamma)
368+ policies[γ] = c_gamma
369+ endogenous_grids[γ] = x_gamma
355370 print(f"Solved for γ = {γ}")
356371```
357372
@@ -361,7 +376,7 @@ Plot the policies on their endogenous grids.
361376fig, ax = plt.subplots()
362377
363378for γ in γ_values:
364- x = model_crra.grid + policies [γ]
379+ x = endogenous_grids [γ]
365380 if γ == 1.0:
366381 ax.plot(x, policies[γ], 'k-', linewidth=2,
367382 label=f'γ = {γ:.2f} (log utility)', alpha=0.8)
0 commit comments