Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions lectures/opt_invest.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def T_σ(v, σ, constants, sizes, arrays):
# Calculate the expected sum Σ_jp v[σ[i, j], jp] * Q[i, j, jp]
Ev = jnp.sum(V * Q, axis=2)

return r_σ + β * jnp.sum(V * Q, axis=2)
return r_σ + β * Ev

T_σ = jax.jit(T_σ, static_argnums=(3,))
```
Expand Down Expand Up @@ -361,7 +361,6 @@ def value_iteration(model, tol=1e-5):

def policy_iteration(model, maxiter=250):
constants, sizes, arrays = model
vz = jnp.zeros(sizes)
σ = jnp.zeros(sizes, dtype=int)
i, error = 0, 1.0
while error > 0 and i < maxiter:
Expand Down
82 changes: 45 additions & 37 deletions lectures/opt_savings.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ $$ W_{t+1} + C_t \leq R W_t + Y_t $$

We assume that labor income $(Y_t)$ is a discretized AR(1) process.

The right-hand side of the Bellman equation is
The right-hand side of the Bellman equation is

$$ B((w, y), w', v) = u(Rw + y - w') + β \sum_{y'} v(w', y') Q(y, y'). $$

Expand All @@ -75,7 +75,7 @@ def successive_approx(T, # Operator (callable)
tolerance=1e-6, # Error tolerance
max_iter=10_000, # Max iteration bound
print_step=25, # Print at multiples
verbose=False):
verbose=False):
x = x_0
error = tolerance + 1
k = 1
Expand All @@ -98,7 +98,7 @@ def successive_approx(T, # Operator (callable)
Here’s a `namedtuple` definition for storing parameters and grids.

```{code-cell} ipython3
Model = namedtuple('Model',
Model = namedtuple('Model',
('β', 'R', 'γ', 'w_grid', 'y_grid', 'Q'))
```

Expand All @@ -114,7 +114,7 @@ def create_consumption_model(R=1.01, # Gross interest rate
A function that takes in parameters and returns an instance of Model that
contains data for the optimal savings problem.
"""
w_grid = jnp.linspace(w_min, w_max, w_size)
w_grid = jnp.linspace(w_min, w_max, w_size)
mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν)
y_grid, Q = jnp.exp(mc.state_values), mc.P
return Model(β=β, R=R, γ=γ, w_grid=w_grid, y_grid=y_grid, Q=Q)
Expand All @@ -129,7 +129,7 @@ def create_consumption_model_jax(R=1.01, # Gross interest rate
w_size=150, # Grid side
ρ=0.9, ν=0.1, y_size=100): # Income parameters
"""
A function that takes in parameters and returns a JAX-compatible version of
A function that takes in parameters and returns a JAX-compatible version of
Model that contains data for the optimal savings problem.
"""
w_grid = jnp.linspace(w_min, w_max, w_size)
Expand All @@ -146,15 +146,15 @@ Here's the right hand side of the Bellman equation:
```{code-cell} ipython3
def B(v, constants, sizes, arrays):
"""
A vectorized version of the right-hand side of the Bellman equation
A vectorized version of the right-hand side of the Bellman equation
(before maximization), which is a 3D array representing

B(w, y, w′) = u(Rw + y - w′) + β Σ_y′ v(w′, y′) Q(y, y′)

for all (w, y, w′).
"""

# Unpack
# Unpack
β, R, γ = constants
w_size, y_size = sizes
w_grid, y_grid, Q = arrays
Expand Down Expand Up @@ -215,15 +215,15 @@ def T_σ(v, σ, constants, sizes, arrays):
yp_idx = jnp.arange(y_size)
yp_idx = jnp.reshape(yp_idx, (1, 1, y_size))
σ = jnp.reshape(σ, (w_size, y_size, 1))
V = v[σ, yp_idx]
V = v[σ, yp_idx]

# Convert Q[j, jp] to Q[i, j, jp]
# Convert Q[j, jp] to Q[i, j, jp]
Q = jnp.reshape(Q, (1, y_size, y_size))

# Calculate the expected sum Σ_jp v[σ[i, j], jp] * Q[i, j, jp]
Ev = jnp.sum(V * Q, axis=2)

return r_σ + β * jnp.sum(V * Q, axis=2)
return r_σ + β * Ev
```

and the Bellman operator $T$
Expand All @@ -248,9 +248,9 @@ The basic problem is to solve the linear system

$$ v(w,y ) = u(Rw + y - \sigma(w, y)) + β \sum_{y'} v(\sigma(w, y), y') Q(y, y) $$

for $v$.
for $v$.

It turns out to be helpful to rewrite this as
It turns out to be helpful to rewrite this as

$$ v(w,y) = r(w, y, \sigma(w, y)) + β \sum_{w', y'} v(w', y') P_\sigma(w, y, w', y') $$

Expand All @@ -260,40 +260,23 @@ We want to write this as $v = r_\sigma + P_\sigma v$ and then solve for $v$

Note, however,

* $v$ is a 2 index array, rather than a single vector.
* $P_\sigma$ has four indices rather than 2
* $v$ is a 2 index array, rather than a single vector.
* $P_\sigma$ has four indices rather than 2

The code below
The code below

1. reshapes $v$ and $r_\sigma$ to 1D arrays and $P_\sigma$ to a matrix
2. solves the linear system
3. converts back to multi-index arrays.

```{code-cell} ipython3
def get_value(σ, constants, sizes, arrays):
"Get the value v_σ of policy σ by inverting the linear map R_σ."

# Unpack
β, R, γ = constants
w_size, y_size = sizes
w_grid, y_grid, Q = arrays

r_σ = compute_r_σ(σ, constants, sizes, arrays)

# Reduce R_σ to a function in v
partial_R_σ = lambda v: R_σ(v, σ, constants, sizes, arrays)

return jax.scipy.sparse.linalg.bicgstab(partial_R_σ, r_σ)[0]
```

```{code-cell} ipython3
def R_σ(v, σ, constants, sizes, arrays):
"""
The value v_σ of a policy σ is defined as
The value v_σ of a policy σ is defined as

v_σ = (I - β P_σ)^{-1} r_σ

Here we set up the linear map v -> R_σ v, where R_σ := I - β P_σ.
Here we set up the linear map v -> R_σ v, where R_σ := I - β P_σ.

In the consumption problem, this map can be expressed as

Expand Down Expand Up @@ -322,6 +305,23 @@ def R_σ(v, σ, constants, sizes, arrays):
return v - β * jnp.sum(V * Q, axis=2)
```

```{code-cell} ipython3
def get_value(σ, constants, sizes, arrays):
"Get the value v_σ of policy σ by inverting the linear map R_σ."

# Unpack
β, R, γ = constants
w_size, y_size = sizes
w_grid, y_grid, Q = arrays

r_σ = compute_r_σ(σ, constants, sizes, arrays)

# Reduce R_σ to a function in v
partial_R_σ = lambda v: R_σ(v, σ, constants, sizes, arrays)

return jax.scipy.sparse.linalg.bicgstab(partial_R_σ, r_σ)[0]
```

## JIT compiled versions

```{code-cell} ipython3
Expand Down Expand Up @@ -354,7 +354,6 @@ def value_iteration(model, tol=1e-5):
def policy_iteration(model):
"Howard policy iteration routine."
constants, sizes, arrays = model
vz = jnp.zeros(sizes)
σ = jnp.zeros(sizes, dtype=int)
i, error = 0, 1.0
while error > 0:
Expand Down Expand Up @@ -387,14 +386,19 @@ def optimistic_policy_iteration(model, tol=1e-5, m=10):
Create a JAX model for consumption, perform policy iteration, and plot the resulting optimal policy function.

```{code-cell} ipython3
fontsize=12
fontsize = 12
model = create_consumption_model_jax()
# Unpack

# Unpack
constants, sizes, arrays = model
β, R, γ = constants
w_size, y_size = sizes
w_grid, y_grid, Q = arrays
```

```{code-cell} ipython3
σ_star = policy_iteration(model)

fig, ax = plt.subplots(figsize=(9, 5.2))
ax.plot(w_grid, w_grid, "k--", label="45")
ax.plot(w_grid, w_grid[σ_star[:, 1]], label="$\\sigma^*(\cdot, y_1)$")
Expand Down Expand Up @@ -443,7 +447,9 @@ def run_algorithm(algorithm, model, **kwargs):
elapsed_time = end_time - start_time
print(f"{algorithm.__name__} completed in {elapsed_time:.2f} seconds.")
return result, elapsed_time
```

```{code-cell} ipython3
model = create_consumption_model_jax()
σ_pi, pi_time = run_algorithm(policy_iteration, model)
σ_vfi, vfi_time = run_algorithm(value_iteration, model, tol=1e-5)
Expand All @@ -453,7 +459,9 @@ opi_times = []
for m in m_vals:
σ_opi, opi_time = run_algorithm(optimistic_policy_iteration, model, m=m, tol=1e-5)
opi_times.append(opi_time)
```

```{code-cell} ipython3
fig, ax = plt.subplots(figsize=(9, 5.2))
ax.plot(m_vals, jnp.full(len(m_vals), pi_time), lw=2, label="Howard policy iteration")
ax.plot(m_vals, jnp.full(len(m_vals), vfi_time), lw=2, label="value function iteration")
Expand Down