Skip to content

Commit 7212a13

Browse files
jstacclaude
andcommitted
Unify K operator signature in cake_eating_egm_jax with cake_eating_egm
Update the Coleman-Reffett operator K and solver functions in the JAX implementation to match the signature from the NumPy version: - K now takes (c_in, x_in, model) and returns (c_out, x_out) - solve_model_time_iter now takes (c_init, x_init) and returns (c, x) - Applied same changes to K_crra and solve_model_crra in exercises The efficient JAX implementation is fully preserved: - Vectorization with vmap - JIT compilation - Use of jax.lax.while_loop Tested successfully with maximum deviation of 1.43e-06 from analytical solution and execution time of ~0.009 seconds. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent c1d524a commit 7212a13

File tree

1 file changed

+60
-45
lines changed

1 file changed

+60
-45
lines changed

lectures/cake_eating_egm_jax.md

Lines changed: 60 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,11 @@ Here's the Coleman-Reffett operator using EGM.
116116
The key JAX feature here is `vmap`, which vectorizes the computation over the grid points.
117117

118118
```{code-cell} python3
119-
def K(σ_array: jnp.ndarray, model: Model) -> jnp.ndarray:
119+
def K(
120+
c_in: jnp.ndarray, # Consumption values on the endogenous grid
121+
x_in: jnp.ndarray, # Current endogenous grid
122+
model: Model # Model specification
123+
):
120124
"""
121125
The Coleman-Reffett operator using EGM
122126
@@ -126,11 +130,8 @@ def K(σ_array: jnp.ndarray, model: Model) -> jnp.ndarray:
126130
β, α = model.β, model.α
127131
grid, shocks = model.grid, model.shocks
128132
129-
# Determine endogenous grid
130-
x = grid + σ_array # x_i = k_i + c_i
131-
132133
# Linear interpolation of policy using endogenous grid
133-
σ = lambda x_val: jnp.interp(x_val, x, σ_array)
134+
σ = lambda x_val: jnp.interp(x_val, x_in, c_in)
134135
135136
# Define function to compute consumption at a single grid point
136137
def compute_c(k):
@@ -139,9 +140,12 @@ def K(σ_array: jnp.ndarray, model: Model) -> jnp.ndarray:
139140
140141
# Vectorize over grid using vmap
141142
compute_c_vectorized = jax.vmap(compute_c)
142-
c = compute_c_vectorized(grid)
143+
c_out = compute_c_vectorized(grid)
144+
145+
# Determine corresponding endogenous grid
146+
x_out = grid + c_out # x_i = k_i + c_i
143147
144-
return c
148+
return c_out, x_out
145149
```
146150

147151
We define utility and production functions globally.
@@ -171,47 +175,47 @@ The solver uses JAX's `jax.lax.while_loop` for the iteration and is JIT-compiled
171175
```{code-cell} python3
172176
@jax.jit
173177
def solve_model_time_iter(model: Model,
174-
σ_init: jnp.ndarray,
178+
c_init: jnp.ndarray,
179+
x_init: jnp.ndarray,
175180
tol: float = 1e-5,
176-
max_iter: int = 1000) -> jnp.ndarray:
181+
max_iter: int = 1000):
177182
"""
178183
Solve the model using time iteration with EGM.
179184
"""
180185
181186
def condition(loop_state):
182-
i, σ, error = loop_state
187+
i, c, x, error = loop_state
183188
return (error > tol) & (i < max_iter)
184189
185190
def body(loop_state):
186-
i, σ, error = loop_state
187-
σ_new = K(σ, model)
188-
error = jnp.max(jnp.abs(σ_new - σ))
189-
return i + 1, σ_new, error
191+
i, c, x, error = loop_state
192+
c_new, x_new = K(c, x, model)
193+
error = jnp.max(jnp.abs(c_new - c))
194+
return i + 1, c_new, x_new, error
190195
191196
# Initialize loop state
192-
initial_state = (0, σ_init, tol + 1)
197+
initial_state = (0, c_init, x_init, tol + 1)
193198
194199
# Run the loop
195-
i, σ, error = jax.lax.while_loop(condition, body, initial_state)
200+
i, c, x, error = jax.lax.while_loop(condition, body, initial_state)
196201
197-
return σ
202+
return c, x
198203
```
199204

200205
We solve the model starting from an initial guess.
201206

202207
```{code-cell} python3
203-
σ_init = jnp.copy(grid)
204-
σ = solve_model_time_iter(model, σ_init)
208+
c_init = jnp.copy(grid)
209+
x_init = grid + c_init
210+
c, x = solve_model_time_iter(model, c_init, x_init)
205211
```
206212

207213
Let's plot the resulting policy against the analytical solution.
208214

209215
```{code-cell} python3
210-
x = grid + σ # x_i = k_i + c_i
211-
212216
fig, ax = plt.subplots()
213217
214-
ax.plot(x, σ, lw=2,
218+
ax.plot(x, c, lw=2,
215219
alpha=0.8, label='approximate policy function')
216220
217221
ax.plot(x, σ_star(x, model.α, model.β), 'k--',
@@ -224,15 +228,16 @@ plt.show()
224228
The fit is very good.
225229

226230
```{code-cell} python3
227-
max_dev = jnp.max(jnp.abs(σ - σ_star(x, model.α, model.β)))
231+
max_dev = jnp.max(jnp.abs(c - σ_star(x, model.α, model.β)))
228232
print(f"Maximum absolute deviation: {max_dev:.7}")
229233
```
230234

231235
The JAX implementation is very fast thanks to JIT compilation and vectorization.
232236

233237
```{code-cell} python3
234238
with qe.Timer(precision=8):
235-
σ = solve_model_time_iter(model, σ_init).block_until_ready()
239+
c, x = solve_model_time_iter(model, c_init, x_init)
240+
jax.block_until_ready(c)
236241
```
237242

238243
This speed comes from:
@@ -282,19 +287,21 @@ def u_prime_inv_crra(x, γ):
282287
Now we create a version of the Coleman-Reffett operator that takes $\gamma$ as a parameter.
283288

284289
```{code-cell} python3
285-
def K_crra(σ_array: jnp.ndarray, model: Model, γ: float) -> jnp.ndarray:
290+
def K_crra(
291+
c_in: jnp.ndarray, # Consumption values on the endogenous grid
292+
x_in: jnp.ndarray, # Current endogenous grid
293+
model: Model, # Model specification
294+
γ: float # CRRA parameter
295+
):
286296
"""
287297
The Coleman-Reffett operator using EGM with CRRA utility
288298
"""
289299
# Simplify names
290300
β, α = model.β, model.α
291301
grid, shocks = model.grid, model.shocks
292302
293-
# Determine endogenous grid
294-
x = grid + σ_array
295-
296303
# Linear interpolation of policy using endogenous grid
297-
σ = lambda x_val: jnp.interp(x_val, x, σ_array)
304+
σ = lambda x_val: jnp.interp(x_val, x_in, c_in)
298305
299306
# Define function to compute consumption at a single grid point
300307
def compute_c(k):
@@ -303,55 +310,63 @@ def K_crra(σ_array: jnp.ndarray, model: Model, γ: float) -> jnp.ndarray:
303310
304311
# Vectorize over grid using vmap
305312
compute_c_vectorized = jax.vmap(compute_c)
306-
c = compute_c_vectorized(grid)
313+
c_out = compute_c_vectorized(grid)
314+
315+
# Determine corresponding endogenous grid
316+
x_out = grid + c_out
307317
308-
return c
318+
return c_out, x_out
309319
```
310320

311321
We also need a solver that uses this operator.
312322

313323
```{code-cell} python3
314324
@jax.jit
315325
def solve_model_crra(model: Model,
316-
σ_init: jnp.ndarray,
326+
c_init: jnp.ndarray,
327+
x_init: jnp.ndarray,
317328
γ: float,
318329
tol: float = 1e-5,
319-
max_iter: int = 1000) -> jnp.ndarray:
330+
max_iter: int = 1000):
320331
"""
321332
Solve the model using time iteration with EGM and CRRA utility.
322333
"""
323334
324335
def condition(loop_state):
325-
i, σ, error = loop_state
336+
i, c, x, error = loop_state
326337
return (error > tol) & (i < max_iter)
327338
328339
def body(loop_state):
329-
i, σ, error = loop_state
330-
σ_new = K_crra(σ, model, γ)
331-
error = jnp.max(jnp.abs(σ_new - σ))
332-
return i + 1, σ_new, error
340+
i, c, x, error = loop_state
341+
c_new, x_new = K_crra(c, x, model, γ)
342+
error = jnp.max(jnp.abs(c_new - c))
343+
return i + 1, c_new, x_new, error
333344
334345
# Initialize loop state
335-
initial_state = (0, σ_init, tol + 1)
346+
initial_state = (0, c_init, x_init, tol + 1)
336347
337348
# Run the loop
338-
i, σ, error = jax.lax.while_loop(condition, body, initial_state)
349+
i, c, x, error = jax.lax.while_loop(condition, body, initial_state)
339350
340-
return σ
351+
return c, x
341352
```
342353

343354
Now we solve for $\gamma = 1$ (log utility) and values approaching 1 from above.
344355

345356
```{code-cell} python3
346357
γ_values = [1.0, 1.05, 1.1, 1.2]
347358
policies = {}
359+
endogenous_grids = {}
348360
349361
model_crra = create_model(α=α)
350362
351363
for γ in γ_values:
352-
σ_init = jnp.copy(model_crra.grid)
353-
σ_gamma = solve_model_crra(model_crra, σ_init, γ).block_until_ready()
354-
policies[γ] = σ_gamma
364+
c_init = jnp.copy(model_crra.grid)
365+
x_init = model_crra.grid + c_init
366+
c_gamma, x_gamma = solve_model_crra(model_crra, c_init, x_init, γ)
367+
jax.block_until_ready(c_gamma)
368+
policies[γ] = c_gamma
369+
endogenous_grids[γ] = x_gamma
355370
print(f"Solved for γ = {γ}")
356371
```
357372

@@ -361,7 +376,7 @@ Plot the policies on their endogenous grids.
361376
fig, ax = plt.subplots()
362377
363378
for γ in γ_values:
364-
x = model_crra.grid + policies[γ]
379+
x = endogenous_grids[γ]
365380
if γ == 1.0:
366381
ax.plot(x, policies[γ], 'k-', linewidth=2,
367382
label=f'γ = {γ:.2f} (log utility)', alpha=0.8)

0 commit comments

Comments
 (0)