From 7212a1332d4750cda9cedec89d0d4086ce3d612c Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Mon, 24 Nov 2025 04:28:15 +0900 Subject: [PATCH 1/6] Unify K operator signature in cake_eating_egm_jax with cake_eating_egm MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update the Coleman-Reffett operator K and solver functions in the JAX implementation to match the signature from the NumPy version: - K now takes (c_in, x_in, model) and returns (c_out, x_out) - solve_model_time_iter now takes (c_init, x_init) and returns (c, x) - Applied same changes to K_crra and solve_model_crra in exercises The efficient JAX implementation is fully preserved: - Vectorization with vmap - JIT compilation - Use of jax.lax.while_loop Tested successfully with maximum deviation of 1.43e-06 from analytical solution and execution time of ~0.009 seconds. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- lectures/cake_eating_egm_jax.md | 105 ++++++++++++++++++-------------- 1 file changed, 60 insertions(+), 45 deletions(-) diff --git a/lectures/cake_eating_egm_jax.md b/lectures/cake_eating_egm_jax.md index 2fc3be430..f93ed2207 100644 --- a/lectures/cake_eating_egm_jax.md +++ b/lectures/cake_eating_egm_jax.md @@ -116,7 +116,11 @@ Here's the Coleman-Reffett operator using EGM. The key JAX feature here is `vmap`, which vectorizes the computation over the grid points. ```{code-cell} python3 -def K(σ_array: jnp.ndarray, model: Model) -> jnp.ndarray: +def K( + c_in: jnp.ndarray, # Consumption values on the endogenous grid + x_in: jnp.ndarray, # Current endogenous grid + model: Model # Model specification + ): """ The Coleman-Reffett operator using EGM @@ -126,11 +130,8 @@ def K(σ_array: jnp.ndarray, model: Model) -> jnp.ndarray: β, α = model.β, model.α grid, shocks = model.grid, model.shocks - # Determine endogenous grid - x = grid + σ_array # x_i = k_i + c_i - # Linear interpolation of policy using endogenous grid - σ = lambda x_val: jnp.interp(x_val, x, σ_array) + σ = lambda x_val: jnp.interp(x_val, x_in, c_in) # Define function to compute consumption at a single grid point def compute_c(k): @@ -139,9 +140,12 @@ def K(σ_array: jnp.ndarray, model: Model) -> jnp.ndarray: # Vectorize over grid using vmap compute_c_vectorized = jax.vmap(compute_c) - c = compute_c_vectorized(grid) + c_out = compute_c_vectorized(grid) + + # Determine corresponding endogenous grid + x_out = grid + c_out # x_i = k_i + c_i - return c + return c_out, x_out ``` We define utility and production functions globally. @@ -171,47 +175,47 @@ The solver uses JAX's `jax.lax.while_loop` for the iteration and is JIT-compiled ```{code-cell} python3 @jax.jit def solve_model_time_iter(model: Model, - σ_init: jnp.ndarray, + c_init: jnp.ndarray, + x_init: jnp.ndarray, tol: float = 1e-5, - max_iter: int = 1000) -> jnp.ndarray: + max_iter: int = 1000): """ Solve the model using time iteration with EGM. """ def condition(loop_state): - i, σ, error = loop_state + i, c, x, error = loop_state return (error > tol) & (i < max_iter) def body(loop_state): - i, σ, error = loop_state - σ_new = K(σ, model) - error = jnp.max(jnp.abs(σ_new - σ)) - return i + 1, σ_new, error + i, c, x, error = loop_state + c_new, x_new = K(c, x, model) + error = jnp.max(jnp.abs(c_new - c)) + return i + 1, c_new, x_new, error # Initialize loop state - initial_state = (0, σ_init, tol + 1) + initial_state = (0, c_init, x_init, tol + 1) # Run the loop - i, σ, error = jax.lax.while_loop(condition, body, initial_state) + i, c, x, error = jax.lax.while_loop(condition, body, initial_state) - return σ + return c, x ``` We solve the model starting from an initial guess. ```{code-cell} python3 -σ_init = jnp.copy(grid) -σ = solve_model_time_iter(model, σ_init) +c_init = jnp.copy(grid) +x_init = grid + c_init +c, x = solve_model_time_iter(model, c_init, x_init) ``` Let's plot the resulting policy against the analytical solution. ```{code-cell} python3 -x = grid + σ # x_i = k_i + c_i - fig, ax = plt.subplots() -ax.plot(x, σ, lw=2, +ax.plot(x, c, lw=2, alpha=0.8, label='approximate policy function') ax.plot(x, σ_star(x, model.α, model.β), 'k--', @@ -224,7 +228,7 @@ plt.show() The fit is very good. ```{code-cell} python3 -max_dev = jnp.max(jnp.abs(σ - σ_star(x, model.α, model.β))) +max_dev = jnp.max(jnp.abs(c - σ_star(x, model.α, model.β))) print(f"Maximum absolute deviation: {max_dev:.7}") ``` @@ -232,7 +236,8 @@ The JAX implementation is very fast thanks to JIT compilation and vectorization. ```{code-cell} python3 with qe.Timer(precision=8): - σ = solve_model_time_iter(model, σ_init).block_until_ready() + c, x = solve_model_time_iter(model, c_init, x_init) + jax.block_until_ready(c) ``` This speed comes from: @@ -282,7 +287,12 @@ def u_prime_inv_crra(x, γ): Now we create a version of the Coleman-Reffett operator that takes $\gamma$ as a parameter. ```{code-cell} python3 -def K_crra(σ_array: jnp.ndarray, model: Model, γ: float) -> jnp.ndarray: +def K_crra( + c_in: jnp.ndarray, # Consumption values on the endogenous grid + x_in: jnp.ndarray, # Current endogenous grid + model: Model, # Model specification + γ: float # CRRA parameter + ): """ The Coleman-Reffett operator using EGM with CRRA utility """ @@ -290,11 +300,8 @@ def K_crra(σ_array: jnp.ndarray, model: Model, γ: float) -> jnp.ndarray: β, α = model.β, model.α grid, shocks = model.grid, model.shocks - # Determine endogenous grid - x = grid + σ_array - # Linear interpolation of policy using endogenous grid - σ = lambda x_val: jnp.interp(x_val, x, σ_array) + σ = lambda x_val: jnp.interp(x_val, x_in, c_in) # Define function to compute consumption at a single grid point def compute_c(k): @@ -303,9 +310,12 @@ def K_crra(σ_array: jnp.ndarray, model: Model, γ: float) -> jnp.ndarray: # Vectorize over grid using vmap compute_c_vectorized = jax.vmap(compute_c) - c = compute_c_vectorized(grid) + c_out = compute_c_vectorized(grid) + + # Determine corresponding endogenous grid + x_out = grid + c_out - return c + return c_out, x_out ``` We also need a solver that uses this operator. @@ -313,31 +323,32 @@ We also need a solver that uses this operator. ```{code-cell} python3 @jax.jit def solve_model_crra(model: Model, - σ_init: jnp.ndarray, + c_init: jnp.ndarray, + x_init: jnp.ndarray, γ: float, tol: float = 1e-5, - max_iter: int = 1000) -> jnp.ndarray: + max_iter: int = 1000): """ Solve the model using time iteration with EGM and CRRA utility. """ def condition(loop_state): - i, σ, error = loop_state + i, c, x, error = loop_state return (error > tol) & (i < max_iter) def body(loop_state): - i, σ, error = loop_state - σ_new = K_crra(σ, model, γ) - error = jnp.max(jnp.abs(σ_new - σ)) - return i + 1, σ_new, error + i, c, x, error = loop_state + c_new, x_new = K_crra(c, x, model, γ) + error = jnp.max(jnp.abs(c_new - c)) + return i + 1, c_new, x_new, error # Initialize loop state - initial_state = (0, σ_init, tol + 1) + initial_state = (0, c_init, x_init, tol + 1) # Run the loop - i, σ, error = jax.lax.while_loop(condition, body, initial_state) + i, c, x, error = jax.lax.while_loop(condition, body, initial_state) - return σ + return c, x ``` Now we solve for $\gamma = 1$ (log utility) and values approaching 1 from above. @@ -345,13 +356,17 @@ Now we solve for $\gamma = 1$ (log utility) and values approaching 1 from above. ```{code-cell} python3 γ_values = [1.0, 1.05, 1.1, 1.2] policies = {} +endogenous_grids = {} model_crra = create_model(α=α) for γ in γ_values: - σ_init = jnp.copy(model_crra.grid) - σ_gamma = solve_model_crra(model_crra, σ_init, γ).block_until_ready() - policies[γ] = σ_gamma + c_init = jnp.copy(model_crra.grid) + x_init = model_crra.grid + c_init + c_gamma, x_gamma = solve_model_crra(model_crra, c_init, x_init, γ) + jax.block_until_ready(c_gamma) + policies[γ] = c_gamma + endogenous_grids[γ] = x_gamma print(f"Solved for γ = {γ}") ``` @@ -361,7 +376,7 @@ Plot the policies on their endogenous grids. fig, ax = plt.subplots() for γ in γ_values: - x = model_crra.grid + policies[γ] + x = endogenous_grids[γ] if γ == 1.0: ax.plot(x, policies[γ], 'k-', linewidth=2, label=f'γ = {γ:.2f} (log utility)', alpha=0.8) From 2d2e99f9cf5d8a1fb10d3515839605953d2ae45d Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Mon, 24 Nov 2025 04:33:47 +0900 Subject: [PATCH 2/6] Remove unnecessary global alpha variable MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The α parameter is already defined with a default value (0.4) in the create_model function, so there's no need to set it as a global variable and pass it explicitly. Simplified: - model = create_model() instead of α = 0.4; model = create_model(α=α) - model_crra = create_model() in the CRRA exercise section Tested successfully with same results as before. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- lectures/cake_eating_egm_jax.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/lectures/cake_eating_egm_jax.md b/lectures/cake_eating_egm_jax.md index f93ed2207..8d107c31b 100644 --- a/lectures/cake_eating_egm_jax.md +++ b/lectures/cake_eating_egm_jax.md @@ -164,9 +164,7 @@ f_prime = lambda k, α: α * k**(α - 1) Now we create a model instance. ```{code-cell} python3 -α = 0.4 - -model = create_model(α=α) +model = create_model() grid = model.grid ``` @@ -358,7 +356,7 @@ Now we solve for $\gamma = 1$ (log utility) and values approaching 1 from above. policies = {} endogenous_grids = {} -model_crra = create_model(α=α) +model_crra = create_model() for γ in γ_values: c_init = jnp.copy(model_crra.grid) From 1ca20dd39411b838a594ca7d06587f10920420bd Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Mon, 24 Nov 2025 04:35:25 +0900 Subject: [PATCH 3/6] Remove redundant alpha parameter in cake_eating_egm MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The α parameter doesn't need to be passed explicitly to create_model since it already has a default value of 0.4. The α = 0.4 line is still needed for the lambda function closures (f and f_prime capture it). Changed: - create_model(u=u, f=f, α=α, ...) + create_model(u=u, f=f, ...) Tested successfully with same convergence behavior. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- lectures/cake_eating_egm.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lectures/cake_eating_egm.md b/lectures/cake_eating_egm.md index f7c465968..5f1fe9df5 100644 --- a/lectures/cake_eating_egm.md +++ b/lectures/cake_eating_egm.md @@ -273,7 +273,7 @@ u_prime_inv = lambda x: 1 / x f = lambda k: k**α f_prime = lambda k: α * k**(α - 1) -model = create_model(u=u, f=f, α=α, u_prime=u_prime, +model = create_model(u=u, f=f, u_prime=u_prime, f_prime=f_prime, u_prime_inv=u_prime_inv) s_grid = model.s_grid ``` From 4eac15975192fb8fa1ace9b84fb5eef650ba7919 Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Mon, 24 Nov 2025 04:48:47 +0900 Subject: [PATCH 4/6] Refactor f and f_prime to use explicit alpha parameter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changed from closure-based approach to explicit parameter passing: - f = lambda k, α: k**α (instead of f = lambda k: k**α with global α) - f_prime = lambda k, α: α * k**(α - 1) - Updated K operator to call f(s, α) and f_prime(s, α) This makes the NumPy version consistent with the JAX implementation and ensures the α stored in the Model is actually used in the K operator (previously it was unpacked but unused). Benefits: - Consistency between NumPy and JAX versions - Clearer function dependencies (α is an explicit parameter) - Actually uses model.α instead of relying on closure Tested successfully with same convergence behavior. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- lectures/cake_eating_egm.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/lectures/cake_eating_egm.md b/lectures/cake_eating_egm.md index 5f1fe9df5..f54ec590d 100644 --- a/lectures/cake_eating_egm.md +++ b/lectures/cake_eating_egm.md @@ -243,7 +243,7 @@ def K( # Solve for updated consumption value for i, s in enumerate(s_grid): - vals = u_prime(σ(f(s) * shocks)) * f_prime(s) * shocks + vals = u_prime(σ(f(s, α) * shocks)) * f_prime(s, α) * shocks c_out[i] = u_prime_inv(β * np.mean(vals)) # Determine corresponding endogenous grid @@ -266,12 +266,11 @@ First we create an instance. ```{code-cell} python3 # Define utility and production functions with derivatives -α = 0.4 u = lambda c: np.log(c) u_prime = lambda c: 1 / c u_prime_inv = lambda x: 1 / x -f = lambda k: k**α -f_prime = lambda k: α * k**(α - 1) +f = lambda k, α: k**α +f_prime = lambda k, α: α * k**(α - 1) model = create_model(u=u, f=f, u_prime=u_prime, f_prime=f_prime, u_prime_inv=u_prime_inv) From d90822f6a6877e7d8adf8e7fd8761ddabb9e7b22 Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Mon, 24 Nov 2025 04:52:56 +0900 Subject: [PATCH 5/6] Rename grid to s_grid for consistency with NumPy version MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changed all references from 'grid' to 's_grid' to match the NumPy implementation and clarify that this is the exogenous savings grid: - Model.grid → Model.s_grid - Updated comment: "state grid" → "exogenous savings grid" - Updated all variable names throughout (K, K_crra, initializations) - Also renamed loop variable from 'k' to 's' for consistency This makes the JAX version consistent with the NumPy version's naming conventions and makes it clearer that we're working with the exogenous grid for savings (not the endogenous grid for wealth x). Tested successfully with identical results. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- lectures/cake_eating_egm_jax.md | 38 ++++++++++++++++----------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/lectures/cake_eating_egm_jax.md b/lectures/cake_eating_egm_jax.md index 8d107c31b..a5189c0aa 100644 --- a/lectures/cake_eating_egm_jax.md +++ b/lectures/cake_eating_egm_jax.md @@ -85,7 +85,7 @@ class Model(NamedTuple): β: float # discount factor μ: float # shock location parameter s: float # shock scale parameter - grid: jnp.ndarray # state grid + s_grid: jnp.ndarray # exogenous savings grid shocks: jnp.ndarray # shock draws α: float # production function parameter @@ -101,14 +101,14 @@ def create_model(β: float = 0.96, """ Creates an instance of the cake eating model. """ - # Set up grid - grid = jnp.linspace(1e-4, grid_max, grid_size) + # Set up exogenous savings grid + s_grid = jnp.linspace(1e-4, grid_max, grid_size) # Store shocks (with a seed, so results are reproducible) key = jax.random.PRNGKey(seed) shocks = jnp.exp(μ + s * jax.random.normal(key, shape=(shock_size,))) - return Model(β=β, μ=μ, s=s, grid=grid, shocks=shocks, α=α) + return Model(β=β, μ=μ, s=s, s_grid=s_grid, shocks=shocks, α=α) ``` Here's the Coleman-Reffett operator using EGM. @@ -128,22 +128,22 @@ def K( # Simplify names β, α = model.β, model.α - grid, shocks = model.grid, model.shocks + s_grid, shocks = model.s_grid, model.shocks # Linear interpolation of policy using endogenous grid σ = lambda x_val: jnp.interp(x_val, x_in, c_in) # Define function to compute consumption at a single grid point - def compute_c(k): - vals = u_prime(σ(f(k, α) * shocks)) * f_prime(k, α) * shocks + def compute_c(s): + vals = u_prime(σ(f(s, α) * shocks)) * f_prime(s, α) * shocks return u_prime_inv(β * jnp.mean(vals)) # Vectorize over grid using vmap compute_c_vectorized = jax.vmap(compute_c) - c_out = compute_c_vectorized(grid) + c_out = compute_c_vectorized(s_grid) # Determine corresponding endogenous grid - x_out = grid + c_out # x_i = k_i + c_i + x_out = s_grid + c_out # x_i = s_i + c_i return c_out, x_out ``` @@ -165,7 +165,7 @@ Now we create a model instance. ```{code-cell} python3 model = create_model() -grid = model.grid +s_grid = model.s_grid ``` The solver uses JAX's `jax.lax.while_loop` for the iteration and is JIT-compiled for speed. @@ -203,8 +203,8 @@ def solve_model_time_iter(model: Model, We solve the model starting from an initial guess. ```{code-cell} python3 -c_init = jnp.copy(grid) -x_init = grid + c_init +c_init = jnp.copy(s_grid) +x_init = s_grid + c_init c, x = solve_model_time_iter(model, c_init, x_init) ``` @@ -296,22 +296,22 @@ def K_crra( """ # Simplify names β, α = model.β, model.α - grid, shocks = model.grid, model.shocks + s_grid, shocks = model.s_grid, model.shocks # Linear interpolation of policy using endogenous grid σ = lambda x_val: jnp.interp(x_val, x_in, c_in) # Define function to compute consumption at a single grid point - def compute_c(k): - vals = u_prime_crra(σ(f(k, α) * shocks), γ) * f_prime(k, α) * shocks + def compute_c(s): + vals = u_prime_crra(σ(f(s, α) * shocks), γ) * f_prime(s, α) * shocks return u_prime_inv_crra(β * jnp.mean(vals), γ) # Vectorize over grid using vmap compute_c_vectorized = jax.vmap(compute_c) - c_out = compute_c_vectorized(grid) + c_out = compute_c_vectorized(s_grid) # Determine corresponding endogenous grid - x_out = grid + c_out + x_out = s_grid + c_out return c_out, x_out ``` @@ -359,8 +359,8 @@ endogenous_grids = {} model_crra = create_model() for γ in γ_values: - c_init = jnp.copy(model_crra.grid) - x_init = model_crra.grid + c_init + c_init = jnp.copy(model_crra.s_grid) + x_init = model_crra.s_grid + c_init c_gamma, x_gamma = solve_model_crra(model_crra, c_init, x_init, γ) jax.block_until_ready(c_gamma) policies[γ] = c_gamma From 1d2f54eb868686f4ea5ef4feeb6334414f0d64b5 Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Mon, 24 Nov 2025 04:56:35 +0900 Subject: [PATCH 6/6] Fix remaining k references to use s for savings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changed remaining instances where k was used instead of s: - Mathematical notation: x = k + σ(k) → x = s + σ(s) - Added missing inline comment in K_crra: x_i = s_i + c_i This completes the transition to using 's' for savings throughout, maintaining consistency with the exogenous savings grid terminology. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- lectures/cake_eating_egm_jax.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lectures/cake_eating_egm_jax.md b/lectures/cake_eating_egm_jax.md index a5189c0aa..5a649ec4a 100644 --- a/lectures/cake_eating_egm_jax.md +++ b/lectures/cake_eating_egm_jax.md @@ -311,7 +311,7 @@ def K_crra( c_out = compute_c_vectorized(s_grid) # Determine corresponding endogenous grid - x_out = s_grid + c_out + x_out = s_grid + c_out # x_i = s_i + c_i return c_out, x_out ``` @@ -390,7 +390,7 @@ plt.show() Note that the plots for $\gamma > 1$ do not cover the entire x-axis range shown. -This is because the endogenous grid $x = k + \sigma(k)$ depends on the consumption policy, which varies with $\gamma$. +This is because the endogenous grid $x = s + \sigma(s)$ depends on the consumption policy, which varies with $\gamma$. Let's check the maximum deviation between the log utility case ($\gamma = 1.0$) and values approaching from above.