Skip to content

Commit ed8e89a

Browse files
jstacclaude
andauthored
Unify K operator signature in cake_eating_egm_jax with cake_eating_egm (#732)
* Unify K operator signature in cake_eating_egm_jax with cake_eating_egm Update the Coleman-Reffett operator K and solver functions in the JAX implementation to match the signature from the NumPy version: - K now takes (c_in, x_in, model) and returns (c_out, x_out) - solve_model_time_iter now takes (c_init, x_init) and returns (c, x) - Applied same changes to K_crra and solve_model_crra in exercises The efficient JAX implementation is fully preserved: - Vectorization with vmap - JIT compilation - Use of jax.lax.while_loop Tested successfully with maximum deviation of 1.43e-06 from analytical solution and execution time of ~0.009 seconds. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Remove unnecessary global alpha variable The α parameter is already defined with a default value (0.4) in the create_model function, so there's no need to set it as a global variable and pass it explicitly. Simplified: - model = create_model() instead of α = 0.4; model = create_model(α=α) - model_crra = create_model() in the CRRA exercise section Tested successfully with same results as before. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Remove redundant alpha parameter in cake_eating_egm The α parameter doesn't need to be passed explicitly to create_model since it already has a default value of 0.4. The α = 0.4 line is still needed for the lambda function closures (f and f_prime capture it). Changed: - create_model(u=u, f=f, α=α, ...) + create_model(u=u, f=f, ...) Tested successfully with same convergence behavior. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Refactor f and f_prime to use explicit alpha parameter Changed from closure-based approach to explicit parameter passing: - f = lambda k, α: k**α (instead of f = lambda k: k**α with global α) - f_prime = lambda k, α: α * k**(α - 1) - Updated K operator to call f(s, α) and f_prime(s, α) This makes the NumPy version consistent with the JAX implementation and ensures the α stored in the Model is actually used in the K operator (previously it was unpacked but unused). Benefits: - Consistency between NumPy and JAX versions - Clearer function dependencies (α is an explicit parameter) - Actually uses model.α instead of relying on closure Tested successfully with same convergence behavior. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Rename grid to s_grid for consistency with NumPy version Changed all references from 'grid' to 's_grid' to match the NumPy implementation and clarify that this is the exogenous savings grid: - Model.grid → Model.s_grid - Updated comment: "state grid" → "exogenous savings grid" - Updated all variable names throughout (K, K_crra, initializations) - Also renamed loop variable from 'k' to 's' for consistency This makes the JAX version consistent with the NumPy version's naming conventions and makes it clearer that we're working with the exogenous grid for savings (not the endogenous grid for wealth x). Tested successfully with identical results. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Fix remaining k references to use s for savings Changed remaining instances where k was used instead of s: - Mathematical notation: x = k + σ(k) → x = s + σ(s) - Added missing inline comment in K_crra: x_i = s_i + c_i This completes the transition to using 's' for savings throughout, maintaining consistency with the exogenous savings grid terminology. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent c1d524a commit ed8e89a

File tree

2 files changed

+78
-66
lines changed

2 files changed

+78
-66
lines changed

lectures/cake_eating_egm.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def K(
243243
244244
# Solve for updated consumption value
245245
for i, s in enumerate(s_grid):
246-
vals = u_prime(σ(f(s) * shocks)) * f_prime(s) * shocks
246+
vals = u_prime(σ(f(s, α) * shocks)) * f_prime(s, α) * shocks
247247
c_out[i] = u_prime_inv(β * np.mean(vals))
248248
249249
# Determine corresponding endogenous grid
@@ -266,14 +266,13 @@ First we create an instance.
266266

267267
```{code-cell} python3
268268
# Define utility and production functions with derivatives
269-
α = 0.4
270269
u = lambda c: np.log(c)
271270
u_prime = lambda c: 1 / c
272271
u_prime_inv = lambda x: 1 / x
273-
f = lambda k: k**α
274-
f_prime = lambda k: α * k**(α - 1)
272+
f = lambda k, α: k**α
273+
f_prime = lambda k, α: α * k**(α - 1)
275274
276-
model = create_model(u=u, f=f, α=α, u_prime=u_prime,
275+
model = create_model(u=u, f=f, u_prime=u_prime,
277276
f_prime=f_prime, u_prime_inv=u_prime_inv)
278277
s_grid = model.s_grid
279278
```

lectures/cake_eating_egm_jax.md

Lines changed: 74 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class Model(NamedTuple):
8585
β: float # discount factor
8686
μ: float # shock location parameter
8787
s: float # shock scale parameter
88-
grid: jnp.ndarray # state grid
88+
s_grid: jnp.ndarray # exogenous savings grid
8989
shocks: jnp.ndarray # shock draws
9090
α: float # production function parameter
9191
@@ -101,47 +101,51 @@ def create_model(β: float = 0.96,
101101
"""
102102
Creates an instance of the cake eating model.
103103
"""
104-
# Set up grid
105-
grid = jnp.linspace(1e-4, grid_max, grid_size)
104+
# Set up exogenous savings grid
105+
s_grid = jnp.linspace(1e-4, grid_max, grid_size)
106106
107107
# Store shocks (with a seed, so results are reproducible)
108108
key = jax.random.PRNGKey(seed)
109109
shocks = jnp.exp(μ + s * jax.random.normal(key, shape=(shock_size,)))
110110
111-
return Model(β=β, μ=μ, s=s, grid=grid, shocks=shocks, α=α)
111+
return Model(β=β, μ=μ, s=s, s_grid=s_grid, shocks=shocks, α=α)
112112
```
113113

114114
Here's the Coleman-Reffett operator using EGM.
115115

116116
The key JAX feature here is `vmap`, which vectorizes the computation over the grid points.
117117

118118
```{code-cell} python3
119-
def K(σ_array: jnp.ndarray, model: Model) -> jnp.ndarray:
119+
def K(
120+
c_in: jnp.ndarray, # Consumption values on the endogenous grid
121+
x_in: jnp.ndarray, # Current endogenous grid
122+
model: Model # Model specification
123+
):
120124
"""
121125
The Coleman-Reffett operator using EGM
122126
123127
"""
124128
125129
# Simplify names
126130
β, α = model.β, model.α
127-
grid, shocks = model.grid, model.shocks
128-
129-
# Determine endogenous grid
130-
x = grid + σ_array # x_i = k_i + c_i
131+
s_grid, shocks = model.s_grid, model.shocks
131132
132133
# Linear interpolation of policy using endogenous grid
133-
σ = lambda x_val: jnp.interp(x_val, x, σ_array)
134+
σ = lambda x_val: jnp.interp(x_val, x_in, c_in)
134135
135136
# Define function to compute consumption at a single grid point
136-
def compute_c(k):
137-
vals = u_prime(σ(f(k, α) * shocks)) * f_prime(k, α) * shocks
137+
def compute_c(s):
138+
vals = u_prime(σ(f(s, α) * shocks)) * f_prime(s, α) * shocks
138139
return u_prime_inv(β * jnp.mean(vals))
139140
140141
# Vectorize over grid using vmap
141142
compute_c_vectorized = jax.vmap(compute_c)
142-
c = compute_c_vectorized(grid)
143+
c_out = compute_c_vectorized(s_grid)
144+
145+
# Determine corresponding endogenous grid
146+
x_out = s_grid + c_out # x_i = s_i + c_i
143147
144-
return c
148+
return c_out, x_out
145149
```
146150

147151
We define utility and production functions globally.
@@ -160,58 +164,56 @@ f_prime = lambda k, α: α * k**(α - 1)
160164
Now we create a model instance.
161165

162166
```{code-cell} python3
163-
α = 0.4
164-
165-
model = create_model(α=α)
166-
grid = model.grid
167+
model = create_model()
168+
s_grid = model.s_grid
167169
```
168170

169171
The solver uses JAX's `jax.lax.while_loop` for the iteration and is JIT-compiled for speed.
170172

171173
```{code-cell} python3
172174
@jax.jit
173175
def solve_model_time_iter(model: Model,
174-
σ_init: jnp.ndarray,
176+
c_init: jnp.ndarray,
177+
x_init: jnp.ndarray,
175178
tol: float = 1e-5,
176-
max_iter: int = 1000) -> jnp.ndarray:
179+
max_iter: int = 1000):
177180
"""
178181
Solve the model using time iteration with EGM.
179182
"""
180183
181184
def condition(loop_state):
182-
i, σ, error = loop_state
185+
i, c, x, error = loop_state
183186
return (error > tol) & (i < max_iter)
184187
185188
def body(loop_state):
186-
i, σ, error = loop_state
187-
σ_new = K(σ, model)
188-
error = jnp.max(jnp.abs(σ_new - σ))
189-
return i + 1, σ_new, error
189+
i, c, x, error = loop_state
190+
c_new, x_new = K(c, x, model)
191+
error = jnp.max(jnp.abs(c_new - c))
192+
return i + 1, c_new, x_new, error
190193
191194
# Initialize loop state
192-
initial_state = (0, σ_init, tol + 1)
195+
initial_state = (0, c_init, x_init, tol + 1)
193196
194197
# Run the loop
195-
i, σ, error = jax.lax.while_loop(condition, body, initial_state)
198+
i, c, x, error = jax.lax.while_loop(condition, body, initial_state)
196199
197-
return σ
200+
return c, x
198201
```
199202

200203
We solve the model starting from an initial guess.
201204

202205
```{code-cell} python3
203-
σ_init = jnp.copy(grid)
204-
σ = solve_model_time_iter(model, σ_init)
206+
c_init = jnp.copy(s_grid)
207+
x_init = s_grid + c_init
208+
c, x = solve_model_time_iter(model, c_init, x_init)
205209
```
206210

207211
Let's plot the resulting policy against the analytical solution.
208212

209213
```{code-cell} python3
210-
x = grid + σ # x_i = k_i + c_i
211-
212214
fig, ax = plt.subplots()
213215
214-
ax.plot(x, σ, lw=2,
216+
ax.plot(x, c, lw=2,
215217
alpha=0.8, label='approximate policy function')
216218
217219
ax.plot(x, σ_star(x, model.α, model.β), 'k--',
@@ -224,15 +226,16 @@ plt.show()
224226
The fit is very good.
225227

226228
```{code-cell} python3
227-
max_dev = jnp.max(jnp.abs(σ - σ_star(x, model.α, model.β)))
229+
max_dev = jnp.max(jnp.abs(c - σ_star(x, model.α, model.β)))
228230
print(f"Maximum absolute deviation: {max_dev:.7}")
229231
```
230232

231233
The JAX implementation is very fast thanks to JIT compilation and vectorization.
232234

233235
```{code-cell} python3
234236
with qe.Timer(precision=8):
235-
σ = solve_model_time_iter(model, σ_init).block_until_ready()
237+
c, x = solve_model_time_iter(model, c_init, x_init)
238+
jax.block_until_ready(c)
236239
```
237240

238241
This speed comes from:
@@ -282,76 +285,86 @@ def u_prime_inv_crra(x, γ):
282285
Now we create a version of the Coleman-Reffett operator that takes $\gamma$ as a parameter.
283286

284287
```{code-cell} python3
285-
def K_crra(σ_array: jnp.ndarray, model: Model, γ: float) -> jnp.ndarray:
288+
def K_crra(
289+
c_in: jnp.ndarray, # Consumption values on the endogenous grid
290+
x_in: jnp.ndarray, # Current endogenous grid
291+
model: Model, # Model specification
292+
γ: float # CRRA parameter
293+
):
286294
"""
287295
The Coleman-Reffett operator using EGM with CRRA utility
288296
"""
289297
# Simplify names
290298
β, α = model.β, model.α
291-
grid, shocks = model.grid, model.shocks
292-
293-
# Determine endogenous grid
294-
x = grid + σ_array
299+
s_grid, shocks = model.s_grid, model.shocks
295300
296301
# Linear interpolation of policy using endogenous grid
297-
σ = lambda x_val: jnp.interp(x_val, x, σ_array)
302+
σ = lambda x_val: jnp.interp(x_val, x_in, c_in)
298303
299304
# Define function to compute consumption at a single grid point
300-
def compute_c(k):
301-
vals = u_prime_crra(σ(f(k, α) * shocks), γ) * f_prime(k, α) * shocks
305+
def compute_c(s):
306+
vals = u_prime_crra(σ(f(s, α) * shocks), γ) * f_prime(s, α) * shocks
302307
return u_prime_inv_crra(β * jnp.mean(vals), γ)
303308
304309
# Vectorize over grid using vmap
305310
compute_c_vectorized = jax.vmap(compute_c)
306-
c = compute_c_vectorized(grid)
311+
c_out = compute_c_vectorized(s_grid)
312+
313+
# Determine corresponding endogenous grid
314+
x_out = s_grid + c_out # x_i = s_i + c_i
307315
308-
return c
316+
return c_out, x_out
309317
```
310318

311319
We also need a solver that uses this operator.
312320

313321
```{code-cell} python3
314322
@jax.jit
315323
def solve_model_crra(model: Model,
316-
σ_init: jnp.ndarray,
324+
c_init: jnp.ndarray,
325+
x_init: jnp.ndarray,
317326
γ: float,
318327
tol: float = 1e-5,
319-
max_iter: int = 1000) -> jnp.ndarray:
328+
max_iter: int = 1000):
320329
"""
321330
Solve the model using time iteration with EGM and CRRA utility.
322331
"""
323332
324333
def condition(loop_state):
325-
i, σ, error = loop_state
334+
i, c, x, error = loop_state
326335
return (error > tol) & (i < max_iter)
327336
328337
def body(loop_state):
329-
i, σ, error = loop_state
330-
σ_new = K_crra(σ, model, γ)
331-
error = jnp.max(jnp.abs(σ_new - σ))
332-
return i + 1, σ_new, error
338+
i, c, x, error = loop_state
339+
c_new, x_new = K_crra(c, x, model, γ)
340+
error = jnp.max(jnp.abs(c_new - c))
341+
return i + 1, c_new, x_new, error
333342
334343
# Initialize loop state
335-
initial_state = (0, σ_init, tol + 1)
344+
initial_state = (0, c_init, x_init, tol + 1)
336345
337346
# Run the loop
338-
i, σ, error = jax.lax.while_loop(condition, body, initial_state)
347+
i, c, x, error = jax.lax.while_loop(condition, body, initial_state)
339348
340-
return σ
349+
return c, x
341350
```
342351

343352
Now we solve for $\gamma = 1$ (log utility) and values approaching 1 from above.
344353

345354
```{code-cell} python3
346355
γ_values = [1.0, 1.05, 1.1, 1.2]
347356
policies = {}
357+
endogenous_grids = {}
348358
349-
model_crra = create_model(α=α)
359+
model_crra = create_model()
350360
351361
for γ in γ_values:
352-
σ_init = jnp.copy(model_crra.grid)
353-
σ_gamma = solve_model_crra(model_crra, σ_init, γ).block_until_ready()
354-
policies[γ] = σ_gamma
362+
c_init = jnp.copy(model_crra.s_grid)
363+
x_init = model_crra.s_grid + c_init
364+
c_gamma, x_gamma = solve_model_crra(model_crra, c_init, x_init, γ)
365+
jax.block_until_ready(c_gamma)
366+
policies[γ] = c_gamma
367+
endogenous_grids[γ] = x_gamma
355368
print(f"Solved for γ = {γ}")
356369
```
357370

@@ -361,7 +374,7 @@ Plot the policies on their endogenous grids.
361374
fig, ax = plt.subplots()
362375
363376
for γ in γ_values:
364-
x = model_crra.grid + policies[γ]
377+
x = endogenous_grids[γ]
365378
if γ == 1.0:
366379
ax.plot(x, policies[γ], 'k-', linewidth=2,
367380
label=f'γ = {γ:.2f} (log utility)', alpha=0.8)
@@ -377,7 +390,7 @@ plt.show()
377390

378391
Note that the plots for $\gamma > 1$ do not cover the entire x-axis range shown.
379392

380-
This is because the endogenous grid $x = k + \sigma(k)$ depends on the consumption policy, which varies with $\gamma$.
393+
This is because the endogenous grid $x = s + \sigma(s)$ depends on the consumption policy, which varies with $\gamma$.
381394

382395
Let's check the maximum deviation between the log utility case ($\gamma = 1.0$) and values approaching from above.
383396

0 commit comments

Comments
 (0)