diff --git a/lectures/ifp_advanced.md b/lectures/ifp_advanced.md index ee9ded0e1..936ce3a60 100644 --- a/lectures/ifp_advanced.md +++ b/lectures/ifp_advanced.md @@ -25,7 +25,7 @@ kernelspec: In addition to what's in Anaconda, this lecture will need the following libraries: -```{code-cell} ipython +```{code-cell} ipython3 --- tags: [hide-output] --- @@ -54,15 +54,19 @@ endogenous grid method to solve the model quickly and accurately. We require the following imports: -```{code-cell} ipython +```{code-cell} ipython3 import matplotlib.pyplot as plt import numpy as np -from numba import jit, float64 -from numba.experimental import jitclass from quantecon import MarkovChain +import jax +import jax.numpy as jnp +from jax import vmap +from typing import NamedTuple ``` -## The Savings Problem + + +## The Model In this section we review the household problem and optimality results. @@ -89,33 +93,28 @@ a_{t+1} = R_{t+1} (a_t - c_t) + Y_{t+1} with initial condition $(a_0, Z_0)=(a,z)$ treated as given. -Note that $\{R_t\}_{t \geq 1}$, the gross rate of return on wealth, is allowed to be stochastic. +The only difference from {doc}`ifp_egm_transient_shocks` is that $\{R_t\}_{t \geq 1}$, the gross rate of return on wealth, is allowed to be stochastic. -The sequence $\{Y_t \}_{t \geq 1}$ is non-financial income. - -The stochastic components of the problem obey +In particular, we assume that ```{math} :label: eq:RY_func -R_t = R(Z_t, \zeta_t) - \quad \text{and} \quad -Y_t = Y(Z_t, \eta_t), + R_t = R(Z_t, \zeta_t) + \quad \text{and} \quad + Y_t = Y(Z_t, \eta_t), ``` where -* the maps $R$ and $Y$ are time-invariant nonnegative functions, +* $R$ and $Y$ are time-invariant nonnegative functions, * the innovation processes $\{\zeta_t\}$ and $\{\eta_t\}$ are IID and independent of each other, and -* $\{Z_t\}_{t \geq 0}$ is an irreducible time-homogeneous Markov chain on a finite set $\mathsf Z$ +* $\{Z_t\}_{t \geq 0}$ is a Markov chain on a finite set $\mathsf Z$ Let $P$ represent the Markov matrix for the chain $\{Z_t\}_{t \geq 0}$. -Our assumptions on preferences are the same as in {doc}`ifp_egm`. - -As before, $\mathbb E_z \hat X$ means expectation of next period value -$\hat X$ given current value $Z = z$. +In what follows, $\mathbb E_z \hat X$ means expectation of next period value $\hat X$ given current value $Z = z$. ### Assumptions @@ -330,317 +329,12 @@ is just $\mathbb E R_t$. We test the condition $\beta \mathbb E R_t < 1$ in the code below. -## Numba Implementation - -We will assume that $R_t = \exp(a_r \zeta_t + b_r)$ where $a_r, b_r$ -are constants and $\{ \zeta_t\}$ is IID standard normal. - -We allow labor income to be correlated, with - -$$ -Y_t = \exp(a_y \eta_t + Z_t b_y) -$$ - -where $\{ \eta_t\}$ is also IID standard normal -and $\{ Z_t\}$ is a Markov chain taking values in $\{0, 1\}$. - -```{code-cell} ipython -ifp_data = [ - ('γ', float64), # utility parameter - ('β', float64), # discount factor - ('P', float64[:, :]), # transition probs for z_t - ('a_r', float64), # scale parameter for R_t - ('b_r', float64), # additive parameter for R_t - ('a_y', float64), # scale parameter for Y_t - ('b_y', float64), # additive parameter for Y_t - ('s_grid', float64[:]), # Grid over savings - ('η_draws', float64[:]), # Draws of innovation η for MC - ('ζ_draws', float64[:]) # Draws of innovation ζ for MC -] -``` - -```{code-cell} ipython -@jitclass(ifp_data) -class IFP: - """ - A class that stores primitives for the income fluctuation - problem. - """ - - def __init__(self, - γ=1.5, - β=0.96, - P=np.array([(0.9, 0.1), - (0.1, 0.9)]), - a_r=0.1, - b_r=0.0, - a_y=0.2, - b_y=0.5, - shock_draw_size=100, - grid_max=10, - grid_size=100, - seed=1234): - - np.random.seed(seed) # arbitrary seed - - self.P, self.γ, self.β = P, γ, β - self.a_r, self.b_r, self.a_y, self.b_y = a_r, b_r, a_y, b_y - self.η_draws = np.random.randn(shock_draw_size) - self.ζ_draws = np.random.randn(shock_draw_size) - self.s_grid = np.linspace(0, grid_max, grid_size) - - # Test stability assuming {R_t} is IID and adopts the lognormal - # specification given below. The test is then β E R_t < 1. - ER = np.exp(b_r + a_r**2 / 2) - assert β * ER < 1, "Stability condition failed." - - # Marginal utility - def u_prime(self, c): - return c**(-self.γ) - - # Inverse of marginal utility - def u_prime_inv(self, c): - return c**(-1/self.γ) - - def R(self, z, ζ): - return np.exp(self.a_r * ζ + self.b_r) - - def Y(self, z, η): - return np.exp(self.a_y * η + (z * self.b_y)) -``` - -Here's the Coleman-Reffett operator based on EGM: - -### Implementation Details - -The implementation of operator $K$ maps directly to equation {eq}`k_opr`. +## Implementation -The left side $u'(\xi)$ becomes `u_prime_inv(β * Ez)` after solving for $\xi$. +Here's the model as a `NamedTuple`. -The expectation term $\mathbb E_z \hat{R} (u' \circ \sigma)[\hat{R}(a - \xi) + \hat{Y}, \hat{Z}]$ is computed via Monte Carlo averaging over future states and shocks. - -The max with $u'(a)$ is handled implicitly—the endogenous grid method naturally handles the liquidity constraint since we only solve for interior consumption where $c < a$. - -```{code-cell} ipython -@jit -def K(ae_vals, c_vals, ifp): - """ - The Coleman--Reffett operator for the income fluctuation problem, - using the endogenous grid method. - - * ifp is an instance of IFP - * ae_vals[i, z] is an asset grid - * c_vals[i, z] is consumption at ae_vals[i, z] - """ - - # Simplify names - u_prime, u_prime_inv = ifp.u_prime, ifp.u_prime_inv - R, Y, P, β = ifp.R, ifp.Y, ifp.P, ifp.β - s_grid, η_draws, ζ_draws = ifp.s_grid, ifp.η_draws, ifp.ζ_draws - n = len(P) - - # Create consumption function by linear interpolation - σ = lambda a, z: np.interp(a, ae_vals[:, z], c_vals[:, z]) - - # Allocate memory - c_out = np.empty_like(c_vals) - - # Obtain c_i at each s_i, z, store in c_out[i, z], computing - # the expectation term by Monte Carlo - for i, s in enumerate(s_grid): - for z in range(n): - # Compute expectation - Ez = 0.0 - for z_hat in range(n): - for η in ifp.η_draws: - for ζ in ifp.ζ_draws: - R_hat = R(z_hat, ζ) - Y_hat = Y(z_hat, η) - U = u_prime(σ(R_hat * s + Y_hat, z_hat)) - Ez += R_hat * U * P[z, z_hat] - Ez = Ez / (len(η_draws) * len(ζ_draws)) - c_out[i, z] = u_prime_inv(β * Ez) - - # Calculate endogenous asset grid - ae_out = np.empty_like(c_out) - for z in range(n): - ae_out[:, z] = s_grid + c_out[:, z] - - # Fixing a consumption-asset pair at (0, 0) improves interpolation - c_out[0, :] = 0 - ae_out[0, :] = 0 - - return ae_out, c_out -``` - -### Code Walkthrough - -The operator creates a consumption function `σ` by interpolating the input policy, then uses triple nested loops to compute the expectation via Monte Carlo averaging over savings grid points, current states, future states, and shock realizations. - -After computing optimal consumption $c_i$ at each savings level $s_i$ by inverting marginal utility, we construct the endogenous asset grid using $a_i = s_i + c_i$. - -Setting consumption and assets to zero at the origin ensures smooth interpolation near zero assets, where the household consumes everything. - -The next function solves for an approximation of the optimal consumption policy via time iteration. - -```{code-cell} ipython -def solve_model_time_iter(model, # Class with model information - a_vec, # Initial condition for assets - σ_vec, # Initial condition for consumption - tol=1e-4, - max_iter=1000, - verbose=True, - print_skip=25): - - # Set up loop - i = 0 - error = tol + 1 - - while i < max_iter and error > tol: - a_new, σ_new = K(a_vec, σ_vec, model) - error = np.max(np.abs(σ_vec - σ_new)) - i += 1 - if verbose and i % print_skip == 0: - print(f"Error at iteration {i} is {error}.") - a_vec, σ_vec = np.copy(a_new), np.copy(σ_new) - - if error > tol: - print("Failed to converge!") - elif verbose: - print(f"\nConverged in {i} iterations.") - - return a_new, σ_new -``` - -This function implements fixed-point iteration by repeatedly applying the operator $K$ until the policy converges. - -Convergence is measured by the maximum absolute change in consumption across all states. - -The operator is guaranteed to converge due to the contraction property discussed earlier. - -Now we are ready to create an instance at the default parameters. - -```{code-cell} ipython -ifp = IFP() -``` - -The default parameters represent a calibration with moderate risk aversion ($\gamma = 1.5$, CRRA utility) and a quarterly discount factor ($\beta = 0.96$, corresponding to roughly 4% annual discounting). - -The Markov chain has high persistence (90% probability of staying in the current state), while returns have 10% volatility around a zero mean log return ($a_r = 0.1$, $b_r = 0.0$). - -Labor income is state-dependent: $Y_t = \exp(0.2 \eta_t + 0.5 Z_t)$ implies higher expected income in the good state ($Z_t = 1$) compared to the bad state ($Z_t = 0$). - -Next we set up an initial condition, which corresponds to consuming all -assets. - -```{code-cell} ipython -# Initial guess of σ = consume all assets -k = len(ifp.s_grid) -n = len(ifp.P) -σ_init = np.empty((k, n)) -for z in range(n): - σ_init[:, z] = ifp.s_grid -a_init = np.copy(σ_init) -``` - -Let's generate an approximation solution. - -```{code-cell} ipython -a_star, σ_star = solve_model_time_iter(ifp, a_init, σ_init, print_skip=5) -``` - -Here's a plot of the resulting consumption policy. - -```{code-cell} ipython -fig, ax = plt.subplots() -for z in range(len(ifp.P)): - ax.plot(a_star[:, z], σ_star[:, z], label=f"consumption when $z={z}$") - -plt.legend() -plt.show() -``` - -Notice that we consume all assets in the lower range of the asset space. - -This is because we anticipate income $Y_{t+1}$ tomorrow, which makes the need to save less urgent. - -Observe that consuming all assets ends earlier (at lower asset levels) when $z=0$ compared to $z=1$. - -This occurs because expected future income is lower in the bad state ($z=0$), so the household begins precautionary saving at lower wealth levels. - -In contrast, when $z=1$ (good state), higher expected future income allows the household to consume all assets up to a higher threshold before savings become optimal. - -### Law of Motion - -Let's try to get some idea of what will happen to assets over the long run -under this consumption policy. - -As in {doc}`ifp_egm`, we -begin by producing a 45 degree diagram showing the law of motion for assets - -```{code-cell} python3 -# Good and bad state mean labor income -Y_mean = [np.mean(ifp.Y(z, ifp.η_draws)) for z in (0, 1)] -# Mean returns -R_mean = np.mean(ifp.R(z, ifp.ζ_draws)) - -a = a_star -fig, ax = plt.subplots() -for z, lb in zip((0, 1), ('bad state', 'good state')): - ax.plot(a[:, z], R_mean * (a[:, z] - σ_star[:, z]) + Y_mean[z] , label=lb) - -ax.plot(a[:, 0], a[:, 0], 'k--') -ax.set(xlabel='current assets', ylabel='next period assets') - -ax.legend() -plt.show() -``` - -The unbroken lines represent, for each $z$, an average update function -for assets, given by - -$$ -a \mapsto \bar R (a - \sigma^*(a, z)) + \bar Y(z) -$$ - -Here - -* $\bar R = \mathbb E R_t$, which is mean returns and -* $\bar Y(z) = \mathbb E_z Y(z, \eta_t)$, which is mean labor income in state $z$. - -The dashed line is the 45 degree line. - -We can see from the figure that the dynamics will be stable --- assets do not -diverge even in the highest state. - -## JAX Implementation - -We now provide a JAX implementation of the model. - -JAX is a high-performance numerical computing library that provides automatic differentiation and JIT compilation, with support for GPU/TPU acceleration. - -First we need to import JAX and related libraries: - -```{code-cell} ipython -import jax -import jax.numpy as jnp -from jax import vmap -from typing import NamedTuple - -# Import jax.jit with a different name to avoid conflict with numba.jit -jax_jit = jax.jit -``` - -We enable 64-bit precision in JAX to ensure accurate results that match the Numba implementation: - -```{code-cell} ipython -jax.config.update("jax_enable_x64", True) -``` - -Here's the JAX version of the IFP class using NamedTuple for compatibility with JAX's JIT compilation: - -```{code-cell} ipython -class IFP_JAX(NamedTuple): +```{code-cell} ipython3 +class IFP(NamedTuple): """ A NamedTuple that stores primitives for the income fluctuation problem, using JAX. @@ -657,20 +351,20 @@ class IFP_JAX(NamedTuple): ζ_draws: jnp.ndarray -def create_ifp_jax(γ=1.5, +def create_ifp(γ=1.5, β=0.96, P=np.array([(0.9, 0.1), (0.1, 0.9)]), - a_r=0.1, + a_r=0.16, b_r=0.0, a_y=0.2, b_y=0.5, shock_draw_size=100, - grid_max=10, + grid_max=100, grid_size=100, seed=1234): """ - Create an instance of IFP_JAX with the given parameters. + Create an instance of IFP with the given parameters. """ # Test stability assuming {R_t} is IID and adopts the lognormal # specification given below. The test is then β E R_t < 1. @@ -678,7 +372,7 @@ def create_ifp_jax(γ=1.5, assert β * ER < 1, "Stability condition failed." # Convert to JAX arrays - P_jax = jnp.array(P) + P = jnp.array(P) # Generate random draws using JAX key = jax.random.PRNGKey(seed) @@ -687,11 +381,10 @@ def create_ifp_jax(γ=1.5, ζ_draws = jax.random.normal(subkey2, (shock_draw_size,)) s_grid = jnp.linspace(0, grid_max, grid_size) - return IFP_JAX(γ=γ, β=β, P=P_jax, a_r=a_r, b_r=b_r, a_y=a_y, b_y=b_y, - s_grid=s_grid, η_draws=η_draws, ζ_draws=ζ_draws) - + return IFP( + γ, β, P, a_r, b_r, a_y, b_y, s_grid, η_draws, ζ_draws + ) -# Utility functions for the IFP model def u_prime(c, γ): """Marginal utility""" @@ -712,14 +405,14 @@ def Y(z, η, a_y, b_y): Here's the Coleman-Reffett operator using JAX: -```{code-cell} ipython -@jax_jit -def K_jax(ae_vals, c_vals, ifp): +```{code-cell} ipython3 +@jax.jit +def K(ae_vals, c_vals, ifp): """ The Coleman--Reffett operator for the income fluctuation problem, using the endogenous grid method with JAX. - * ifp is an instance of IFP_JAX + * ifp is an instance of IFP * ae_vals[i, z] is an asset grid * c_vals[i, z] is consumption at ae_vals[i, z] """ @@ -771,23 +464,26 @@ def K_jax(ae_vals, c_vals, ifp): return ae_out, c_out ``` -The next function solves for an approximation of the optimal consumption policy via time iteration using JAX: +The next function solves for an approximation of the optimal consumption policy +via time iteration using JAX: -```{code-cell} ipython -def solve_model_time_iter_jax(model, # Class with model information - a_vec, # Initial condition for assets - σ_vec, # Initial condition for consumption - tol=1e-4, - max_iter=1000, - verbose=True, - print_skip=25): +```{code-cell} ipython3 +def solve_model_time_iter( + model, # Class with model information + a_vec, # Initial condition for assets + σ_vec, # Initial condition for consumption + tol=1e-4, + max_iter=1000, + verbose=True, + print_skip=25 + ): # Set up loop i = 0 error = tol + 1 while i < max_iter and error > tol: - a_new, σ_new = K_jax(a_vec, σ_vec, model) + a_new, σ_new = K(a_vec, σ_vec, model) error = jnp.max(jnp.abs(σ_vec - σ_new)) i += 1 if verbose and i % print_skip == 0: @@ -804,195 +500,308 @@ def solve_model_time_iter_jax(model, # Class with model information Now we can create an instance and solve the model using JAX: -```{code-cell} ipython -ifp_jax = create_ifp_jax() +```{code-cell} ipython3 +ifp = create_ifp() ``` Set up the initial condition: -```{code-cell} ipython +```{code-cell} ipython3 # Initial guess of σ = consume all assets -k = len(ifp_jax.s_grid) -n = len(ifp_jax.P) -σ_init_jax = jnp.empty((k, n)) +k = len(ifp.s_grid) +n = len(ifp.P) +σ_init = jnp.empty((k, n)) for z in range(n): - σ_init_jax = σ_init_jax.at[:, z].set(ifp_jax.s_grid) -a_init_jax = σ_init_jax.copy() + σ_init = σ_init.at[:, z].set(ifp.s_grid) +a_init = σ_init.copy() ``` Let's generate an approximation solution with JAX: -```{code-cell} ipython -a_star_jax, σ_star_jax = solve_model_time_iter_jax(ifp_jax, a_init_jax, σ_init_jax, print_skip=5) +```{code-cell} ipython3 +a_star, σ_star = solve_model_time_iter(ifp, a_init, σ_init, print_skip=5) ``` -Here's a plot comparing the JAX solution with the Numba solution: -```{code-cell} ipython -fig, ax = plt.subplots() -for z in range(len(ifp_jax.P)): - ax.plot(np.array(a_star_jax[:, z]), np.array(σ_star_jax[:, z]), - label=f"JAX: consumption when $z={z}$", linestyle='--') - ax.plot(a_star[:, z], σ_star[:, z], - label=f"Numba: consumption when $z={z}$", linestyle='-', alpha=0.6) -plt.legend() -plt.show() -``` +## Simulation -### Comparison of Numba and JAX Solutions +Let's return to the default model and study the stationary distribution of assets. -Now let's verify that both implementations produce nearly identical results. +Our plan is to run a large number of households forward for $T$ periods and then +histogram the cross-sectional distribution of assets. -With 64-bit precision enabled in JAX, we expect the solutions to be very close. +Set `num_households=50_000, T=500`. -Let's compute the maximum absolute differences: +First we write a function to run a single household forward in time and record +the final value of assets. -```{code-cell} ipython -# Convert JAX arrays to NumPy for comparison -a_star_jax_np = np.array(a_star_jax) -σ_star_jax_np = np.array(σ_star_jax) +The function takes a solution pair `c_vec` and `a_vec`, understanding them +as representing an optimal policy associated with a given model `ifp` -# Compute differences -a_diff = np.abs(a_star - a_star_jax_np) -σ_diff = np.abs(σ_star - σ_star_jax_np) +```{code-cell} ipython3 +@jax.jit +def simulate_household( + key, a_0, z_idx_0, c_vec, a_vec, ifp, T + ): + """ + Simulates a single household for T periods to approximate the stationary + distribution of assets. -print("Comparison of Numba and JAX solutions:") -print("=" * 50) -print(f"Max absolute difference in asset grid: {np.max(a_diff):.3e}") -print(f"Mean absolute difference in asset grid: {np.mean(a_diff):.3e}") -print(f"Max absolute difference in consumption: {np.max(σ_diff):.3e}") -print(f"Mean absolute difference in consumption: {np.mean(σ_diff):.3e}") -``` + - key is the state of the random number generator + - ifp is an instance of IFP + - c_vec, a_vec are the optimal consumption policy, endogenous grid for ifp -Let's also visualize the differences: + """ + # Extract parameters from ifp + γ, β, P, a_r, b_r, a_y, b_y, s_grid, η_draws, ζ_draws = ifp + n_z = len(P) + + # Create interpolation function for consumption policy + σ = lambda a, z_idx: jnp.interp(a, a_vec[:, z_idx], c_vec[:, z_idx]) + + # Simulate forward T periods + def update(t, state): + a, z_idx = state + # Draw next shock z' from P[z, z'] + current_key = jax.random.fold_in(key, 3*t) + z_next_idx = jax.random.choice(current_key, n_z, p=P[z_idx]).astype(jnp.int32) + # Draw η shock for income + η_key = jax.random.fold_in(key, 3*t + 1) + η = jax.random.normal(η_key) + # Draw ζ shock for return + ζ_key = jax.random.fold_in(key, 3*t + 2) + ζ = jax.random.normal(ζ_key) + # Compute stochastic return + R_next = R(z_next_idx, ζ, a_r, b_r) + # Compute income + Y_next = Y(z_next_idx, η, a_y, b_y) + # Update assets: a' = R' * (a - c) + Y' + a_next = R_next * (a - σ(a, z_idx)) + Y_next + # Return updated state + return a_next, z_next_idx + + initial_state = a_0, z_idx_0 + final_state = jax.lax.fori_loop(0, T, update, initial_state) + a_final, _ = final_state + return a_final +``` + +Now we write a function to simulate many households in parallel. + +```{code-cell} ipython3 +def compute_asset_stationary( + c_vec, a_vec, ifp, num_households=50_000, T=500, seed=1234 + ): + """ + Simulates num_households households for T periods to approximate + the stationary distribution of assets. -```{code-cell} ipython -fig, axes = plt.subplots(1, 2, figsize=(12, 4)) + Returns the final cross-section of asset holdings. -for z in range(len(ifp.P)): - axes[0].plot(a_star[:, z], a_diff[:, z], label=f'z={z}') - axes[1].plot(a_star[:, z], σ_diff[:, z], label=f'z={z}') + - ifp is an instance of IFP + - c_vec, a_vec are the optimal consumption policy and endogenous grid. -axes[0].set_xlabel('assets') -axes[0].set_ylabel('absolute difference') -axes[0].set_title('Asset Grid Differences: |Numba - JAX|') -axes[0].legend() + """ + # Extract parameters from ifp + γ, β, P, a_r, b_r, a_y, b_y, s_grid, η_draws, ζ_draws = ifp -axes[1].set_xlabel('assets') -axes[1].set_ylabel('absolute difference') -axes[1].set_title('Consumption Differences: |Numba - JAX|') -axes[1].legend() + # Start with assets = savings_grid_max / 2 + a_0_vector = jnp.full(num_households, s_grid[-1] / 2) + # Initialize the exogenous state of each household + z_idx_0_vector = jnp.zeros(num_households).astype(jnp.int32) -plt.tight_layout() -plt.show() + # Vectorize over many households + key = jax.random.PRNGKey(seed) + keys = jax.random.split(key, num_households) + # Vectorize simulate_household in (key, a_0, z_idx_0) + sim_all_households = jax.vmap( + simulate_household, in_axes=(0, 0, 0, None, None, None, None) + ) + assets = sim_all_households(keys, a_0_vector, z_idx_0_vector, c_vec, a_vec, ifp, T) + + return np.array(assets) ``` -As we can see, the differences between the two implementations are extremely small (on the order of machine precision), confirming that both methods produce essentially identical results. +We'll need some inequality measures for visualization, so let's define them first: -The tiny differences arise from: -- Different random number generators (NumPy vs JAX) -- Minor differences in floating-point operations order -- Different interpolation implementations +```{code-cell} ipython3 +def gini_coefficient(x): + """ + Compute the Gini coefficient for array x. -Despite these minor numerical differences, both implementations converge to the same optimal policy. + """ + x = jnp.asarray(x) + n = len(x) + x_sorted = jnp.sort(x) + # Compute Gini coefficient + cumsum = jnp.cumsum(x_sorted) + a = (2 * jnp.sum((jnp.arange(1, n+1)) * x_sorted)) / (n * cumsum[-1]) + return a - (n + 1) / n + + +def top_share( + x: jnp.array, # array of wealth values + p: float=0.01 # fraction of top households (default 0.01 for top 1%) + ): + """ + Compute the share of total wealth held by the top p fraction of households. -The JAX implementation provides several advantages: + """ + x = jnp.asarray(x) + x_sorted = jnp.sort(x) + # Number of households in top p% + n_top = int(jnp.ceil(len(x) * p)) + # Wealth held by top p% + wealth_top = jnp.sum(x_sorted[-n_top:]) + # Total wealth + wealth_total = jnp.sum(x_sorted) + return wealth_top / wealth_total +``` + +Now we call the function, generate the asset distribution and visualize it: + +```{code-cell} ipython3 +ifp = create_ifp() +# Extract parameters for initialization +s_grid = ifp.s_grid +n_z = len(ifp.P) +a_init = s_grid[:, None] * jnp.ones(n_z) +c_init = a_init +a_vec, c_vec = solve_model_time_iter(ifp, a_init, c_init) +assets = compute_asset_stationary(c_vec, a_vec, ifp, num_households=200_000) + +# Diagnostic: Check extrapolation issues +print(f"\n=== Grid and Asset Diagnostics ===") +print(f"Grid max (s_grid[-1]): {ifp.s_grid[-1]:.2f}") +print(f"Endogenous grid max (a_vec.max()): {a_vec.max():.2f}") +print(f"Simulated assets max: {assets.max():.2f}") +print(f"Simulated assets mean: {assets.mean():.2f}") +print(f"Simulated assets median: {np.median(assets):.2f}") +print(f"Fraction of households beyond grid: {(assets > a_vec.max()).mean():.4f}") +print(f"Fraction beyond 0.9 * grid_max: {(assets > 0.9 * a_vec.max()).mean():.4f}") +print() + +# Compute Gini coefficient for the plot +gini_plot = gini_coefficient(assets) + +# Plot: Histogram with log-scale y-axis +fig, ax = plt.subplots(figsize=(10, 6)) +ax.hist(assets, bins=40, alpha=0.5, density=True) +ax.set_yscale('log') +ax.set(xlabel='assets', ylabel='density (log scale)', + title="Wealth Distribution") +plt.tight_layout() +plt.show() +``` -1. **GPU/TPU acceleration**: JAX can automatically utilize GPU/TPU hardware for faster computation -2. **Automatic differentiation**: JAX provides automatic differentiation, which can be useful for sensitivity analysis -3. **Functional programming**: JAX encourages a functional style that can be easier to reason about and parallelize +The histogram shows the wealth distribution with the y-axis on a log scale, allowing us to see both the mass of households at low wealth levels and the long right tail of the distribution. -## Exercises -```{exercise} -:label: ifpa_ex1 +## Wealth Inequality -Let's repeat our {ref}`earlier exercise ` on the long-run -cross sectional distribution of assets. +Lets' look at wealth inequality by computing some standard measures of this phenomenon. -In that exercise, we used a relatively simple income fluctuation model. +We will also examine how inequality varies with the interest rate. -In the solution, we found the shape of the asset distribution to be unrealistic. -In particular, we failed to match the long right tail of the wealth distribution. +### Measuring Inequality -Your task is to try again, repeating the exercise, but now with our more sophisticated model. +Let's print the Gini coefficient and the top 1% wealth share from our simulation: -Use the default parameters. -``` +```{code-cell} ipython3 +gini = gini_coefficient(assets) +top1 = top_share(assets, p=0.01) -```{solution-start} ifpa_ex1 -:class: dropdown +print(f"Gini coefficient: {gini:.4f}") +print(f"Top 1% wealth share: {top1:.4f}") ``` -First we write a function to compute a long asset series. +Recent numbers suggest that -Because we want to JIT-compile the function, we code the solution in a way -that breaks some rules on good programming style. +* the Gini coefficient for wealth in the US is around 0.8 +* the top 1% wealth share is over 0.3 -For example, we will pass in the solutions `a_star, σ_star` along with -`ifp`, even though it would be more natural to just pass in `ifp` and then -solve inside the function. +Our model with stochastic returns generates a Gini coefficient close to the +empirical value, demonstrating that capital income risk is an important factor +in wealth inequality. -The reason we do this is that `solve_model_time_iter` is not -JIT-compiled. +The top 1% wealth share is, however, too large. -```{code-cell} python3 -@jit -def compute_asset_series(ifp, a_star, σ_star, z_seq, T=500_000): - """ - Simulates a time series of length T for assets, given optimal - savings behavior. +Our model needs proper calibration and additional work -- we set these tasks aside for now. - * ifp is an instance of IFP - * a_star is the endogenous grid solution - * σ_star is optimal consumption on the grid - * z_seq is a time path for {Z_t} +## Exercises - """ +```{exercise} +:label: ifp_advanced_ex1 - # Create consumption function by linear interpolation - σ = lambda a, z: np.interp(a, a_star[:, z], σ_star[:, z]) - - # Simulate the asset path - a = np.zeros(T+1) - for t in range(T): - z = z_seq[t] - ζ, η = np.random.randn(), np.random.randn() - R = ifp.R(z, ζ) - Y = ifp.Y(z, η) - a[t+1] = R * (a[t] - σ(a[t], z)) + Y - return a -``` +Plot how the Gini coefficient varies with the volatility of returns on assets. + +Specifically, compute the Gini coefficient for values of `a_r` ranging from 0.10 to 0.16 (use at least 5 different values) and plot the results. -Now we call the function, generate the series and then histogram it, using the -solutions computed above. +What does this tell you about the relationship between capital income risk and wealth inequality? -```{code-cell} python3 -T = 1_000_000 -mc = MarkovChain(ifp.P) -z_seq = mc.simulate(T, random_state=1234) +``` -a = compute_asset_series(ifp, a_star, σ_star, z_seq, T=T) +```{solution-start} ifp_advanced_ex1 +:class: dropdown +``` -fig, ax = plt.subplots() -ax.hist(a, bins=40, alpha=0.5, density=True) -ax.set(xlabel='assets') +We loop over different values of `a_r`, solve the model for each, simulate the wealth distribution, and compute the Gini coefficient. + +```{code-cell} ipython3 +# Range of a_r values to explore +a_r_vals = np.linspace(0.10, 0.16, 7) +gini_vals = [] + +print("Computing Gini coefficients for different return volatilities...\n") + +for a_r in a_r_vals: + print(f"a_r = {a_r:.3f}...", end=" ") + + # Create model with this a_r value + ifp_temp = create_ifp(a_r=a_r, grid_max=100) + + # Solve the model + s_grid_temp = ifp_temp.s_grid + n_z_temp = len(ifp_temp.P) + a_init_temp = s_grid_temp[:, None] * jnp.ones(n_z_temp) + c_init_temp = a_init_temp + a_vec_temp, c_vec_temp = solve_model_time_iter( + ifp_temp, a_init_temp, c_init_temp, verbose=False + ) + + # Simulate households + assets_temp = compute_asset_stationary( + c_vec_temp, a_vec_temp, ifp_temp, num_households=200_000 + ) + + # Compute Gini coefficient + gini_temp = gini_coefficient(assets_temp) + gini_vals.append(gini_temp) + print(f"Gini = {gini_temp:.4f}") + +# Plot the results +fig, ax = plt.subplots(figsize=(10, 6)) +ax.plot(a_r_vals, gini_vals, 'o-', linewidth=2, markersize=8) +ax.set(xlabel='Return volatility (a_r)', + ylabel='Gini coefficient', + title='Wealth Inequality vs Return Volatility') +ax.axhline(y=0.8, color='r', linestyle='--', linewidth=1, + label='Empirical US Gini (~0.8)') +ax.legend() +plt.tight_layout() plt.show() ``` -Now we have managed to successfully replicate the long right tail of the -wealth distribution. +The plot shows that wealth inequality (measured by the Gini coefficient) increases with return volatility. -Here's another view of this using a horizontal violin plot. +This demonstrates that capital income risk is a key driver of wealth inequality. + +When returns are more volatile, lucky households who experience sequences of +high returns accumulate substantially more wealth than unlucky households, +leading to greater inequality in the wealth distribution. -```{code-cell} python3 -fig, ax = plt.subplots() -ax.violinplot(a, vert=False, showmedians=True) -ax.set(xlabel='assets') -plt.show() -``` ```{solution-end} ```