Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions lectures/_static/lecture_specific/hpi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Implements HPI-Howard policy iteration routine

def policy_iteration(model, maxiter=250):
constants, sizes, arrays = model
σ = jnp.zeros(sizes, dtype=int)
i, error = 0, 1.0
while error > 0 and i < maxiter:
v_σ = get_value(σ, constants, sizes, arrays)
σ_new = get_greedy(v_σ, constants, sizes, arrays)
error = jnp.max(jnp.abs(σ_new - σ))
σ = σ_new
i = i + 1
print(f"Concluded loop {i} with error {error}.")
return σ
13 changes: 13 additions & 0 deletions lectures/_static/lecture_specific/opi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Implements the OPI-Optimal policy Iteration routine

def optimistic_policy_iteration(model, tol=1e-5, m=10):
constants, sizes, arrays = model
v = jnp.zeros(sizes)
error = tol + 1
while error > tol:
last_v = v
σ = get_greedy(v, constants, sizes, arrays)
for _ in range(m):
v = T_σ(v, σ, constants, sizes, arrays)
error = jnp.max(jnp.abs(v - last_v))
return get_greedy(v, constants, sizes, arrays)
21 changes: 21 additions & 0 deletions lectures/_static/lecture_specific/successive_approx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
def successive_approx(T, # Operator (callable)
x_0, # Initial condition
tolerance=1e-6, # Error tolerance
max_iter=10_000, # Max iteration bound
print_step=25, # Print at multiples
verbose=False):
x = x_0
error = tolerance + 1
k = 1
while error > tolerance and k <= max_iter:
x_new = T(x)
error = jnp.max(jnp.abs(x_new - x))
if verbose and k % print_step == 0:
print(f"Completed iteration {k} with error {error}.")
x = x_new
k += 1
if error > tolerance:
print(f"Warning: Iteration hit upper bound {max_iter}.")
elif verbose:
print(f"Terminated successfully in {k} iterations.")
return x
9 changes: 9 additions & 0 deletions lectures/_static/lecture_specific/vfi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Implements VFI-Value Function iteration

def value_iteration(model, tol=1e-5):
constants, sizes, arrays = model
_T = lambda v: T(v, constants, sizes, arrays)
vz = jnp.zeros(sizes)

v_star = successive_approx(_T, vz, tolerance=tol)
return get_greedy(v_star, constants, sizes, arrays)
61 changes: 4 additions & 57 deletions lectures/opt_invest.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,27 +81,7 @@ jax.config.update("jax_enable_x64", True)
We need the following successive approximation function.

```{code-cell} ipython3
def successive_approx(T, # Operator (callable)
x_0, # Initial condition
tolerance=1e-6, # Error tolerance
max_iter=10_000, # Max iteration bound
print_step=25, # Print at multiples
verbose=False):
x = x_0
error = tolerance + 1
k = 1
while error > tolerance and k <= max_iter:
x_new = T(x)
error = jnp.max(jnp.abs(x_new - x))
if verbose and k % print_step == 0:
print(f"Completed iteration {k} with error {error}.")
x = x_new
k += 1
if error > tolerance:
print(f"Warning: Iteration hit upper bound {max_iter}.")
elif verbose:
print(f"Terminated successfully in {k} iterations.")
return x
:load: _static/lecture_specific/successive_approx.py
```


Expand Down Expand Up @@ -345,48 +325,15 @@ get_value = jax.jit(get_value, static_argnums=(2,))
Now we define the solvers, which implement VFI, HPI and OPI.

```{code-cell} ipython3
# Implements VFI-Value Function iteration

def value_iteration(model, tol=1e-5):
constants, sizes, arrays = model
_T = lambda v: T(v, constants, sizes, arrays)
vz = jnp.zeros(sizes)

v_star = successive_approx(_T, vz, tolerance=tol)
return get_greedy(v_star, constants, sizes, arrays)
:load: _static/lecture_specific/vfi.py
```

```{code-cell} ipython3
# Implements HPI-Howard policy iteration routine

def policy_iteration(model, maxiter=250):
constants, sizes, arrays = model
σ = jnp.zeros(sizes, dtype=int)
i, error = 0, 1.0
while error > 0 and i < maxiter:
v_σ = get_value(σ, constants, sizes, arrays)
σ_new = get_greedy(v_σ, constants, sizes, arrays)
error = jnp.max(jnp.abs(σ_new - σ))
σ = σ_new
i = i + 1
print(f"Concluded loop {i} with error {error}.")
return σ
:load: _static/lecture_specific/hpi.py
```

```{code-cell} ipython3
# Implements the OPI-Optimal policy Iteration routine

def optimistic_policy_iteration(model, tol=1e-5, m=10):
constants, sizes, arrays = model
v = jnp.zeros(sizes)
error = tol + 1
while error > tol:
last_v = v
σ = get_greedy(v, constants, sizes, arrays)
for _ in range(m):
v = T_σ(v, σ, constants, sizes, arrays)
error = jnp.max(jnp.abs(v - last_v))
return get_greedy(v, constants, sizes, arrays)
:load: _static/lecture_specific/opi.py
```

```{code-cell} ipython3
Expand Down
60 changes: 4 additions & 56 deletions lectures/opt_savings.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,27 +70,7 @@ $$ u(c) = \frac{c^{1-\gamma}}{1-\gamma} $$
We use successive approximation for VFI.

```{code-cell} ipython3
def successive_approx(T, # Operator (callable)
x_0, # Initial condition
tolerance=1e-6, # Error tolerance
max_iter=10_000, # Max iteration bound
print_step=25, # Print at multiples
verbose=False):
x = x_0
error = tolerance + 1
k = 1
while error > tolerance and k <= max_iter:
x_new = T(x)
error = jnp.max(jnp.abs(x_new - x))
if verbose and k % print_step == 0:
print(f"Completed iteration {k} with error {error}.")
x = x_new
k += 1
if error > tolerance:
print(f"Warning: Iteration hit upper bound {max_iter}.")
elif verbose:
print(f"Terminated successfully in {k} iterations.")
return x
:load: _static/lecture_specific/successive_approx.py
```

## Model primitives
Expand Down Expand Up @@ -339,46 +319,15 @@ R_σ = jax.jit(R_σ, static_argnums=(3,))
Now we define the solvers, which implement VFI, HPI and OPI.

```{code-cell} ipython3
def value_iteration(model, tol=1e-5):
"Implements VFI."

constants, sizes, arrays = model
_T = lambda v: T(v, constants, sizes, arrays)
vz = jnp.zeros(sizes)

v_star = successive_approx(_T, vz, tolerance=tol)
return get_greedy(v_star, constants, sizes, arrays)
:load: _static/lecture_specific/vfi.py
```

```{code-cell} ipython3
def policy_iteration(model):
"Howard policy iteration routine."
constants, sizes, arrays = model
σ = jnp.zeros(sizes, dtype=int)
i, error = 0, 1.0
while error > 0:
v_σ = get_value(σ, constants, sizes, arrays)
σ_new = get_greedy(v_σ, constants, sizes, arrays)
error = jnp.max(jnp.abs(σ_new - σ))
σ = σ_new
i = i + 1
print(f"Concluded loop {i} with error {error}.")
return σ
:load: _static/lecture_specific/hpi.py
```

```{code-cell} ipython3
def optimistic_policy_iteration(model, tol=1e-5, m=10):
"Implements the OPI routine."
constants, sizes, arrays = model
v = jnp.zeros(sizes)
error = tol + 1
while error > tol:
last_v = v
σ = get_greedy(v, constants, sizes, arrays)
for _ in range(m):
v = T_σ(v, σ, constants, sizes, arrays)
error = jnp.max(jnp.abs(v - last_v))
return get_greedy(v, constants, sizes, arrays)
:load: _static/lecture_specific/opi.py
```

## Plots
Expand All @@ -388,7 +337,6 @@ Create a JAX model for consumption, perform policy iteration, and plot the resul
```{code-cell} ipython3
fontsize = 12
model = create_consumption_model_jax()

# Unpack
constants, sizes, arrays = model
β, R, γ = constants
Expand Down