Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions lectures/cake_eating_egm.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def K(

# Solve for updated consumption value
for i, s in enumerate(s_grid):
vals = u_prime(σ(f(s) * shocks)) * f_prime(s) * shocks
vals = u_prime(σ(f(s, α) * shocks)) * f_prime(s, α) * shocks
c_out[i] = u_prime_inv(β * np.mean(vals))

# Determine corresponding endogenous grid
Expand All @@ -266,14 +266,13 @@ First we create an instance.

```{code-cell} python3
# Define utility and production functions with derivatives
α = 0.4
u = lambda c: np.log(c)
u_prime = lambda c: 1 / c
u_prime_inv = lambda x: 1 / x
f = lambda k: k**α
f_prime = lambda k: α * k**(α - 1)
f = lambda k, α: k**α
f_prime = lambda k, α: α * k**(α - 1)

model = create_model(u=u, f=f, α=α, u_prime=u_prime,
model = create_model(u=u, f=f, u_prime=u_prime,
f_prime=f_prime, u_prime_inv=u_prime_inv)
s_grid = model.s_grid
```
Expand Down
135 changes: 74 additions & 61 deletions lectures/cake_eating_egm_jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class Model(NamedTuple):
β: float # discount factor
μ: float # shock location parameter
s: float # shock scale parameter
grid: jnp.ndarray # state grid
s_grid: jnp.ndarray # exogenous savings grid
shocks: jnp.ndarray # shock draws
α: float # production function parameter

Expand All @@ -101,47 +101,51 @@ def create_model(β: float = 0.96,
"""
Creates an instance of the cake eating model.
"""
# Set up grid
grid = jnp.linspace(1e-4, grid_max, grid_size)
# Set up exogenous savings grid
s_grid = jnp.linspace(1e-4, grid_max, grid_size)

# Store shocks (with a seed, so results are reproducible)
key = jax.random.PRNGKey(seed)
shocks = jnp.exp(μ + s * jax.random.normal(key, shape=(shock_size,)))

return Model(β=β, μ=μ, s=s, grid=grid, shocks=shocks, α=α)
return Model(β=β, μ=μ, s=s, s_grid=s_grid, shocks=shocks, α=α)
```

Here's the Coleman-Reffett operator using EGM.

The key JAX feature here is `vmap`, which vectorizes the computation over the grid points.

```{code-cell} python3
def K(σ_array: jnp.ndarray, model: Model) -> jnp.ndarray:
def K(
c_in: jnp.ndarray, # Consumption values on the endogenous grid
x_in: jnp.ndarray, # Current endogenous grid
model: Model # Model specification
):
"""
The Coleman-Reffett operator using EGM

"""

# Simplify names
β, α = model.β, model.α
grid, shocks = model.grid, model.shocks

# Determine endogenous grid
x = grid + σ_array # x_i = k_i + c_i
s_grid, shocks = model.s_grid, model.shocks

# Linear interpolation of policy using endogenous grid
σ = lambda x_val: jnp.interp(x_val, x, σ_array)
σ = lambda x_val: jnp.interp(x_val, x_in, c_in)

# Define function to compute consumption at a single grid point
def compute_c(k):
vals = u_prime(σ(f(k, α) * shocks)) * f_prime(k, α) * shocks
def compute_c(s):
vals = u_prime(σ(f(s, α) * shocks)) * f_prime(s, α) * shocks
return u_prime_inv(β * jnp.mean(vals))

# Vectorize over grid using vmap
compute_c_vectorized = jax.vmap(compute_c)
c = compute_c_vectorized(grid)
c_out = compute_c_vectorized(s_grid)

# Determine corresponding endogenous grid
x_out = s_grid + c_out # x_i = s_i + c_i

return c
return c_out, x_out
```

We define utility and production functions globally.
Expand All @@ -160,58 +164,56 @@ f_prime = lambda k, α: α * k**(α - 1)
Now we create a model instance.

```{code-cell} python3
α = 0.4

model = create_model(α=α)
grid = model.grid
model = create_model()
s_grid = model.s_grid
```

The solver uses JAX's `jax.lax.while_loop` for the iteration and is JIT-compiled for speed.

```{code-cell} python3
@jax.jit
def solve_model_time_iter(model: Model,
σ_init: jnp.ndarray,
c_init: jnp.ndarray,
x_init: jnp.ndarray,
tol: float = 1e-5,
max_iter: int = 1000) -> jnp.ndarray:
max_iter: int = 1000):
"""
Solve the model using time iteration with EGM.
"""

def condition(loop_state):
i, σ, error = loop_state
i, c, x, error = loop_state
return (error > tol) & (i < max_iter)

def body(loop_state):
i, σ, error = loop_state
σ_new = K(σ, model)
error = jnp.max(jnp.abs(σ_new - σ))
return i + 1, σ_new, error
i, c, x, error = loop_state
c_new, x_new = K(c, x, model)
error = jnp.max(jnp.abs(c_new - c))
return i + 1, c_new, x_new, error

# Initialize loop state
initial_state = (0, σ_init, tol + 1)
initial_state = (0, c_init, x_init, tol + 1)

# Run the loop
i, σ, error = jax.lax.while_loop(condition, body, initial_state)
i, c, x, error = jax.lax.while_loop(condition, body, initial_state)

return σ
return c, x
```

We solve the model starting from an initial guess.

```{code-cell} python3
σ_init = jnp.copy(grid)
σ = solve_model_time_iter(model, σ_init)
c_init = jnp.copy(s_grid)
x_init = s_grid + c_init
c, x = solve_model_time_iter(model, c_init, x_init)
```

Let's plot the resulting policy against the analytical solution.

```{code-cell} python3
x = grid + σ # x_i = k_i + c_i

fig, ax = plt.subplots()

ax.plot(x, σ, lw=2,
ax.plot(x, c, lw=2,
alpha=0.8, label='approximate policy function')

ax.plot(x, σ_star(x, model.α, model.β), 'k--',
Expand All @@ -224,15 +226,16 @@ plt.show()
The fit is very good.

```{code-cell} python3
max_dev = jnp.max(jnp.abs(σ - σ_star(x, model.α, model.β)))
max_dev = jnp.max(jnp.abs(c - σ_star(x, model.α, model.β)))
print(f"Maximum absolute deviation: {max_dev:.7}")
```

The JAX implementation is very fast thanks to JIT compilation and vectorization.

```{code-cell} python3
with qe.Timer(precision=8):
σ = solve_model_time_iter(model, σ_init).block_until_ready()
c, x = solve_model_time_iter(model, c_init, x_init)
jax.block_until_ready(c)
```

This speed comes from:
Expand Down Expand Up @@ -282,76 +285,86 @@ def u_prime_inv_crra(x, γ):
Now we create a version of the Coleman-Reffett operator that takes $\gamma$ as a parameter.

```{code-cell} python3
def K_crra(σ_array: jnp.ndarray, model: Model, γ: float) -> jnp.ndarray:
def K_crra(
c_in: jnp.ndarray, # Consumption values on the endogenous grid
x_in: jnp.ndarray, # Current endogenous grid
model: Model, # Model specification
γ: float # CRRA parameter
):
"""
The Coleman-Reffett operator using EGM with CRRA utility
"""
# Simplify names
β, α = model.β, model.α
grid, shocks = model.grid, model.shocks

# Determine endogenous grid
x = grid + σ_array
s_grid, shocks = model.s_grid, model.shocks

# Linear interpolation of policy using endogenous grid
σ = lambda x_val: jnp.interp(x_val, x, σ_array)
σ = lambda x_val: jnp.interp(x_val, x_in, c_in)

# Define function to compute consumption at a single grid point
def compute_c(k):
vals = u_prime_crra(σ(f(k, α) * shocks), γ) * f_prime(k, α) * shocks
def compute_c(s):
vals = u_prime_crra(σ(f(s, α) * shocks), γ) * f_prime(s, α) * shocks
return u_prime_inv_crra(β * jnp.mean(vals), γ)

# Vectorize over grid using vmap
compute_c_vectorized = jax.vmap(compute_c)
c = compute_c_vectorized(grid)
c_out = compute_c_vectorized(s_grid)

# Determine corresponding endogenous grid
x_out = s_grid + c_out # x_i = s_i + c_i

return c
return c_out, x_out
```

We also need a solver that uses this operator.

```{code-cell} python3
@jax.jit
def solve_model_crra(model: Model,
σ_init: jnp.ndarray,
c_init: jnp.ndarray,
x_init: jnp.ndarray,
γ: float,
tol: float = 1e-5,
max_iter: int = 1000) -> jnp.ndarray:
max_iter: int = 1000):
"""
Solve the model using time iteration with EGM and CRRA utility.
"""

def condition(loop_state):
i, σ, error = loop_state
i, c, x, error = loop_state
return (error > tol) & (i < max_iter)

def body(loop_state):
i, σ, error = loop_state
σ_new = K_crra(σ, model, γ)
error = jnp.max(jnp.abs(σ_new - σ))
return i + 1, σ_new, error
i, c, x, error = loop_state
c_new, x_new = K_crra(c, x, model, γ)
error = jnp.max(jnp.abs(c_new - c))
return i + 1, c_new, x_new, error

# Initialize loop state
initial_state = (0, σ_init, tol + 1)
initial_state = (0, c_init, x_init, tol + 1)

# Run the loop
i, σ, error = jax.lax.while_loop(condition, body, initial_state)
i, c, x, error = jax.lax.while_loop(condition, body, initial_state)

return σ
return c, x
```

Now we solve for $\gamma = 1$ (log utility) and values approaching 1 from above.

```{code-cell} python3
γ_values = [1.0, 1.05, 1.1, 1.2]
policies = {}
endogenous_grids = {}

model_crra = create_model(α=α)
model_crra = create_model()

for γ in γ_values:
σ_init = jnp.copy(model_crra.grid)
σ_gamma = solve_model_crra(model_crra, σ_init, γ).block_until_ready()
policies[γ] = σ_gamma
c_init = jnp.copy(model_crra.s_grid)
x_init = model_crra.s_grid + c_init
c_gamma, x_gamma = solve_model_crra(model_crra, c_init, x_init, γ)
jax.block_until_ready(c_gamma)
policies[γ] = c_gamma
endogenous_grids[γ] = x_gamma
print(f"Solved for γ = {γ}")
```

Expand All @@ -361,7 +374,7 @@ Plot the policies on their endogenous grids.
fig, ax = plt.subplots()

for γ in γ_values:
x = model_crra.grid + policies[γ]
x = endogenous_grids[γ]
if γ == 1.0:
ax.plot(x, policies[γ], 'k-', linewidth=2,
label=f'γ = {γ:.2f} (log utility)', alpha=0.8)
Expand All @@ -377,7 +390,7 @@ plt.show()

Note that the plots for $\gamma > 1$ do not cover the entire x-axis range shown.

This is because the endogenous grid $x = k + \sigma(k)$ depends on the consumption policy, which varies with $\gamma$.
This is because the endogenous grid $x = s + \sigma(s)$ depends on the consumption policy, which varies with $\gamma$.

Let's check the maximum deviation between the log utility case ($\gamma = 1.0$) and values approaching from above.

Expand Down
Loading