diff --git a/lectures/opt_invest.md b/lectures/opt_invest.md index 472ebd50..9bf6c092 100644 --- a/lectures/opt_invest.md +++ b/lectures/opt_invest.md @@ -250,7 +250,7 @@ def T_σ(v, σ, constants, sizes, arrays): # Calculate the expected sum Σ_jp v[σ[i, j], jp] * Q[i, j, jp] Ev = jnp.sum(V * Q, axis=2) - return r_σ + β * jnp.sum(V * Q, axis=2) + return r_σ + β * Ev T_σ = jax.jit(T_σ, static_argnums=(3,)) ``` @@ -361,7 +361,6 @@ def value_iteration(model, tol=1e-5): def policy_iteration(model, maxiter=250): constants, sizes, arrays = model - vz = jnp.zeros(sizes) σ = jnp.zeros(sizes, dtype=int) i, error = 0, 1.0 while error > 0 and i < maxiter: diff --git a/lectures/opt_savings.md b/lectures/opt_savings.md index 1719941a..327de924 100644 --- a/lectures/opt_savings.md +++ b/lectures/opt_savings.md @@ -57,7 +57,7 @@ $$ W_{t+1} + C_t \leq R W_t + Y_t $$ We assume that labor income $(Y_t)$ is a discretized AR(1) process. -The right-hand side of the Bellman equation is +The right-hand side of the Bellman equation is $$ B((w, y), w', v) = u(Rw + y - w') + β \sum_{y'} v(w', y') Q(y, y'). $$ @@ -75,7 +75,7 @@ def successive_approx(T, # Operator (callable) tolerance=1e-6, # Error tolerance max_iter=10_000, # Max iteration bound print_step=25, # Print at multiples - verbose=False): + verbose=False): x = x_0 error = tolerance + 1 k = 1 @@ -98,7 +98,7 @@ def successive_approx(T, # Operator (callable) Here’s a `namedtuple` definition for storing parameters and grids. ```{code-cell} ipython3 -Model = namedtuple('Model', +Model = namedtuple('Model', ('β', 'R', 'γ', 'w_grid', 'y_grid', 'Q')) ``` @@ -114,7 +114,7 @@ def create_consumption_model(R=1.01, # Gross interest rate A function that takes in parameters and returns an instance of Model that contains data for the optimal savings problem. """ - w_grid = jnp.linspace(w_min, w_max, w_size) + w_grid = jnp.linspace(w_min, w_max, w_size) mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν) y_grid, Q = jnp.exp(mc.state_values), mc.P return Model(β=β, R=R, γ=γ, w_grid=w_grid, y_grid=y_grid, Q=Q) @@ -129,7 +129,7 @@ def create_consumption_model_jax(R=1.01, # Gross interest rate w_size=150, # Grid side ρ=0.9, ν=0.1, y_size=100): # Income parameters """ - A function that takes in parameters and returns a JAX-compatible version of + A function that takes in parameters and returns a JAX-compatible version of Model that contains data for the optimal savings problem. """ w_grid = jnp.linspace(w_min, w_max, w_size) @@ -146,7 +146,7 @@ Here's the right hand side of the Bellman equation: ```{code-cell} ipython3 def B(v, constants, sizes, arrays): """ - A vectorized version of the right-hand side of the Bellman equation + A vectorized version of the right-hand side of the Bellman equation (before maximization), which is a 3D array representing B(w, y, w′) = u(Rw + y - w′) + β Σ_y′ v(w′, y′) Q(y, y′) @@ -154,7 +154,7 @@ def B(v, constants, sizes, arrays): for all (w, y, w′). """ - # Unpack + # Unpack β, R, γ = constants w_size, y_size = sizes w_grid, y_grid, Q = arrays @@ -215,15 +215,15 @@ def T_σ(v, σ, constants, sizes, arrays): yp_idx = jnp.arange(y_size) yp_idx = jnp.reshape(yp_idx, (1, 1, y_size)) σ = jnp.reshape(σ, (w_size, y_size, 1)) - V = v[σ, yp_idx] + V = v[σ, yp_idx] - # Convert Q[j, jp] to Q[i, j, jp] + # Convert Q[j, jp] to Q[i, j, jp] Q = jnp.reshape(Q, (1, y_size, y_size)) # Calculate the expected sum Σ_jp v[σ[i, j], jp] * Q[i, j, jp] Ev = jnp.sum(V * Q, axis=2) - return r_σ + β * jnp.sum(V * Q, axis=2) + return r_σ + β * Ev ``` and the Bellman operator $T$ @@ -248,9 +248,9 @@ The basic problem is to solve the linear system $$ v(w,y ) = u(Rw + y - \sigma(w, y)) + β \sum_{y'} v(\sigma(w, y), y') Q(y, y) $$ -for $v$. +for $v$. -It turns out to be helpful to rewrite this as +It turns out to be helpful to rewrite this as $$ v(w,y) = r(w, y, \sigma(w, y)) + β \sum_{w', y'} v(w', y') P_\sigma(w, y, w', y') $$ @@ -260,40 +260,23 @@ We want to write this as $v = r_\sigma + P_\sigma v$ and then solve for $v$ Note, however, -* $v$ is a 2 index array, rather than a single vector. -* $P_\sigma$ has four indices rather than 2 +* $v$ is a 2 index array, rather than a single vector. +* $P_\sigma$ has four indices rather than 2 -The code below +The code below 1. reshapes $v$ and $r_\sigma$ to 1D arrays and $P_\sigma$ to a matrix 2. solves the linear system 3. converts back to multi-index arrays. -```{code-cell} ipython3 -def get_value(σ, constants, sizes, arrays): - "Get the value v_σ of policy σ by inverting the linear map R_σ." - - # Unpack - β, R, γ = constants - w_size, y_size = sizes - w_grid, y_grid, Q = arrays - - r_σ = compute_r_σ(σ, constants, sizes, arrays) - - # Reduce R_σ to a function in v - partial_R_σ = lambda v: R_σ(v, σ, constants, sizes, arrays) - - return jax.scipy.sparse.linalg.bicgstab(partial_R_σ, r_σ)[0] -``` - ```{code-cell} ipython3 def R_σ(v, σ, constants, sizes, arrays): """ - The value v_σ of a policy σ is defined as + The value v_σ of a policy σ is defined as v_σ = (I - β P_σ)^{-1} r_σ - Here we set up the linear map v -> R_σ v, where R_σ := I - β P_σ. + Here we set up the linear map v -> R_σ v, where R_σ := I - β P_σ. In the consumption problem, this map can be expressed as @@ -322,6 +305,23 @@ def R_σ(v, σ, constants, sizes, arrays): return v - β * jnp.sum(V * Q, axis=2) ``` +```{code-cell} ipython3 +def get_value(σ, constants, sizes, arrays): + "Get the value v_σ of policy σ by inverting the linear map R_σ." + + # Unpack + β, R, γ = constants + w_size, y_size = sizes + w_grid, y_grid, Q = arrays + + r_σ = compute_r_σ(σ, constants, sizes, arrays) + + # Reduce R_σ to a function in v + partial_R_σ = lambda v: R_σ(v, σ, constants, sizes, arrays) + + return jax.scipy.sparse.linalg.bicgstab(partial_R_σ, r_σ)[0] +``` + ## JIT compiled versions ```{code-cell} ipython3 @@ -354,7 +354,6 @@ def value_iteration(model, tol=1e-5): def policy_iteration(model): "Howard policy iteration routine." constants, sizes, arrays = model - vz = jnp.zeros(sizes) σ = jnp.zeros(sizes, dtype=int) i, error = 0, 1.0 while error > 0: @@ -387,14 +386,19 @@ def optimistic_policy_iteration(model, tol=1e-5, m=10): Create a JAX model for consumption, perform policy iteration, and plot the resulting optimal policy function. ```{code-cell} ipython3 -fontsize=12 +fontsize = 12 model = create_consumption_model_jax() -# Unpack + +# Unpack constants, sizes, arrays = model β, R, γ = constants w_size, y_size = sizes w_grid, y_grid, Q = arrays +``` + +```{code-cell} ipython3 σ_star = policy_iteration(model) + fig, ax = plt.subplots(figsize=(9, 5.2)) ax.plot(w_grid, w_grid, "k--", label="45") ax.plot(w_grid, w_grid[σ_star[:, 1]], label="$\\sigma^*(\cdot, y_1)$") @@ -443,7 +447,9 @@ def run_algorithm(algorithm, model, **kwargs): elapsed_time = end_time - start_time print(f"{algorithm.__name__} completed in {elapsed_time:.2f} seconds.") return result, elapsed_time +``` +```{code-cell} ipython3 model = create_consumption_model_jax() σ_pi, pi_time = run_algorithm(policy_iteration, model) σ_vfi, vfi_time = run_algorithm(value_iteration, model, tol=1e-5) @@ -453,7 +459,9 @@ opi_times = [] for m in m_vals: σ_opi, opi_time = run_algorithm(optimistic_policy_iteration, model, m=m, tol=1e-5) opi_times.append(opi_time) +``` +```{code-cell} ipython3 fig, ax = plt.subplots(figsize=(9, 5.2)) ax.plot(m_vals, jnp.full(len(m_vals), pi_time), lw=2, label="Howard policy iteration") ax.plot(m_vals, jnp.full(len(m_vals), vfi_time), lw=2, label="value function iteration")