Skip to content

Commit d90822f

Browse files
jstacclaude
andcommitted
Rename grid to s_grid for consistency with NumPy version
Changed all references from 'grid' to 's_grid' to match the NumPy implementation and clarify that this is the exogenous savings grid: - Model.grid → Model.s_grid - Updated comment: "state grid" → "exogenous savings grid" - Updated all variable names throughout (K, K_crra, initializations) - Also renamed loop variable from 'k' to 's' for consistency This makes the JAX version consistent with the NumPy version's naming conventions and makes it clearer that we're working with the exogenous grid for savings (not the endogenous grid for wealth x). Tested successfully with identical results. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 4eac159 commit d90822f

File tree

1 file changed

+19
-19
lines changed

1 file changed

+19
-19
lines changed

lectures/cake_eating_egm_jax.md

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class Model(NamedTuple):
8585
β: float # discount factor
8686
μ: float # shock location parameter
8787
s: float # shock scale parameter
88-
grid: jnp.ndarray # state grid
88+
s_grid: jnp.ndarray # exogenous savings grid
8989
shocks: jnp.ndarray # shock draws
9090
α: float # production function parameter
9191
@@ -101,14 +101,14 @@ def create_model(β: float = 0.96,
101101
"""
102102
Creates an instance of the cake eating model.
103103
"""
104-
# Set up grid
105-
grid = jnp.linspace(1e-4, grid_max, grid_size)
104+
# Set up exogenous savings grid
105+
s_grid = jnp.linspace(1e-4, grid_max, grid_size)
106106
107107
# Store shocks (with a seed, so results are reproducible)
108108
key = jax.random.PRNGKey(seed)
109109
shocks = jnp.exp(μ + s * jax.random.normal(key, shape=(shock_size,)))
110110
111-
return Model(β=β, μ=μ, s=s, grid=grid, shocks=shocks, α=α)
111+
return Model(β=β, μ=μ, s=s, s_grid=s_grid, shocks=shocks, α=α)
112112
```
113113

114114
Here's the Coleman-Reffett operator using EGM.
@@ -128,22 +128,22 @@ def K(
128128
129129
# Simplify names
130130
β, α = model.β, model.α
131-
grid, shocks = model.grid, model.shocks
131+
s_grid, shocks = model.s_grid, model.shocks
132132
133133
# Linear interpolation of policy using endogenous grid
134134
σ = lambda x_val: jnp.interp(x_val, x_in, c_in)
135135
136136
# Define function to compute consumption at a single grid point
137-
def compute_c(k):
138-
vals = u_prime(σ(f(k, α) * shocks)) * f_prime(k, α) * shocks
137+
def compute_c(s):
138+
vals = u_prime(σ(f(s, α) * shocks)) * f_prime(s, α) * shocks
139139
return u_prime_inv(β * jnp.mean(vals))
140140
141141
# Vectorize over grid using vmap
142142
compute_c_vectorized = jax.vmap(compute_c)
143-
c_out = compute_c_vectorized(grid)
143+
c_out = compute_c_vectorized(s_grid)
144144
145145
# Determine corresponding endogenous grid
146-
x_out = grid + c_out # x_i = k_i + c_i
146+
x_out = s_grid + c_out # x_i = s_i + c_i
147147
148148
return c_out, x_out
149149
```
@@ -165,7 +165,7 @@ Now we create a model instance.
165165

166166
```{code-cell} python3
167167
model = create_model()
168-
grid = model.grid
168+
s_grid = model.s_grid
169169
```
170170

171171
The solver uses JAX's `jax.lax.while_loop` for the iteration and is JIT-compiled for speed.
@@ -203,8 +203,8 @@ def solve_model_time_iter(model: Model,
203203
We solve the model starting from an initial guess.
204204

205205
```{code-cell} python3
206-
c_init = jnp.copy(grid)
207-
x_init = grid + c_init
206+
c_init = jnp.copy(s_grid)
207+
x_init = s_grid + c_init
208208
c, x = solve_model_time_iter(model, c_init, x_init)
209209
```
210210

@@ -296,22 +296,22 @@ def K_crra(
296296
"""
297297
# Simplify names
298298
β, α = model.β, model.α
299-
grid, shocks = model.grid, model.shocks
299+
s_grid, shocks = model.s_grid, model.shocks
300300
301301
# Linear interpolation of policy using endogenous grid
302302
σ = lambda x_val: jnp.interp(x_val, x_in, c_in)
303303
304304
# Define function to compute consumption at a single grid point
305-
def compute_c(k):
306-
vals = u_prime_crra(σ(f(k, α) * shocks), γ) * f_prime(k, α) * shocks
305+
def compute_c(s):
306+
vals = u_prime_crra(σ(f(s, α) * shocks), γ) * f_prime(s, α) * shocks
307307
return u_prime_inv_crra(β * jnp.mean(vals), γ)
308308
309309
# Vectorize over grid using vmap
310310
compute_c_vectorized = jax.vmap(compute_c)
311-
c_out = compute_c_vectorized(grid)
311+
c_out = compute_c_vectorized(s_grid)
312312
313313
# Determine corresponding endogenous grid
314-
x_out = grid + c_out
314+
x_out = s_grid + c_out
315315
316316
return c_out, x_out
317317
```
@@ -359,8 +359,8 @@ endogenous_grids = {}
359359
model_crra = create_model()
360360
361361
for γ in γ_values:
362-
c_init = jnp.copy(model_crra.grid)
363-
x_init = model_crra.grid + c_init
362+
c_init = jnp.copy(model_crra.s_grid)
363+
x_init = model_crra.s_grid + c_init
364364
c_gamma, x_gamma = solve_model_crra(model_crra, c_init, x_init, γ)
365365
jax.block_until_ready(c_gamma)
366366
policies[γ] = c_gamma

0 commit comments

Comments
 (0)