From fad3450da2ceae383e99876579121d9de8168c1b Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Thu, 27 Nov 2025 06:51:34 +0900 Subject: [PATCH] Fix variable naming and add timing comparison to IFP EGM lecture MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Standardize function signatures: solve_model and solve_model_numpy now use (c_init, a_init) order consistently - Fix internal variable naming: use c_in/c_out and a_in/a_out for iteration variables - Update β from 0.96 to 0.94 to support higher interest rates (up to 6%) - Adjust interest rate ranges to [0, 0.05] to avoid instability - Add new Timing subsection comparing NumPy vs JAX performance - Simplify notation: a^e_{ij} → a_{ij} throughout - Improve variable naming consistency: c_vals/ae_vals → c_vec/a_vec, c_in/a_in → c_out/a_out 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- lectures/ifp_egm.md | 318 ++++++++++++++++++++++++-------------------- 1 file changed, 176 insertions(+), 142 deletions(-) diff --git a/lectures/ifp_egm.md b/lectures/ifp_egm.md index d5685cfdb..2eec55754 100644 --- a/lectures/ifp_egm.md +++ b/lectures/ifp_egm.md @@ -321,10 +321,10 @@ $$ We then obtain a corresponding endogenous grid of current assets via $$ - a^e_{ij} := c_{ij} + s_i. + a_{ij} := c_{ij} + s_i. $$ -Notice that, for each $j$, we have $a^e_{0j} = c_{0j} = 0$. +Notice that, for each $j$, we have $a_{0j} = c_{0j} = 0$. This anchors the interpolation at the correct value at the origin, since, without borrowing, consumption is zero when assets are zero. @@ -332,7 +332,7 @@ without borrowing, consumption is zero when assets are zero. Our next guess of the policy function, which we write as $K\sigma$, is the linear interpolation of the interpolation points -$$ \{(a^e_{0j}, c_{0j}), \ldots, (a^e_{mj}, c_{mj})\} $$ +$$ \{(a_{0j}, c_{0j}), \ldots, (a_{mj}, c_{mj})\} $$ for each $j$. @@ -353,17 +353,6 @@ $$ u(c) = \frac{c^{1 - \gamma}} {1 - \gamma} $$ -Here are the utility-related functions: - -```{code-cell} ipython3 -@numba.jit -def u_prime(c, γ): - return c**(-γ) - -@numba.jit -def u_prime_inv(c, γ): - return c**(-1/γ) -``` ### Set Up @@ -415,8 +404,8 @@ guess $K\sigma$. In practice, it takes in -* a guess of optimal consumption values $c_{ij}$, stored as `c_vals` -* and a corresponding set of endogenous grid points $a^e_{ij}$, stored as `ae_vals` +* a guess of optimal consumption values $c_{ij}$, stored as `c_vec` +* and a corresponding set of endogenous grid points $a^e_{ij}$, stored as `a_vec` These are converted into a consumption policy $a \mapsto \sigma(a, z_j)$ by linear interpolation of $(a^e_{ij}, c_{ij})$ over $i$ for each $j$. @@ -440,8 +429,8 @@ with each $\eta_{\ell}$ being a standard normal draw. ```{code-cell} ipython3 @numba.jit def K_numpy( - c_vals: np.ndarray, # Initial guess of σ on grid endogenous grid - ae_vals: np.ndarray, # Initial endogenous grid + c_in: np.ndarray, # Initial guess of σ on grid endogenous grid + a_in: np.ndarray, # Initial endogenous grid ifp_numpy: IFPNumPy ) -> np.ndarray: """ @@ -466,7 +455,7 @@ def K_numpy( def y(z, η): return np.exp(a_y * η + z * b_y) - new_c_vals = np.zeros_like(c_vals) + c_out = np.zeros_like(c_in) for i in range(1, n_a): # Start from 1 for positive savings levels for j in range(n_z): @@ -481,7 +470,7 @@ def K_numpy( # Calculate next period assets next_a = R * s[i] + y(z_prime, η) # Interpolate to get σ(R s_i + y(z', η), z') - next_c = np.interp(next_a, ae_vals[:, k], c_vals[:, k]) + next_c = np.interp(next_a, a_in[:, k], c_in[:, k]) # Add to the inner sum inner_sum += u_prime(next_c) # Average over η draws to approximate the integral @@ -491,11 +480,11 @@ def K_numpy( expectation += inner_mean_k * Π[j, k] # Calculate updated c_{ij} values - new_c_vals[i, j] = u_prime_inv(β * R * expectation) + c_out[i, j] = u_prime_inv(β * R * expectation) - new_ae_vals = new_c_vals + s[:, None] + a_out = c_out + s[:, None] - return new_c_vals, new_ae_vals + return c_out, a_out ``` To solve the model we use a simple while loop. @@ -503,8 +492,8 @@ To solve the model we use a simple while loop. ```{code-cell} ipython3 def solve_model_numpy( ifp_numpy: IFPNumPy, - ae_vals_init: np.ndarray, - c_vals_init: np.ndarray, + c_init: np.ndarray, + a_init: np.ndarray, tol: float = 1e-5, max_iter: int = 1_000 ) -> np.ndarray: @@ -512,17 +501,17 @@ def solve_model_numpy( Solve the model using time iteration with EGM. """ - c_vals, ae_vals = c_vals_init, ae_vals_init + c_in, a_in = c_init, a_init i = 0 error = tol + 1 while error > tol and i < max_iter: - new_c_vals, new_ae_vals = K_numpy(c_vals, ae_vals, ifp_numpy) - error = np.max(np.abs(new_c_vals - c_vals)) + c_out, a_out = K_numpy(c_in, a_in, ifp_numpy) + error = np.max(np.abs(c_out - c_in)) i = i + 1 - c_vals, ae_vals = new_c_vals, new_ae_vals + c_in, a_in = c_out, a_out - return c_vals, ae_vals + return c_out, a_out ``` Let's road test the EGM code. @@ -531,11 +520,11 @@ Let's road test the EGM code. ifp_numpy = create_ifp() R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp_numpy # Initial conditions -- agent consumes everything -ae_vals_init = s[:, None] * np.ones(len(z_grid)) -c_vals_init = ae_vals_init +a_init = s[:, None] * np.ones(len(z_grid)) +c_init = a_init # Solve from these initial conditions -c_vals, ae_vals = solve_model_numpy( - ifp_numpy, c_vals_init, ae_vals_init +c_vec, a_vec = solve_model_numpy( + ifp_numpy, c_init, a_init ) ``` @@ -544,8 +533,8 @@ Here's a plot of the optimal consumption policy for each $z$ state ```{code-cell} ipython3 fig, ax = plt.subplots() -ax.plot(ae_vals[:, 0], c_vals[:, 0], label='bad state') -ax.plot(ae_vals[:, 1], c_vals[:, 1], label='good state') +ax.plot(a_vec[:, 0], c_vec[:, 0], label='bad state') +ax.plot(a_vec[:, 1], c_vec[:, 1], label='good state') ax.set(xlabel='assets', ylabel='consumption') ax.legend() plt.show() @@ -578,7 +567,7 @@ class IFP(NamedTuple): def create_ifp(r=0.01, - β=0.96, + β=0.94, γ=1.5, Π=((0.6, 0.4), (0.05, 0.95)), @@ -608,8 +597,8 @@ guess $K\sigma$. ```{code-cell} ipython3 def K( - c_vals: jnp.ndarray, - ae_vals: jnp.ndarray, + c_in: jnp.ndarray, + a_in: jnp.ndarray, ifp: IFP ) -> jnp.ndarray: """ @@ -634,24 +623,28 @@ def K( def y(z, η): return jnp.exp(a_y * η + z * b_y) - def compute_c_ij(i, j): + def compute_c(i, j): " Function to compute consumption for one (i, j) pair where i >= 1. " - # For each k (future z state), compute the integral over η def compute_expectation_k(k): - z_prime = z_grid[k] - - # For each η draw, compute u'(σ(R * s_i + y(z', η), z')) - def compute_for_eta(η): - next_a = R * s[i] + y(z_prime, η) - # Interpolate to get σ(R * s_i + y(z', η), z') - next_c = jnp.interp(next_a, ae_vals[:, k], c_vals[:, k]) - # Return u'(σ(R * s_i + y(z', η), z')) + """ + For each k, approximate the integral + + ∫ u'(σ(R s_i + y(z_k, η'), z_k)) φ(η') dη' + """ + + def compute_mu_at_eta(η): + " For each η draw, compute u'(σ(R * s_i + y(z_k, η), z_k)) " + next_a = R * s[i] + y(z_grid[k], η) + # Interpolate to get σ(R * s_i + y(z_k, η), z_k) + next_c = jnp.interp(next_a, a_in[:, k], c_in[:, k]) + # Return u'(σ(R * s_i + y(z_k, η), z_k)) return u_prime(next_c) - # Average over η draws to approximate the integral - # ∫ u'(σ(R s_i + y(z', η'), z')) φ(η') dη' when z' = z_grid[k] - return jnp.mean(jax.vmap(compute_for_eta)(η_draws)) + # Average over η draws to approximate the inner integral + # ∫ u'(σ(R s_i + y(z_k, η'), z_k)) φ(η') dη' + all_draws = jax.vmap(compute_mu_at_eta)(η_draws) + return jnp.mean(all_draws) # Compute expectation: Σ_k [∫ u'(σ(...)) φ(η) dη] * Π[j, k] expectations = jax.vmap(compute_expectation_k)(jnp.arange(n_z)) @@ -665,23 +658,21 @@ def K( j_grid = jnp.arange(n_z) # vmap over j for each i - compute_c_i = jax.vmap(compute_c_ij, in_axes=(None, 0)) + compute_c_i = jax.vmap(compute_c, in_axes=(None, 0)) # vmap over i compute_c = jax.vmap(lambda i: compute_c_i(i, j_grid)) - # Compute consumption for i >= 1 - new_c_interior = compute_c(i_grid) # Shape: (n_a-1, n_z) - + c_out_interior = compute_c(i_grid) # Shape: (n_a-1, n_z) # For i = 0, set consumption to 0 - new_c_boundary = jnp.zeros((1, n_z)) + c_out_boundary = jnp.zeros((1, n_z)) # Concatenate boundary and interior - new_c_vals = jnp.concatenate([new_c_boundary, new_c_interior], axis=0) + c_out = jnp.concatenate([c_out_boundary, c_out_interior], axis=0) # Compute endogenous asset grid: a^e_{ij} = c_{ij} + s_i - new_ae_vals = new_c_vals + s[:, None] + a_out = c_out + s[:, None] - return new_c_vals, new_ae_vals + return c_out, a_out ``` @@ -691,8 +682,8 @@ Here's a jit-accelerated iterative routine to solve the model using this operato @jax.jit def solve_model( ifp: IFP, - c_vals_init: jnp.ndarray, # Initial guess of σ on grid endogenous grid - ae_vals_init: jnp.ndarray, # Initial endogenous grid + c_init: jnp.ndarray, # Initial guess of σ on grid endogenous grid + a_init: jnp.ndarray, # Initial endogenous grid tol: float = 1e-5, max_iter: int = 1000 ) -> jnp.ndarray: @@ -702,22 +693,22 @@ def solve_model( """ def condition(loop_state): - c_vals, ae_vals, i, error = loop_state + c_in, a_in, i, error = loop_state return (error > tol) & (i < max_iter) def body(loop_state): - c_vals, ae_vals, i, error = loop_state - new_c_vals, new_ae_vals = K(c_vals, ae_vals, ifp) - error = jnp.max(jnp.abs(new_c_vals - c_vals)) + c_in, a_in, i, error = loop_state + c_out, a_out = K(c_in, a_in, ifp) + error = jnp.max(jnp.abs(c_out - c_in)) i += 1 - return new_c_vals, new_ae_vals, i, error + return c_out, a_out, i, error i, error = 0, tol + 1 - initial_state = (c_vals_init, ae_vals_init, i, error) + initial_state = (c_init, a_init, i, error) final_loop_state = jax.lax.while_loop(condition, body, initial_state) - c_vals, ae_vals, i, error = final_loop_state + c_out, a_out, i, error = final_loop_state - return c_vals, ae_vals + return c_out, a_out ``` @@ -729,18 +720,18 @@ Let's road test the EGM code. ifp = create_ifp() R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp # Set initial conditions where the agent consumes everything -ae_vals_init = s[:, None] * jnp.ones(len(z_grid)) -c_vals_init = ae_vals_init +a_init = s[:, None] * jnp.ones(len(z_grid)) +c_init = a_init # Solve starting from these initial conditions -c_vals_jax, ae_vals_jax = solve_model(ifp, c_vals_init, ae_vals_init) +c_vec_jax, a_vec_jax = solve_model(ifp, c_init, a_init) ``` To verify the correctness of our JAX implementation, let's compare it with the NumPy version we developed earlier. ```{code-cell} ipython3 # Compare the results -max_c_diff = np.max(np.abs(np.array(c_vals) - c_vals_jax)) -max_ae_diff = np.max(np.abs(np.array(ae_vals) - ae_vals_jax)) +max_c_diff = np.max(np.abs(np.array(c_vec) - c_vec_jax)) +max_ae_diff = np.max(np.abs(np.array(a_vec) - a_vec_jax)) print(f"Maximum difference in consumption policy: {max_c_diff:.2e}") print(f"Maximum difference in asset grid: {max_ae_diff:.2e}") @@ -751,12 +742,54 @@ the two approaches. (Remaining differences are mainly due to different Monte Carlo integration outcomes over relatively small samples.) +### Timing + +Now let's compare the execution time between NumPy and JAX implementations. + +```{code-cell} ipython3 +import time + +# Set up initial conditions for NumPy version +s_np = np.array(s) +z_grid_np = np.array(z_grid) +a_init_np = s_np[:, None] * np.ones(len(z_grid_np)) +c_init_np = a_init_np.copy() + +# Set up initial conditions for JAX version +a_init_jx = s[:, None] * jnp.ones(len(z_grid)) +c_init_jx = a_init_jx + +# Time NumPy version +start = time.time() +c_vec_np, a_vec_np = solve_model_numpy(ifp_numpy, c_init_np, a_init_np) +numpy_time = time.time() - start + +# Time JAX version (with compilation) +start = time.time() +c_vec_jx, a_vec_jx = solve_model(ifp, c_init_jx, a_init_jx) +c_vec_jx.block_until_ready() +jax_time_with_compile = time.time() - start + +# Time JAX version (without compilation - second run) +start = time.time() +c_vec_jx, a_vec_jx = solve_model(ifp, c_init_jx, a_init_jx) +c_vec_jx.block_until_ready() +jax_time = time.time() - start + +print(f"NumPy time: {numpy_time:.4f} seconds") +print(f"JAX time (with compile): {jax_time_with_compile:.4f} seconds") +print(f"JAX time (without compile): {jax_time:.4f} seconds") +print(f"Speedup (NumPy/JAX): {numpy_time/jax_time:.2f}x") +``` + +The JAX implementation is significantly faster due to JIT compilation and GPU/TPU acceleration (if available). + Here's a plot of the optimal policy for each $z$ state ```{code-cell} ipython3 fig, ax = plt.subplots() -ax.plot(ae_vals[:, 0], c_vals[:, 0], label='bad state') -ax.plot(ae_vals[:, 1], c_vals[:, 1], label='good state') +ax.plot(a_vec[:, 0], c_vec[:, 0], label='bad state') +ax.plot(a_vec[:, 1], c_vec[:, 1], label='good state') ax.set(xlabel='assets', ylabel='consumption') ax.legend() plt.show() @@ -771,9 +804,6 @@ default parameters, let's look at the ```{code-cell} ipython3 fig, ax = plt.subplots() -# Compute mean labor income at each z state -R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp - def y(z, η): return jnp.exp(a_y * η + z * b_y) @@ -782,6 +812,8 @@ def y_bar(k): Taking z = z_grid[k], compute an approximation to E_z Y' = Σ_{z'} ∫ y(z', η') φ(η') dη' Π[z, z'] + + This is the expectation of Y_{t+1} given Z_t = z. """ # Approximate ∫ y(z', η') φ(η') dη' at given z' def mean_y_at_z(z_prime): @@ -793,7 +825,7 @@ def y_bar(k): for k, label in zip((0, 1), ('low income', 'high income')): # Interpolate consumption policy on the savings grid - c_on_grid = jnp.interp(s, ae_vals[:, k], c_vals[:, k]) + c_on_grid = jnp.interp(s, a_vec[:, k], c_vec[:, k]) ax.plot(s, R * (s - c_on_grid) + y_bar(k) , label=label) ax.plot(s, s, 'k--') @@ -855,14 +887,14 @@ Let's see if we match up: ```{code-cell} ipython3 ifp_cake_eating = create_ifp(r=0.0, z_grid=(-jnp.inf, -jnp.inf)) R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp_cake_eating -ae_vals_init = s[:, None] * jnp.ones(len(z_grid)) -c_vals_init = ae_vals_init -c_vals, ae_vals = solve_model(ifp_cake_eating, c_vals_init, ae_vals_init) +a_init = s[:, None] * jnp.ones(len(z_grid)) +c_init = a_init +c_vec, a_vec = solve_model(ifp_cake_eating, c_init, a_init) fig, ax = plt.subplots() -ax.plot(ae_vals[:, 0], c_vals[:, 0], label='numerical') -ax.plot(ae_vals[:, 0], - c_star(ae_vals[:, 0], ifp_cake_eating.β, ifp_cake_eating.γ), +ax.plot(a_vec[:, 0], c_vec[:, 0], label='numerical') +ax.plot(a_vec[:, 0], + c_star(a_vec[:, 0], ifp_cake_eating.β, ifp_cake_eating.γ), '--', label='analytical') ax.set(xlabel='assets', ylabel='consumption') ax.legend() @@ -886,13 +918,13 @@ Set `num_households=50_000, T=500`. First we write a function to run a single household forward in time and record the final value of assets. -The function takes a solution pair `c_vals` and `ae_vals`, understanding them +The function takes a solution pair `c_vec` and `a_vec`, understanding them as representing an optimal policy associated with a given model `ifp` ```{code-cell} ipython3 @jax.jit def simulate_household( - key, a_0, z_idx_0, c_vals, ae_vals, ifp, T + key, a_0, z_idx_0, c_vec, a_vec, ifp, T ): """ Simulates a single household for T periods to approximate the stationary @@ -900,7 +932,7 @@ def simulate_household( - key is the state of the random number generator - ifp is an instance of IFP - - c_vals, ae_vals are the optimal consumption policy, endogenous grid for ifp + - c_vec, a_vec are the optimal consumption policy, endogenous grid for ifp """ R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp @@ -910,7 +942,7 @@ def simulate_household( return jnp.exp(a_y * η + z * b_y) # Create interpolation function for consumption policy - σ = lambda a, z_idx: jnp.interp(a, ae_vals[:, z_idx], c_vals[:, z_idx]) + σ = lambda a, z_idx: jnp.interp(a, a_vec[:, z_idx], c_vec[:, z_idx]) # Simulate forward T periods def update(t, state): @@ -937,7 +969,7 @@ Now we write a function to simulate many households in parallel. ```{code-cell} ipython3 def compute_asset_stationary( - c_vals, ae_vals, ifp, num_households=50_000, T=500, seed=1234 + c_vec, a_vec, ifp, num_households=50_000, T=500, seed=1234 ): """ Simulates num_households households for T periods to approximate @@ -946,7 +978,7 @@ def compute_asset_stationary( Returns the final cross-section of asset holdings. - ifp is an instance of IFP - - c_vals, ae_vals are the optimal consumption policy and endogenous grid. + - c_vec, a_vec are the optimal consumption policy and endogenous grid. """ R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp @@ -954,7 +986,7 @@ def compute_asset_stationary( # Create interpolation function for consumption policy # Interpolate on the endogenous grid - σ = lambda a, z_idx: jnp.interp(a, ae_vals[:, z_idx], c_vals[:, z_idx]) + σ = lambda a, z_idx: jnp.interp(a, a_vec[:, z_idx], c_vec[:, z_idx]) # Start with assets = savings_grid_max / 2 a_0_vector = jnp.full(num_households, s[-1] / 2) @@ -968,7 +1000,7 @@ def compute_asset_stationary( sim_all_households = jax.vmap( simulate_household, in_axes=(0, 0, 0, None, None, None, None) ) - assets = sim_all_households(keys, a_0_vector, z_idx_0_vector, c_vals, ae_vals, ifp, T) + assets = sim_all_households(keys, a_0_vector, z_idx_0_vector, c_vec, a_vec, ifp, T) return np.array(assets) ``` @@ -978,31 +1010,29 @@ Now we call the function, generate the asset distribution and histogram it: ```{code-cell} ipython3 ifp = create_ifp() R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp -ae_vals_init = s[:, None] * jnp.ones(len(z_grid)) -c_vals_init = ae_vals_init -c_vals, ae_vals = solve_model(ifp, c_vals_init, ae_vals_init) -assets = compute_asset_stationary(c_vals, ae_vals, ifp) +a_init = s[:, None] * jnp.ones(len(z_grid)) +c_init = a_init +c_vec, a_vec = solve_model(ifp, c_init, a_init) +assets = compute_asset_stationary(c_vec, a_vec, ifp) fig, ax = plt.subplots() ax.hist(assets, bins=20, alpha=0.5, density=True) -ax.set(xlabel='assets') +ax.set(xlabel='assets', title="Cross-sectional distribution of wealth") plt.show() ``` -The asset distribution now shows more realistic features compared to the simple -model without transient income shocks. +Some aspects of the wealth distribution look implausible, such as its left skew. -The addition of the IID income shock $\eta_t$ creates more income volatility, -which induces households to save more for precautionary reasons. - -This helps generate more wealth inequality compared to a model with only the -Markov component. +In the next section we study additional features of this distribution, including +measures of inequality. ## Wealth Inequality -In this section we examine wealth inequality in more detail by computing -standard measures of inequality and examining how they vary with the interest rate. +Lets' wealth inequality by computing some standard measures of this phenomenon. + +We will also examine how inequality varies with the interest rate. + ### Measuring Inequality @@ -1019,28 +1049,23 @@ def gini_coefficient(x): """ Compute the Gini coefficient for array x. - The Gini coefficient is a measure of inequality that ranges from - 0 (perfect equality) to 1 (perfect inequality). """ x = jnp.asarray(x) n = len(x) - # Sort values x_sorted = jnp.sort(x) # Compute Gini coefficient cumsum = jnp.cumsum(x_sorted) - return (2 * jnp.sum((jnp.arange(1, n+1)) * x_sorted)) / (n * cumsum[-1]) - (n + 1) / n + a = (2 * jnp.sum((jnp.arange(1, n+1)) * x_sorted)) / (n * cumsum[-1]) + return a - (n + 1) / n -def top_share(x, p=0.01): +def top_share( + x: jnp.array, # array of wealth values + p: float=0.01 # fraction of top households (default 0.01 for top 1%) + ): """ Compute the share of total wealth held by the top p fraction of households. - Parameters: - x: array of wealth values - p: fraction of top households (default 0.01 for top 1%) - - Returns: - Share of total wealth held by top p fraction """ x = jnp.asarray(x) x_sorted = jnp.sort(x) @@ -1050,7 +1075,7 @@ def top_share(x, p=0.01): wealth_top = jnp.sum(x_sorted[-n_top:]) # Total wealth wealth_total = jnp.sum(x_sorted) - return wealth_top / wealth_total if wealth_total > 0 else 0.0 + return wealth_top / wealth_total ``` Let's compute these measures for our baseline simulation: @@ -1070,6 +1095,13 @@ Recent numbers suggest that * the Gini coefficient for wealth in the US is around 0.8 * the top 1% wealth share is over 0.3 +Of course we have not made much effort to accurately estimate or calibrate our +parameters. + +But actually the cause is deeper --- a model with this structure [will always +struggle](https://arxiv.org/pdf/1807.08404) to replicate the observed wealth +distribution. + In a {doc}`later lecture ` we'll see if we can improve on these numbers. @@ -1079,7 +1111,7 @@ numbers. Let's examine how wealth inequality varies with the interest rate $r$. -Economic intuition suggests that higher interest rates might increase wealth +We conjecture that higher interest rates will increase wealth inequality, as wealthier households benefit more from returns on their assets. Let's investigate empirically: @@ -1087,7 +1119,7 @@ Let's investigate empirically: ```{code-cell} ipython3 # Test over 8 interest rate values M = 8 -r_vals = np.linspace(0, 0.015, M) +r_vals = np.linspace(0, 0.05, M) gini_vals = [] top1_vals = [] @@ -1097,19 +1129,19 @@ for r in r_vals: print(f'Analyzing inequality at r = {r:.4f}') ifp = create_ifp(r=r) R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp - ae_vals_init = s[:, None] * jnp.ones(len(z_grid)) - c_vals_init = ae_vals_init - c_vals, ae_vals = solve_model(ifp, c_vals_init, ae_vals_init) + a_init = s[:, None] * jnp.ones(len(z_grid)) + c_init = a_init + c_vec, a_vec = solve_model(ifp, c_init, a_init) assets = compute_asset_stationary( - c_vals, ae_vals, ifp, num_households=50_000, T=500 + c_vec, a_vec, ifp, num_households=50_000, T=500 ) gini = gini_coefficient(assets) top1 = top_share(assets, p=0.01) gini_vals.append(gini) top1_vals.append(top1) # Use last solution as initial conditions for the policy solver - c_vals_init = c_vals - ae_vals_init = ae_vals + c_init = c_vec + a_init = a_vec ``` Now let's visualize the results: @@ -1172,14 +1204,14 @@ fig, ax = plt.subplots() for r_val in r_vals: ifp = create_ifp(r=r_val) R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp - ae_vals_init = s[:, None] * jnp.ones(len(z_grid)) - c_vals_init = ae_vals_init - c_vals, ae_vals = solve_model(ifp, c_vals_init, ae_vals_init) + a_init = s[:, None] * jnp.ones(len(z_grid)) + c_init = a_init + c_vec, a_vec = solve_model(ifp, c_init, a_init) # Plot policy - ax.plot(ae_vals[:, 0], c_vals[:, 0], label=f'$r = {r_val:.3f}$') + ax.plot(a_vec[:, 0], c_vec[:, 0], label=f'$r = {r_val:.3f}$') # Start next round with last solution - c_vals_init = c_vals - ae_vals_init = ae_vals + c_init = c_vec + a_init = a_vec ax.set(xlabel='asset level', ylabel='consumption (low income)') ax.legend() @@ -1217,7 +1249,7 @@ For the interest rate grid, use ```{code-cell} ipython3 M = 8 -r_vals = np.linspace(0, 0.015, M) +r_vals = np.linspace(0, 0.05, M) ``` ```{exercise-end} @@ -1238,16 +1270,18 @@ for r in r_vals: print(f'Solving model at r = {r}') ifp = create_ifp(r=r) R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp - ae_vals_init = s[:, None] * jnp.ones(len(z_grid)) - c_vals_init = ae_vals_init - c_vals, ae_vals = solve_model(ifp, c_vals_init, ae_vals_init) - assets = compute_asset_stationary(c_vals, ae_vals, ifp, num_households=10_000, T=500) + a_init = s[:, None] * jnp.ones(len(z_grid)) + c_init = a_init + c_vec, a_vec = solve_model(ifp, c_init, a_init) + assets = compute_asset_stationary( + c_vec, a_vec, ifp, num_households=10_000, T=500 + ) mean = np.mean(assets) asset_mean.append(mean) print(f' Mean assets: {mean:.4f}') # Start next round with last solution - c_vals_init = c_vals - ae_vals_init = ae_vals + c_init = c_vec + a_init = a_vec ax.plot(r_vals, asset_mean) ax.set(xlabel='interest rate', ylabel='capital')