diff --git a/lectures/aiyagari.md b/lectures/aiyagari.md index 683dd0861..1436c20f6 100644 --- a/lectures/aiyagari.md +++ b/lectures/aiyagari.md @@ -231,7 +231,7 @@ Below we provide code to solve the household problem, taking $r$ and $w$ as fixe ### Primitives and operators -We will solve the household problem using Howard policy iteration (see Ch 5 of [Dynamic Programming](https://dp.quantecon.org/)). +We will solve the household problem using value function iteration. First we set up a `NamedTuple` to store the parameters that define a household asset accumulation problem, as well as the grids used to solve it @@ -245,8 +245,8 @@ class Household(NamedTuple): def create_household(β=0.96, # Discount factor Π=[[0.9, 0.1], [0.1, 0.9]], # Markov chain z_grid=[0.1, 1.0], # Exogenous states - a_min=1e-10, a_max=20, # Asset grid - a_size=200): + a_min=1e-10, a_max=12.5, # Asset grid + a_size=100): """ Create a Household namedtuple with custom grids. """ @@ -278,7 +278,6 @@ $$ for all $(a, z, a')$. ```{code-cell} ipython3 -@jax.jit def B(v, household, prices): # Unpack β, a_grid, z_grid, Π = household @@ -303,125 +302,54 @@ def B(v, household, prices): The next function computes greedy policies ```{code-cell} ipython3 -@jax.jit def get_greedy(v, household, prices): """ - Computes a v-greedy policy σ, returned as a set of indices. If + Computes a v-greedy policy σ, returned as a set of indices. If σ[i, j] equals ip, then a_grid[ip] is the maximizer at i, j. """ # argmax over ap return jnp.argmax(B(v, household, prices), axis=-1) ``` -The following function computes the array $r_{\sigma}$ which gives current rewards given policy $\sigma$ +We define the Bellman operator $T$, which takes a value function $v$ and returns $Tv$ as given in the Bellman equation ```{code-cell} ipython3 -@jax.jit -def compute_r_σ(σ, household, prices): +def T(v, household, prices): """ - Compute current rewards at each i, j under policy σ. In particular, - - r_σ[i, j] = u((1 + r)a[i] + wz[j] - a'[ip]) - - when ip = σ[i, j]. + The Bellman operator. Takes a value function v and returns Tv. """ - # Unpack - β, a_grid, z_grid, Π = household - a_size, z_size = len(a_grid), len(z_grid) - r, w = prices - - # Compute r_σ[i, j] - a = jnp.reshape(a_grid, (a_size, 1)) - z = jnp.reshape(z_grid, (1, z_size)) - ap = a_grid[σ] - c = (1 + r) * a + w * z - ap - r_σ = u(c) - - return r_σ + return jnp.max(B(v, household, prices), axis=-1) ``` -The value $v_{\sigma}$ of a policy $\sigma$ is defined as - -$$ -v_{\sigma} = (I - \beta P_{\sigma})^{-1} r_{\sigma} -$$ - -(See Ch 5 of [Dynamic Programming](https://dp.quantecon.org/) for notation and background on Howard policy iteration.) - -To compute this vector, we set up the linear map $v \rightarrow R_{\sigma} v$, where $R_{\sigma} := I - \beta P_{\sigma}$. - -This map can be expressed as - -$$ -(R_{\sigma} v)(a, z) = v(a, z) - \beta \sum_{z'} v(\sigma(a, z), z') \Pi(z, z') -$$ - -(Notice that $R_\sigma$ is expressed as a linear operator rather than a matrix—this is much easier and cleaner to code, and also exploits sparsity.) +Here's value function iteration, which repeatedly applies the Bellman operator until convergence ```{code-cell} ipython3 @jax.jit -def R_σ(v, σ, household): - # Unpack +def value_function_iteration(household, prices, tol=1e-4, max_iter=10_000): + """ + Implements value function iteration using a compiled JAX loop. + """ β, a_grid, z_grid, Π = household a_size, z_size = len(a_grid), len(z_grid) - # Set up the array v[σ[i, j], jp] - zp_idx = jnp.arange(z_size) - zp_idx = jnp.reshape(zp_idx, (1, 1, z_size)) - σ = jnp.reshape(σ, (a_size, z_size, 1)) - V = v[σ, zp_idx] - - # Expand Π[j, jp] to Π[i, j, jp] - Π = jnp.reshape(Π, (1, z_size, z_size)) - - # Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Π[j, jp] - return v - β * jnp.sum(V * Π, axis=-1) -``` + def condition_function(loop_state): + i, v, error = loop_state + return jnp.logical_and(error > tol, i < max_iter) -The next function computes the lifetime value of a given policy + def update(loop_state): + i, v, error = loop_state + v_new = T(v, household, prices) + error = jnp.max(jnp.abs(v_new - v)) + return i + 1, v_new, error -```{code-cell} ipython3 -@jax.jit -def get_value(σ, household, prices): - """ - Get the lifetime value of policy σ by computing + # Initial loop state + v_init = jnp.zeros((a_size, z_size)) + loop_state_init = (0, v_init, tol + 1) - v_σ = R_σ^{-1} r_σ - """ - r_σ = compute_r_σ(σ, household, prices) - - # Reduce R_σ to a function in v - _R_σ = lambda v: R_σ(v, σ, household) + # Run the fixed point iteration + i, v, error = jax.lax.while_loop(condition_function, update, loop_state_init) - # Compute v_σ = R_σ^{-1} r_σ using an iterative routine. - return jax.scipy.sparse.linalg.bicgstab(_R_σ, r_σ)[0] -``` - -Here's the Howard policy iteration - -```{code-cell} ipython3 -def howard_policy_iteration(household, prices, - tol=1e-4, max_iter=10_000, verbose=False): - """ - Howard policy iteration routine. - """ - β, a_grid, z_grid, Π = household - a_size, z_size = len(a_grid), len(z_grid) - σ = jnp.zeros((a_size, z_size), dtype=int) - - v_σ = get_value(σ, household, prices) - i = 0 - error = tol + 1 - while error > tol and i < max_iter: - σ_new = get_greedy(v_σ, household, prices) - v_σ_new = get_value(σ_new, household, prices) - error = jnp.max(jnp.abs(v_σ_new - v_σ)) - σ = σ_new - v_σ = v_σ_new - i = i + 1 - if verbose: - print(f"iteration {i} with error {error}.") - return σ + return get_greedy(v, household, prices) ``` As a first example of what we can do, let's compute and plot an optimal accumulation policy at fixed prices @@ -437,8 +365,7 @@ print(f"Interest rate: {r}, Wage: {w}") ```{code-cell} ipython3 with qe.Timer(): - σ_star = howard_policy_iteration( - household, prices, verbose=True).block_until_ready() + σ_star = value_function_iteration(household, prices).block_until_ready() ``` The next plot shows asset accumulation policies at different values of the exogenous state @@ -560,7 +487,7 @@ def G(K, firm, household): # Generate a household object with these prices, compute # aggregate capital. prices = Prices(r=r, w=w) - σ_star = howard_policy_iteration(household, prices) + σ_star = value_function_iteration(household, prices) return capital_supply(σ_star, household) ``` @@ -640,8 +567,8 @@ def prices_to_capital_stock(household, r, firm): prices = Prices(r=r, w=w) # Compute the optimal policy - σ_star = howard_policy_iteration(household, prices) - + σ_star = value_function_iteration(household, prices) + # Compute capital supply return capital_supply(σ_star, household) @@ -752,3 +679,189 @@ plt.show() ```{solution-end} ``` + +```{exercise-start} +:label: aiyagari_ex3 +``` + +In this lecture, we used value function iteration to solve the household problem. + +An alternative is Howard policy iteration (HPI), which is discussed in detail in [Dynamic Programming](https://dp.quantecon.org/). + +HPI can be faster than VFI for some problems because it uses fewer but more computationally intensive iterations. + +Your task is to implement Howard policy iteration and compare the results with value function iteration. + +**Key concepts you'll need:** + +Howard policy iteration requires computing the value $v_{\sigma}$ of a policy $\sigma$, defined as: + +$$ +v_{\sigma} = (I - \beta P_{\sigma})^{-1} r_{\sigma} +$$ + +where $r_{\sigma}$ is the reward vector under policy $\sigma$, and $P_{\sigma}$ is the transition matrix induced by $\sigma$. + +To solve this, you'll need to: +1. Compute current rewards $r_{\sigma}(a, z) = u((1 + r)a + wz - \sigma(a, z))$ +2. Set up the linear operator $R_{\sigma}$ where $(R_{\sigma} v)(a, z) = v(a, z) - \beta \sum_{z'} v(\sigma(a, z), z') \Pi(z, z')$ +3. Solve $v_{\sigma} = R_{\sigma}^{-1} r_{\sigma}$ using `jax.scipy.sparse.linalg.bicgstab` + +You can use the `get_greedy` function that's already defined in this lecture. + +Implement the following Howard policy iteration routine: + +```python +def howard_policy_iteration(household, prices, + tol=1e-4, max_iter=10_000, verbose=False): + """ + Howard policy iteration routine. + """ + # Your code here + pass +``` + +Once implemented, compute the equilibrium capital stock using HPI and verify that it produces approximately the same result as VFI at the default parameter values. + +```{exercise-end} +``` + +```{solution-start} aiyagari_ex3 +:class: dropdown +``` + +First, we need to implement the helper functions for Howard policy iteration. + +The following function computes the array $r_{\sigma}$ which gives current rewards given policy $\sigma$: + +```{code-cell} ipython3 +def compute_r_σ(σ, household, prices): + """ + Compute current rewards at each i, j under policy σ. In particular, + + r_σ[i, j] = u((1 + r)a[i] + wz[j] - a'[ip]) + + when ip = σ[i, j]. + """ + # Unpack + β, a_grid, z_grid, Π = household + a_size, z_size = len(a_grid), len(z_grid) + r, w = prices + + # Compute r_σ[i, j] + a = jnp.reshape(a_grid, (a_size, 1)) + z = jnp.reshape(z_grid, (1, z_size)) + ap = a_grid[σ] + c = (1 + r) * a + w * z - ap + r_σ = u(c) + + return r_σ +``` + +The linear operator $R_{\sigma}$ is defined as: + +```{code-cell} ipython3 +def R_σ(v, σ, household): + # Unpack + β, a_grid, z_grid, Π = household + a_size, z_size = len(a_grid), len(z_grid) + + # Set up the array v[σ[i, j], jp] + zp_idx = jnp.arange(z_size) + zp_idx = jnp.reshape(zp_idx, (1, 1, z_size)) + σ = jnp.reshape(σ, (a_size, z_size, 1)) + V = v[σ, zp_idx] + + # Expand Π[j, jp] to Π[i, j, jp] + Π = jnp.reshape(Π, (1, z_size, z_size)) + + # Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Π[j, jp] + return v - β * jnp.sum(V * Π, axis=-1) +``` + +The next function computes the lifetime value of a given policy: + +```{code-cell} ipython3 +def get_value(σ, household, prices): + """ + Get the lifetime value of policy σ by computing + + v_σ = R_σ^{-1} r_σ + """ + r_σ = compute_r_σ(σ, household, prices) + + # Reduce R_σ to a function in v + _R_σ = lambda v: R_σ(v, σ, household) + + # Compute v_σ = R_σ^{-1} r_σ using an iterative routine. + return jax.scipy.sparse.linalg.bicgstab(_R_σ, r_σ)[0] +``` + +Now we can implement Howard policy iteration: + +```{code-cell} ipython3 +@jax.jit +def howard_policy_iteration(household, prices, tol=1e-4, max_iter=10_000): + """ + Howard policy iteration routine using a compiled JAX loop. + """ + β, a_grid, z_grid, Π = household + a_size, z_size = len(a_grid), len(z_grid) + + def condition_function(loop_state): + i, σ, v_σ, error = loop_state + return jnp.logical_and(error > tol, i < max_iter) + + def update(loop_state): + i, σ, v_σ, error = loop_state + σ_new = get_greedy(v_σ, household, prices) + v_σ_new = get_value(σ_new, household, prices) + error = jnp.max(jnp.abs(v_σ_new - v_σ)) + return i + 1, σ_new, v_σ_new, error + + # Initial loop state + σ_init = jnp.zeros((a_size, z_size), dtype=int) + v_σ_init = get_value(σ_init, household, prices) + loop_state_init = (0, σ_init, v_σ_init, tol + 1) + + # Run the fixed point iteration + i, σ, v_σ, error = jax.lax.while_loop(condition_function, update, loop_state_init) + + return σ +``` + +Now let's create a modified version of the G function that uses HPI: + +```{code-cell} ipython3 +def G_hpi(K, firm, household): + # Get prices r, w associated with K + r = r_given_k(K, firm) + w = r_to_w(r, firm) + + # Generate prices and compute aggregate capital using HPI. + prices = Prices(r=r, w=w) + σ_star = howard_policy_iteration(household, prices) + return capital_supply(σ_star, household) +``` + +And compute the equilibrium using HPI: + +```{code-cell} ipython3 +def compute_equilibrium_bisect_hpi(firm, household, a=1.0, b=20.0): + K = bisect(lambda k: k - G_hpi(k, firm, household), a, b, xtol=1e-4) + return K + +firm = Firm() +household = create_household() +print("\nComputing equilibrium capital stock using HPI") +with qe.Timer(): + K_star_hpi = compute_equilibrium_bisect_hpi(firm, household) +print(f"Computed equilibrium capital stock with HPI: {K_star_hpi:.5}") +print(f"Previous equilibrium capital stock with VFI: {K_star:.5}") +print(f"Difference: {abs(K_star_hpi - K_star):.6}") +``` + +The results show that both methods produce approximately the same equilibrium, confirming that HPI is a valid alternative to VFI. + +```{solution-end} +```