Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 79 additions & 108 deletions lectures/ifp_advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,12 @@ We require the following imports:
```{code-cell} ipython3
import matplotlib.pyplot as plt
import numpy as np
from quantecon import MarkovChain
import quantecon as qe
import jax
import jax.numpy as jnp
from jax import vmap
from typing import NamedTuple
from functools import partial
```


Expand Down Expand Up @@ -129,7 +130,7 @@ does not grow too quickly.

When $\{R_t\}$ was constant we required that $\beta R < 1$.

Since it is now stochastic, we require that
Since it is now stochastic, we require (see {cite}`ma2020income`) that

```{math}
:label: fpbc2
Expand All @@ -140,15 +141,15 @@ G_R := \lim_{n \to \infty}
\left(\mathbb E \prod_{t=1}^n R_t \right)^{1/n}
```

Notice that, when $\{R_t\}$ takes some constant value $R$, this
reduces to the previous restriction $\beta R < 1$.

The value $G_R$ can be thought of as the long run (geometric) average
gross rate of return.

More intuition behind {eq}`fpbc2` is provided in {cite}`ma2020income`.
To simplify this lecture, we will *assume that the interest rate process is
IID*.

In that case, it is clear from the definition of $G_R$ that $G_R$ is just $\mathbb E R_t$.

Discussion on how to check it is given below.
We test the condition $\beta \mathbb E R_t < 1$ in the code below.

Finally, we impose some routine technical restrictions on non-financial income.

Expand Down Expand Up @@ -309,28 +310,6 @@ obtained by interpolating $\{a_i, c_i\}$ at each $z$.

In what follows, we use linear interpolation.

### Testing the Assumptions

Convergence of time iteration is dependent on the condition $\beta G_R < 1$ being satisfied.

One can check this using the fact that $G_R$ is equal to the spectral
radius of the matrix $L$ defined by

$$
L(z, \hat z) := P(z, \hat z) \int R(\hat z, x) \phi(x) dx
$$

This identity is proved in {cite}`ma2020income`, where $\phi$ is the
density of the innovation $\zeta_t$ to returns on assets.

(Remember that $\mathsf Z$ is a finite set, so this expression defines a matrix.)

Checking the condition is even easier when $\{R_t\}$ is IID.

In that case, it is clear from the definition of $G_R$ that $G_R$
is just $\mathbb E R_t$.

We test the condition $\beta \mathbb E R_t < 1$ in the code below.

## Implementation

Expand All @@ -354,32 +333,31 @@ class IFP(NamedTuple):
ζ_draws: jnp.ndarray


def create_ifp(γ=1.5,
β=0.96,
P=np.array([(0.9, 0.1),
(0.1, 0.9)]),
a_r=0.16,
b_r=0.0,
a_y=0.2,
b_y=0.5,
shock_draw_size=100,
grid_max=100,
grid_size=100,
seed=1234):
def create_ifp(
γ=1.5, # Utility parameter
β=0.96, # Discount factor
P=jnp.array([(0.9, 0.1), # Default Markov chain for Z
(0.1, 0.9)]),
a_r=0.16, # Volatility term in R shock
b_r=0.0, # Mean shift R shock
a_y=0.2, # Volatility term in Y shock
b_y=0.5, # Mean shift Y shock
shock_draw_size=100, # For Monte Carlo
grid_max=100, # Exogenous grid max
grid_size=100, # Exogenous grid size
seed=1234 # Random seed
):
"""
Create an instance of IFP with the given parameters.

"""
# Test stability assuming {R_t} is IID and adopts the lognormal
# specification given below. The test is then β E R_t < 1.
# Test stability assuming {R_t} is IID and ln R ~ N(b_r, a_r)
ER = np.exp(b_r + a_r**2 / 2)
assert β * ER < 1, "Stability condition failed."

# Convert to JAX arrays
P = jnp.array(P)

# Generate random draws using JAX
key = jax.random.PRNGKey(seed)
key, subkey1, subkey2 = jax.random.split(key, 3)
subkey1, subkey2 = jax.random.split(key)
η_draws = jax.random.normal(subkey1, (shock_draw_size,))
ζ_draws = jax.random.normal(subkey2, (shock_draw_size,))
s_grid = jnp.linspace(0, grid_max, grid_size)
Expand Down Expand Up @@ -409,96 +387,83 @@ def Y(z, η, a_y, b_y):
Here's the Coleman-Reffett operator using JAX:

```{code-cell} ipython3
@jax.jit
def K(ae_vals, c_vals, ifp):
def K(
a_in: jnp.array, # a_in[i, z] is an asset grid
c_in: jnp.array, # c_in[i, z] = consumption at a_in[i, z]
ifp: IFP
):
"""
The Coleman--Reffett operator for the income fluctuation problem,
using the endogenous grid method with JAX.

* ifp is an instance of IFP
* ae_vals[i, z] is an asset grid
* c_vals[i, z] is consumption at ae_vals[i, z]
"""

# Extract parameters from ifp
γ, β, P = ifp.γ, ifp.β, ifp.P
a_r, b_r, a_y, b_y = ifp.a_r, ifp.b_r, ifp.a_y, ifp.b_y
s_grid, η_draws, ζ_draws = ifp.s_grid, ifp.η_draws, ifp.ζ_draws
γ, β, P, a_r, b_r, a_y, b_y, s_grid, η_draws, ζ_draws = ifp
n = len(P)

# Allocate memory
c_out = jnp.empty_like(c_vals)

# Obtain c_i at each s_i, z, store in c_out[i, z], computing
# the expectation term by Monte Carlo
def compute_expectation(s, z):
"""Compute expectation for given s and z"""
def inner_expectation(z_hat):
# Vectorize over shocks
def compute_term(η, ζ):
R_hat = R(z_hat, ζ, a_r, b_r)
Y_hat = Y(z_hat, η, a_y, b_y)
a_val = R_hat * s + Y_hat
# Interpolate consumption
c_interp = jnp.interp(a_val, ae_vals[:, z_hat], c_vals[:, z_hat])
U = u_prime(c_interp, γ)
return R_hat * U

c_interp = jnp.interp(a_val, a_in[:, z_hat], c_in[:, z_hat])
mu = u_prime(c_interp, γ)
return R_hat * mu
# Vectorize over all shock combinations
η_grid, ζ_grid = jnp.meshgrid(η_draws, ζ_draws, indexing='ij')
terms = vmap(vmap(compute_term))(η_grid, ζ_grid)
return P[z, z_hat] * jnp.mean(terms)

# Sum over z_hat states
Ez = jnp.sum(vmap(inner_expectation)(jnp.arange(n)))
return u_prime_inv(β * Ez, γ)

# Vectorize over s_grid and z
c_out = vmap(vmap(compute_expectation, in_axes=(None, 0)),
in_axes=(0, None))(s_grid, jnp.arange(n))

compute_exp_v1 = vmap(compute_expectation, in_axes=(None, 0))
compute_exp_v2 = vmap(compute_exp_v1, in_axes=(0, None))
c_out = compute_exp_v2(s_grid, jnp.arange(n))
# Calculate endogenous asset grid
ae_out = s_grid[:, None] + c_out

# Fixing a consumption-asset pair at (0, 0) improves interpolation
a_out = s_grid[:, None] + c_out
# Fix consumption-asset pair at (0, 0)
c_out = c_out.at[0, :].set(0)
ae_out = ae_out.at[0, :].set(0)
a_out = a_out.at[0, :].set(0)

return ae_out, c_out
return a_out, c_out
```

The next function solves for an approximation of the optimal consumption policy
via time iteration using JAX:

```{code-cell} ipython3
def solve_model_time_iter(
model, # Class with model information
a_vec, # Initial condition for assets
σ_vec, # Initial condition for consumption
tol=1e-4,
max_iter=1000,
verbose=True,
print_skip=25
):

# Set up loop
i = 0
error = tol + 1

while i < max_iter and error > tol:
a_new, σ_new = K(a_vec, σ_vec, model)
error = jnp.max(jnp.abs(σ_vec - σ_new))
@jax.jit
def solve_model(
ifp: IFP,
c_init: jnp.ndarray, # Initial guess of σ on grid endogenous grid
a_init: jnp.ndarray, # Initial endogenous grid
tol: float = 1e-5,
max_iter: int = 1000
) -> jnp.ndarray:
" Solve the model using time iteration with EGM. "

def condition(loop_state):
c_in, a_in, i, error = loop_state
return (error > tol) & (i < max_iter)

def body(loop_state):
c_in, a_in, i, error = loop_state
c_out, a_out = K(c_in, a_in, ifp)
error = jnp.max(jnp.abs(c_out - c_in))
i += 1
if verbose and i % print_skip == 0:
print(f"Error at iteration {i} is {error}.")
a_vec, σ_vec = a_new, σ_new
return c_out, a_out, i, error

if error > tol:
print("Failed to converge!")
elif verbose:
print(f"\nConverged in {i} iterations.")
i, error = 0, tol + 1
initial_state = (c_init, a_init, i, error)
final_loop_state = jax.lax.while_loop(condition, body, initial_state)
c_out, a_out, i, error = final_loop_state

return a_new, σ_new
return c_out, a_out
```

Now we can create an instance and solve the model using JAX:
Expand All @@ -522,10 +487,16 @@ a_init = σ_init.copy()
Let's generate an approximation solution with JAX:

```{code-cell} ipython3
a_star, σ_star = solve_model_time_iter(ifp, a_init, σ_init, print_skip=5)
a_star, σ_star = solve_model(ifp, a_init, σ_init)
```

Let's try it again with a timer.

```{code-cell} python3
with qe.Timer(precision=8):
a_star, σ_star = solve_model(ifp, a_init, σ_init)
a_star.block_until_ready()
```

## Simulation

Expand All @@ -543,7 +514,6 @@ The function takes a solution pair `c_vec` and `a_vec`, understanding them
as representing an optimal policy associated with a given model `ifp`

```{code-cell} ipython3
@jax.jit
def simulate_household(
key, a_0, z_idx_0, c_vec, a_vec, ifp, T
):
Expand Down Expand Up @@ -593,6 +563,7 @@ def simulate_household(
Now we write a function to simulate many households in parallel.

```{code-cell} ipython3
@partial(jax.jit, static_argnums=(3, 4, 5))
def compute_asset_stationary(
c_vec, a_vec, ifp, num_households=50_000, T=500, seed=1234
):
Expand Down Expand Up @@ -623,7 +594,7 @@ def compute_asset_stationary(
)
assets = sim_all_households(keys, a_0_vector, z_idx_0_vector, c_vec, a_vec, ifp, T)

return np.array(assets)
return jnp.array(assets)
```

We'll need some inequality measures for visualization, so let's define them first:
Expand Down Expand Up @@ -671,7 +642,7 @@ s_grid = ifp.s_grid
n_z = len(ifp.P)
a_init = s_grid[:, None] * jnp.ones(n_z)
c_init = a_init
a_vec, c_vec = solve_model_time_iter(ifp, a_init, c_init)
a_vec, c_vec = solve_model(ifp, a_init, c_init)
assets = compute_asset_stationary(c_vec, a_vec, ifp, num_households=200_000)

# Compute Gini coefficient for the plot
Expand Down Expand Up @@ -763,8 +734,8 @@ for a_r in a_r_vals:
n_z_temp = len(ifp_temp.P)
a_init_temp = s_grid_temp[:, None] * jnp.ones(n_z_temp)
c_init_temp = a_init_temp
a_vec_temp, c_vec_temp = solve_model_time_iter(
ifp_temp, a_init_temp, c_init_temp, verbose=False
a_vec_temp, c_vec_temp = solve_model(
ifp_temp, a_init_temp, c_init_temp
)

# Simulate households
Expand Down Expand Up @@ -840,8 +811,8 @@ for a_y in a_y_vals:
n_z_temp = len(ifp_temp.P)
a_init_temp = s_grid_temp[:, None] * jnp.ones(n_z_temp)
c_init_temp = a_init_temp
a_vec_temp, c_vec_temp = solve_model_time_iter(
ifp_temp, a_init_temp, c_init_temp, verbose=False
a_vec_temp, c_vec_temp = solve_model(
ifp_temp, a_init_temp, c_init_temp
)

# Simulate households
Expand Down
Loading