diff --git a/lectures/ifp_advanced.md b/lectures/ifp_advanced.md index edf7ab5d7..f6df09b27 100644 --- a/lectures/ifp_advanced.md +++ b/lectures/ifp_advanced.md @@ -60,11 +60,12 @@ We require the following imports: ```{code-cell} ipython3 import matplotlib.pyplot as plt import numpy as np -from quantecon import MarkovChain +import quantecon as qe import jax import jax.numpy as jnp from jax import vmap from typing import NamedTuple +from functools import partial ``` @@ -129,7 +130,7 @@ does not grow too quickly. When $\{R_t\}$ was constant we required that $\beta R < 1$. -Since it is now stochastic, we require that +Since it is now stochastic, we require (see {cite}`ma2020income`) that ```{math} :label: fpbc2 @@ -140,15 +141,15 @@ G_R := \lim_{n \to \infty} \left(\mathbb E \prod_{t=1}^n R_t \right)^{1/n} ``` -Notice that, when $\{R_t\}$ takes some constant value $R$, this -reduces to the previous restriction $\beta R < 1$. - The value $G_R$ can be thought of as the long run (geometric) average gross rate of return. -More intuition behind {eq}`fpbc2` is provided in {cite}`ma2020income`. +To simplify this lecture, we will *assume that the interest rate process is +IID*. + +In that case, it is clear from the definition of $G_R$ that $G_R$ is just $\mathbb E R_t$. -Discussion on how to check it is given below. +We test the condition $\beta \mathbb E R_t < 1$ in the code below. Finally, we impose some routine technical restrictions on non-financial income. @@ -309,28 +310,6 @@ obtained by interpolating $\{a_i, c_i\}$ at each $z$. In what follows, we use linear interpolation. -### Testing the Assumptions - -Convergence of time iteration is dependent on the condition $\beta G_R < 1$ being satisfied. - -One can check this using the fact that $G_R$ is equal to the spectral -radius of the matrix $L$ defined by - -$$ -L(z, \hat z) := P(z, \hat z) \int R(\hat z, x) \phi(x) dx -$$ - -This identity is proved in {cite}`ma2020income`, where $\phi$ is the -density of the innovation $\zeta_t$ to returns on assets. - -(Remember that $\mathsf Z$ is a finite set, so this expression defines a matrix.) - -Checking the condition is even easier when $\{R_t\}$ is IID. - -In that case, it is clear from the definition of $G_R$ that $G_R$ -is just $\mathbb E R_t$. - -We test the condition $\beta \mathbb E R_t < 1$ in the code below. ## Implementation @@ -354,32 +333,31 @@ class IFP(NamedTuple): ζ_draws: jnp.ndarray -def create_ifp(γ=1.5, - β=0.96, - P=np.array([(0.9, 0.1), - (0.1, 0.9)]), - a_r=0.16, - b_r=0.0, - a_y=0.2, - b_y=0.5, - shock_draw_size=100, - grid_max=100, - grid_size=100, - seed=1234): +def create_ifp( + γ=1.5, # Utility parameter + β=0.96, # Discount factor + P=jnp.array([(0.9, 0.1), # Default Markov chain for Z + (0.1, 0.9)]), + a_r=0.16, # Volatility term in R shock + b_r=0.0, # Mean shift R shock + a_y=0.2, # Volatility term in Y shock + b_y=0.5, # Mean shift Y shock + shock_draw_size=100, # For Monte Carlo + grid_max=100, # Exogenous grid max + grid_size=100, # Exogenous grid size + seed=1234 # Random seed + ): """ Create an instance of IFP with the given parameters. + """ - # Test stability assuming {R_t} is IID and adopts the lognormal - # specification given below. The test is then β E R_t < 1. + # Test stability assuming {R_t} is IID and ln R ~ N(b_r, a_r) ER = np.exp(b_r + a_r**2 / 2) assert β * ER < 1, "Stability condition failed." - # Convert to JAX arrays - P = jnp.array(P) - # Generate random draws using JAX key = jax.random.PRNGKey(seed) - key, subkey1, subkey2 = jax.random.split(key, 3) + subkey1, subkey2 = jax.random.split(key) η_draws = jax.random.normal(subkey1, (shock_draw_size,)) ζ_draws = jax.random.normal(subkey2, (shock_draw_size,)) s_grid = jnp.linspace(0, grid_max, grid_size) @@ -409,96 +387,83 @@ def Y(z, η, a_y, b_y): Here's the Coleman-Reffett operator using JAX: ```{code-cell} ipython3 -@jax.jit -def K(ae_vals, c_vals, ifp): +def K( + a_in: jnp.array, # a_in[i, z] is an asset grid + c_in: jnp.array, # c_in[i, z] = consumption at a_in[i, z] + ifp: IFP + ): """ The Coleman--Reffett operator for the income fluctuation problem, using the endogenous grid method with JAX. - * ifp is an instance of IFP - * ae_vals[i, z] is an asset grid - * c_vals[i, z] is consumption at ae_vals[i, z] """ # Extract parameters from ifp - γ, β, P = ifp.γ, ifp.β, ifp.P - a_r, b_r, a_y, b_y = ifp.a_r, ifp.b_r, ifp.a_y, ifp.b_y - s_grid, η_draws, ζ_draws = ifp.s_grid, ifp.η_draws, ifp.ζ_draws + γ, β, P, a_r, b_r, a_y, b_y, s_grid, η_draws, ζ_draws = ifp n = len(P) - # Allocate memory - c_out = jnp.empty_like(c_vals) - - # Obtain c_i at each s_i, z, store in c_out[i, z], computing - # the expectation term by Monte Carlo def compute_expectation(s, z): - """Compute expectation for given s and z""" def inner_expectation(z_hat): - # Vectorize over shocks def compute_term(η, ζ): R_hat = R(z_hat, ζ, a_r, b_r) Y_hat = Y(z_hat, η, a_y, b_y) a_val = R_hat * s + Y_hat # Interpolate consumption - c_interp = jnp.interp(a_val, ae_vals[:, z_hat], c_vals[:, z_hat]) - U = u_prime(c_interp, γ) - return R_hat * U - + c_interp = jnp.interp(a_val, a_in[:, z_hat], c_in[:, z_hat]) + mu = u_prime(c_interp, γ) + return R_hat * mu # Vectorize over all shock combinations η_grid, ζ_grid = jnp.meshgrid(η_draws, ζ_draws, indexing='ij') terms = vmap(vmap(compute_term))(η_grid, ζ_grid) return P[z, z_hat] * jnp.mean(terms) - # Sum over z_hat states Ez = jnp.sum(vmap(inner_expectation)(jnp.arange(n))) return u_prime_inv(β * Ez, γ) # Vectorize over s_grid and z - c_out = vmap(vmap(compute_expectation, in_axes=(None, 0)), - in_axes=(0, None))(s_grid, jnp.arange(n)) - + compute_exp_v1 = vmap(compute_expectation, in_axes=(None, 0)) + compute_exp_v2 = vmap(compute_exp_v1, in_axes=(0, None)) + c_out = compute_exp_v2(s_grid, jnp.arange(n)) # Calculate endogenous asset grid - ae_out = s_grid[:, None] + c_out - - # Fixing a consumption-asset pair at (0, 0) improves interpolation + a_out = s_grid[:, None] + c_out + # Fix consumption-asset pair at (0, 0) c_out = c_out.at[0, :].set(0) - ae_out = ae_out.at[0, :].set(0) + a_out = a_out.at[0, :].set(0) - return ae_out, c_out + return a_out, c_out ``` The next function solves for an approximation of the optimal consumption policy via time iteration using JAX: ```{code-cell} ipython3 -def solve_model_time_iter( - model, # Class with model information - a_vec, # Initial condition for assets - σ_vec, # Initial condition for consumption - tol=1e-4, - max_iter=1000, - verbose=True, - print_skip=25 - ): - - # Set up loop - i = 0 - error = tol + 1 - - while i < max_iter and error > tol: - a_new, σ_new = K(a_vec, σ_vec, model) - error = jnp.max(jnp.abs(σ_vec - σ_new)) +@jax.jit +def solve_model( + ifp: IFP, + c_init: jnp.ndarray, # Initial guess of σ on grid endogenous grid + a_init: jnp.ndarray, # Initial endogenous grid + tol: float = 1e-5, + max_iter: int = 1000 + ) -> jnp.ndarray: + " Solve the model using time iteration with EGM. " + + def condition(loop_state): + c_in, a_in, i, error = loop_state + return (error > tol) & (i < max_iter) + + def body(loop_state): + c_in, a_in, i, error = loop_state + c_out, a_out = K(c_in, a_in, ifp) + error = jnp.max(jnp.abs(c_out - c_in)) i += 1 - if verbose and i % print_skip == 0: - print(f"Error at iteration {i} is {error}.") - a_vec, σ_vec = a_new, σ_new + return c_out, a_out, i, error - if error > tol: - print("Failed to converge!") - elif verbose: - print(f"\nConverged in {i} iterations.") + i, error = 0, tol + 1 + initial_state = (c_init, a_init, i, error) + final_loop_state = jax.lax.while_loop(condition, body, initial_state) + c_out, a_out, i, error = final_loop_state - return a_new, σ_new + return c_out, a_out ``` Now we can create an instance and solve the model using JAX: @@ -522,10 +487,16 @@ a_init = σ_init.copy() Let's generate an approximation solution with JAX: ```{code-cell} ipython3 -a_star, σ_star = solve_model_time_iter(ifp, a_init, σ_init, print_skip=5) +a_star, σ_star = solve_model(ifp, a_init, σ_init) ``` +Let's try it again with a timer. +```{code-cell} python3 +with qe.Timer(precision=8): + a_star, σ_star = solve_model(ifp, a_init, σ_init) + a_star.block_until_ready() +``` ## Simulation @@ -543,7 +514,6 @@ The function takes a solution pair `c_vec` and `a_vec`, understanding them as representing an optimal policy associated with a given model `ifp` ```{code-cell} ipython3 -@jax.jit def simulate_household( key, a_0, z_idx_0, c_vec, a_vec, ifp, T ): @@ -593,6 +563,7 @@ def simulate_household( Now we write a function to simulate many households in parallel. ```{code-cell} ipython3 +@partial(jax.jit, static_argnums=(3, 4, 5)) def compute_asset_stationary( c_vec, a_vec, ifp, num_households=50_000, T=500, seed=1234 ): @@ -623,7 +594,7 @@ def compute_asset_stationary( ) assets = sim_all_households(keys, a_0_vector, z_idx_0_vector, c_vec, a_vec, ifp, T) - return np.array(assets) + return jnp.array(assets) ``` We'll need some inequality measures for visualization, so let's define them first: @@ -671,7 +642,7 @@ s_grid = ifp.s_grid n_z = len(ifp.P) a_init = s_grid[:, None] * jnp.ones(n_z) c_init = a_init -a_vec, c_vec = solve_model_time_iter(ifp, a_init, c_init) +a_vec, c_vec = solve_model(ifp, a_init, c_init) assets = compute_asset_stationary(c_vec, a_vec, ifp, num_households=200_000) # Compute Gini coefficient for the plot @@ -763,8 +734,8 @@ for a_r in a_r_vals: n_z_temp = len(ifp_temp.P) a_init_temp = s_grid_temp[:, None] * jnp.ones(n_z_temp) c_init_temp = a_init_temp - a_vec_temp, c_vec_temp = solve_model_time_iter( - ifp_temp, a_init_temp, c_init_temp, verbose=False + a_vec_temp, c_vec_temp = solve_model( + ifp_temp, a_init_temp, c_init_temp ) # Simulate households @@ -840,8 +811,8 @@ for a_y in a_y_vals: n_z_temp = len(ifp_temp.P) a_init_temp = s_grid_temp[:, None] * jnp.ones(n_z_temp) c_init_temp = a_init_temp - a_vec_temp, c_vec_temp = solve_model_time_iter( - ifp_temp, a_init_temp, c_init_temp, verbose=False + a_vec_temp, c_vec_temp = solve_model( + ifp_temp, a_init_temp, c_init_temp ) # Simulate households