diff --git a/lectures/cake_eating_egm.md b/lectures/cake_eating_egm.md index f7c465968..f54ec590d 100644 --- a/lectures/cake_eating_egm.md +++ b/lectures/cake_eating_egm.md @@ -243,7 +243,7 @@ def K( # Solve for updated consumption value for i, s in enumerate(s_grid): - vals = u_prime(σ(f(s) * shocks)) * f_prime(s) * shocks + vals = u_prime(σ(f(s, α) * shocks)) * f_prime(s, α) * shocks c_out[i] = u_prime_inv(β * np.mean(vals)) # Determine corresponding endogenous grid @@ -266,14 +266,13 @@ First we create an instance. ```{code-cell} python3 # Define utility and production functions with derivatives -α = 0.4 u = lambda c: np.log(c) u_prime = lambda c: 1 / c u_prime_inv = lambda x: 1 / x -f = lambda k: k**α -f_prime = lambda k: α * k**(α - 1) +f = lambda k, α: k**α +f_prime = lambda k, α: α * k**(α - 1) -model = create_model(u=u, f=f, α=α, u_prime=u_prime, +model = create_model(u=u, f=f, u_prime=u_prime, f_prime=f_prime, u_prime_inv=u_prime_inv) s_grid = model.s_grid ``` diff --git a/lectures/cake_eating_egm_jax.md b/lectures/cake_eating_egm_jax.md index 2fc3be430..5a649ec4a 100644 --- a/lectures/cake_eating_egm_jax.md +++ b/lectures/cake_eating_egm_jax.md @@ -85,7 +85,7 @@ class Model(NamedTuple): β: float # discount factor μ: float # shock location parameter s: float # shock scale parameter - grid: jnp.ndarray # state grid + s_grid: jnp.ndarray # exogenous savings grid shocks: jnp.ndarray # shock draws α: float # production function parameter @@ -101,14 +101,14 @@ def create_model(β: float = 0.96, """ Creates an instance of the cake eating model. """ - # Set up grid - grid = jnp.linspace(1e-4, grid_max, grid_size) + # Set up exogenous savings grid + s_grid = jnp.linspace(1e-4, grid_max, grid_size) # Store shocks (with a seed, so results are reproducible) key = jax.random.PRNGKey(seed) shocks = jnp.exp(μ + s * jax.random.normal(key, shape=(shock_size,))) - return Model(β=β, μ=μ, s=s, grid=grid, shocks=shocks, α=α) + return Model(β=β, μ=μ, s=s, s_grid=s_grid, shocks=shocks, α=α) ``` Here's the Coleman-Reffett operator using EGM. @@ -116,7 +116,11 @@ Here's the Coleman-Reffett operator using EGM. The key JAX feature here is `vmap`, which vectorizes the computation over the grid points. ```{code-cell} python3 -def K(σ_array: jnp.ndarray, model: Model) -> jnp.ndarray: +def K( + c_in: jnp.ndarray, # Consumption values on the endogenous grid + x_in: jnp.ndarray, # Current endogenous grid + model: Model # Model specification + ): """ The Coleman-Reffett operator using EGM @@ -124,24 +128,24 @@ def K(σ_array: jnp.ndarray, model: Model) -> jnp.ndarray: # Simplify names β, α = model.β, model.α - grid, shocks = model.grid, model.shocks - - # Determine endogenous grid - x = grid + σ_array # x_i = k_i + c_i + s_grid, shocks = model.s_grid, model.shocks # Linear interpolation of policy using endogenous grid - σ = lambda x_val: jnp.interp(x_val, x, σ_array) + σ = lambda x_val: jnp.interp(x_val, x_in, c_in) # Define function to compute consumption at a single grid point - def compute_c(k): - vals = u_prime(σ(f(k, α) * shocks)) * f_prime(k, α) * shocks + def compute_c(s): + vals = u_prime(σ(f(s, α) * shocks)) * f_prime(s, α) * shocks return u_prime_inv(β * jnp.mean(vals)) # Vectorize over grid using vmap compute_c_vectorized = jax.vmap(compute_c) - c = compute_c_vectorized(grid) + c_out = compute_c_vectorized(s_grid) + + # Determine corresponding endogenous grid + x_out = s_grid + c_out # x_i = s_i + c_i - return c + return c_out, x_out ``` We define utility and production functions globally. @@ -160,10 +164,8 @@ f_prime = lambda k, α: α * k**(α - 1) Now we create a model instance. ```{code-cell} python3 -α = 0.4 - -model = create_model(α=α) -grid = model.grid +model = create_model() +s_grid = model.s_grid ``` The solver uses JAX's `jax.lax.while_loop` for the iteration and is JIT-compiled for speed. @@ -171,47 +173,47 @@ The solver uses JAX's `jax.lax.while_loop` for the iteration and is JIT-compiled ```{code-cell} python3 @jax.jit def solve_model_time_iter(model: Model, - σ_init: jnp.ndarray, + c_init: jnp.ndarray, + x_init: jnp.ndarray, tol: float = 1e-5, - max_iter: int = 1000) -> jnp.ndarray: + max_iter: int = 1000): """ Solve the model using time iteration with EGM. """ def condition(loop_state): - i, σ, error = loop_state + i, c, x, error = loop_state return (error > tol) & (i < max_iter) def body(loop_state): - i, σ, error = loop_state - σ_new = K(σ, model) - error = jnp.max(jnp.abs(σ_new - σ)) - return i + 1, σ_new, error + i, c, x, error = loop_state + c_new, x_new = K(c, x, model) + error = jnp.max(jnp.abs(c_new - c)) + return i + 1, c_new, x_new, error # Initialize loop state - initial_state = (0, σ_init, tol + 1) + initial_state = (0, c_init, x_init, tol + 1) # Run the loop - i, σ, error = jax.lax.while_loop(condition, body, initial_state) + i, c, x, error = jax.lax.while_loop(condition, body, initial_state) - return σ + return c, x ``` We solve the model starting from an initial guess. ```{code-cell} python3 -σ_init = jnp.copy(grid) -σ = solve_model_time_iter(model, σ_init) +c_init = jnp.copy(s_grid) +x_init = s_grid + c_init +c, x = solve_model_time_iter(model, c_init, x_init) ``` Let's plot the resulting policy against the analytical solution. ```{code-cell} python3 -x = grid + σ # x_i = k_i + c_i - fig, ax = plt.subplots() -ax.plot(x, σ, lw=2, +ax.plot(x, c, lw=2, alpha=0.8, label='approximate policy function') ax.plot(x, σ_star(x, model.α, model.β), 'k--', @@ -224,7 +226,7 @@ plt.show() The fit is very good. ```{code-cell} python3 -max_dev = jnp.max(jnp.abs(σ - σ_star(x, model.α, model.β))) +max_dev = jnp.max(jnp.abs(c - σ_star(x, model.α, model.β))) print(f"Maximum absolute deviation: {max_dev:.7}") ``` @@ -232,7 +234,8 @@ The JAX implementation is very fast thanks to JIT compilation and vectorization. ```{code-cell} python3 with qe.Timer(precision=8): - σ = solve_model_time_iter(model, σ_init).block_until_ready() + c, x = solve_model_time_iter(model, c_init, x_init) + jax.block_until_ready(c) ``` This speed comes from: @@ -282,30 +285,35 @@ def u_prime_inv_crra(x, γ): Now we create a version of the Coleman-Reffett operator that takes $\gamma$ as a parameter. ```{code-cell} python3 -def K_crra(σ_array: jnp.ndarray, model: Model, γ: float) -> jnp.ndarray: +def K_crra( + c_in: jnp.ndarray, # Consumption values on the endogenous grid + x_in: jnp.ndarray, # Current endogenous grid + model: Model, # Model specification + γ: float # CRRA parameter + ): """ The Coleman-Reffett operator using EGM with CRRA utility """ # Simplify names β, α = model.β, model.α - grid, shocks = model.grid, model.shocks - - # Determine endogenous grid - x = grid + σ_array + s_grid, shocks = model.s_grid, model.shocks # Linear interpolation of policy using endogenous grid - σ = lambda x_val: jnp.interp(x_val, x, σ_array) + σ = lambda x_val: jnp.interp(x_val, x_in, c_in) # Define function to compute consumption at a single grid point - def compute_c(k): - vals = u_prime_crra(σ(f(k, α) * shocks), γ) * f_prime(k, α) * shocks + def compute_c(s): + vals = u_prime_crra(σ(f(s, α) * shocks), γ) * f_prime(s, α) * shocks return u_prime_inv_crra(β * jnp.mean(vals), γ) # Vectorize over grid using vmap compute_c_vectorized = jax.vmap(compute_c) - c = compute_c_vectorized(grid) + c_out = compute_c_vectorized(s_grid) + + # Determine corresponding endogenous grid + x_out = s_grid + c_out # x_i = s_i + c_i - return c + return c_out, x_out ``` We also need a solver that uses this operator. @@ -313,31 +321,32 @@ We also need a solver that uses this operator. ```{code-cell} python3 @jax.jit def solve_model_crra(model: Model, - σ_init: jnp.ndarray, + c_init: jnp.ndarray, + x_init: jnp.ndarray, γ: float, tol: float = 1e-5, - max_iter: int = 1000) -> jnp.ndarray: + max_iter: int = 1000): """ Solve the model using time iteration with EGM and CRRA utility. """ def condition(loop_state): - i, σ, error = loop_state + i, c, x, error = loop_state return (error > tol) & (i < max_iter) def body(loop_state): - i, σ, error = loop_state - σ_new = K_crra(σ, model, γ) - error = jnp.max(jnp.abs(σ_new - σ)) - return i + 1, σ_new, error + i, c, x, error = loop_state + c_new, x_new = K_crra(c, x, model, γ) + error = jnp.max(jnp.abs(c_new - c)) + return i + 1, c_new, x_new, error # Initialize loop state - initial_state = (0, σ_init, tol + 1) + initial_state = (0, c_init, x_init, tol + 1) # Run the loop - i, σ, error = jax.lax.while_loop(condition, body, initial_state) + i, c, x, error = jax.lax.while_loop(condition, body, initial_state) - return σ + return c, x ``` Now we solve for $\gamma = 1$ (log utility) and values approaching 1 from above. @@ -345,13 +354,17 @@ Now we solve for $\gamma = 1$ (log utility) and values approaching 1 from above. ```{code-cell} python3 γ_values = [1.0, 1.05, 1.1, 1.2] policies = {} +endogenous_grids = {} -model_crra = create_model(α=α) +model_crra = create_model() for γ in γ_values: - σ_init = jnp.copy(model_crra.grid) - σ_gamma = solve_model_crra(model_crra, σ_init, γ).block_until_ready() - policies[γ] = σ_gamma + c_init = jnp.copy(model_crra.s_grid) + x_init = model_crra.s_grid + c_init + c_gamma, x_gamma = solve_model_crra(model_crra, c_init, x_init, γ) + jax.block_until_ready(c_gamma) + policies[γ] = c_gamma + endogenous_grids[γ] = x_gamma print(f"Solved for γ = {γ}") ``` @@ -361,7 +374,7 @@ Plot the policies on their endogenous grids. fig, ax = plt.subplots() for γ in γ_values: - x = model_crra.grid + policies[γ] + x = endogenous_grids[γ] if γ == 1.0: ax.plot(x, policies[γ], 'k-', linewidth=2, label=f'γ = {γ:.2f} (log utility)', alpha=0.8) @@ -377,7 +390,7 @@ plt.show() Note that the plots for $\gamma > 1$ do not cover the entire x-axis range shown. -This is because the endogenous grid $x = k + \sigma(k)$ depends on the consumption policy, which varies with $\gamma$. +This is because the endogenous grid $x = s + \sigma(s)$ depends on the consumption policy, which varies with $\gamma$. Let's check the maximum deviation between the log utility case ($\gamma = 1.0$) and values approaching from above.