Skip to content

Commit e2646e4

Browse files
authored
Extract common code from opt_invest and savings (#83)
* Extract out common code * Use load * Remove merge issue
1 parent b56e91d commit e2646e4

File tree

6 files changed

+65
-113
lines changed

6 files changed

+65
-113
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Implements HPI-Howard policy iteration routine
2+
3+
def policy_iteration(model, maxiter=250):
4+
constants, sizes, arrays = model
5+
σ = jnp.zeros(sizes, dtype=int)
6+
i, error = 0, 1.0
7+
while error > 0 and i < maxiter:
8+
v_σ = get_value(σ, constants, sizes, arrays)
9+
σ_new = get_greedy(v_σ, constants, sizes, arrays)
10+
error = jnp.max(jnp.abs(σ_new - σ))
11+
σ = σ_new
12+
i = i + 1
13+
print(f"Concluded loop {i} with error {error}.")
14+
return σ
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Implements the OPI-Optimal policy Iteration routine
2+
3+
def optimistic_policy_iteration(model, tol=1e-5, m=10):
4+
constants, sizes, arrays = model
5+
v = jnp.zeros(sizes)
6+
error = tol + 1
7+
while error > tol:
8+
last_v = v
9+
σ = get_greedy(v, constants, sizes, arrays)
10+
for _ in range(m):
11+
v = T_σ(v, σ, constants, sizes, arrays)
12+
error = jnp.max(jnp.abs(v - last_v))
13+
return get_greedy(v, constants, sizes, arrays)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
def successive_approx(T, # Operator (callable)
2+
x_0, # Initial condition
3+
tolerance=1e-6, # Error tolerance
4+
max_iter=10_000, # Max iteration bound
5+
print_step=25, # Print at multiples
6+
verbose=False):
7+
x = x_0
8+
error = tolerance + 1
9+
k = 1
10+
while error > tolerance and k <= max_iter:
11+
x_new = T(x)
12+
error = jnp.max(jnp.abs(x_new - x))
13+
if verbose and k % print_step == 0:
14+
print(f"Completed iteration {k} with error {error}.")
15+
x = x_new
16+
k += 1
17+
if error > tolerance:
18+
print(f"Warning: Iteration hit upper bound {max_iter}.")
19+
elif verbose:
20+
print(f"Terminated successfully in {k} iterations.")
21+
return x
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Implements VFI-Value Function iteration
2+
3+
def value_iteration(model, tol=1e-5):
4+
constants, sizes, arrays = model
5+
_T = lambda v: T(v, constants, sizes, arrays)
6+
vz = jnp.zeros(sizes)
7+
8+
v_star = successive_approx(_T, vz, tolerance=tol)
9+
return get_greedy(v_star, constants, sizes, arrays)

lectures/opt_invest.md

Lines changed: 4 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -81,27 +81,7 @@ jax.config.update("jax_enable_x64", True)
8181
We need the following successive approximation function.
8282

8383
```{code-cell} ipython3
84-
def successive_approx(T, # Operator (callable)
85-
x_0, # Initial condition
86-
tolerance=1e-6, # Error tolerance
87-
max_iter=10_000, # Max iteration bound
88-
print_step=25, # Print at multiples
89-
verbose=False):
90-
x = x_0
91-
error = tolerance + 1
92-
k = 1
93-
while error > tolerance and k <= max_iter:
94-
x_new = T(x)
95-
error = jnp.max(jnp.abs(x_new - x))
96-
if verbose and k % print_step == 0:
97-
print(f"Completed iteration {k} with error {error}.")
98-
x = x_new
99-
k += 1
100-
if error > tolerance:
101-
print(f"Warning: Iteration hit upper bound {max_iter}.")
102-
elif verbose:
103-
print(f"Terminated successfully in {k} iterations.")
104-
return x
84+
:load: _static/lecture_specific/successive_approx.py
10585
```
10686

10787

@@ -345,48 +325,15 @@ get_value = jax.jit(get_value, static_argnums=(2,))
345325
Now we define the solvers, which implement VFI, HPI and OPI.
346326

347327
```{code-cell} ipython3
348-
# Implements VFI-Value Function iteration
349-
350-
def value_iteration(model, tol=1e-5):
351-
constants, sizes, arrays = model
352-
_T = lambda v: T(v, constants, sizes, arrays)
353-
vz = jnp.zeros(sizes)
354-
355-
v_star = successive_approx(_T, vz, tolerance=tol)
356-
return get_greedy(v_star, constants, sizes, arrays)
328+
:load: _static/lecture_specific/vfi.py
357329
```
358330

359331
```{code-cell} ipython3
360-
# Implements HPI-Howard policy iteration routine
361-
362-
def policy_iteration(model, maxiter=250):
363-
constants, sizes, arrays = model
364-
σ = jnp.zeros(sizes, dtype=int)
365-
i, error = 0, 1.0
366-
while error > 0 and i < maxiter:
367-
v_σ = get_value(σ, constants, sizes, arrays)
368-
σ_new = get_greedy(v_σ, constants, sizes, arrays)
369-
error = jnp.max(jnp.abs(σ_new - σ))
370-
σ = σ_new
371-
i = i + 1
372-
print(f"Concluded loop {i} with error {error}.")
373-
return σ
332+
:load: _static/lecture_specific/hpi.py
374333
```
375334

376335
```{code-cell} ipython3
377-
# Implements the OPI-Optimal policy Iteration routine
378-
379-
def optimistic_policy_iteration(model, tol=1e-5, m=10):
380-
constants, sizes, arrays = model
381-
v = jnp.zeros(sizes)
382-
error = tol + 1
383-
while error > tol:
384-
last_v = v
385-
σ = get_greedy(v, constants, sizes, arrays)
386-
for _ in range(m):
387-
v = T_σ(v, σ, constants, sizes, arrays)
388-
error = jnp.max(jnp.abs(v - last_v))
389-
return get_greedy(v, constants, sizes, arrays)
336+
:load: _static/lecture_specific/opi.py
390337
```
391338

392339
```{code-cell} ipython3

lectures/opt_savings.md

Lines changed: 4 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -70,27 +70,7 @@ $$ u(c) = \frac{c^{1-\gamma}}{1-\gamma} $$
7070
We use successive approximation for VFI.
7171

7272
```{code-cell} ipython3
73-
def successive_approx(T, # Operator (callable)
74-
x_0, # Initial condition
75-
tolerance=1e-6, # Error tolerance
76-
max_iter=10_000, # Max iteration bound
77-
print_step=25, # Print at multiples
78-
verbose=False):
79-
x = x_0
80-
error = tolerance + 1
81-
k = 1
82-
while error > tolerance and k <= max_iter:
83-
x_new = T(x)
84-
error = jnp.max(jnp.abs(x_new - x))
85-
if verbose and k % print_step == 0:
86-
print(f"Completed iteration {k} with error {error}.")
87-
x = x_new
88-
k += 1
89-
if error > tolerance:
90-
print(f"Warning: Iteration hit upper bound {max_iter}.")
91-
elif verbose:
92-
print(f"Terminated successfully in {k} iterations.")
93-
return x
73+
:load: _static/lecture_specific/successive_approx.py
9474
```
9575

9676
## Model primitives
@@ -339,46 +319,15 @@ R_σ = jax.jit(R_σ, static_argnums=(3,))
339319
Now we define the solvers, which implement VFI, HPI and OPI.
340320

341321
```{code-cell} ipython3
342-
def value_iteration(model, tol=1e-5):
343-
"Implements VFI."
344-
345-
constants, sizes, arrays = model
346-
_T = lambda v: T(v, constants, sizes, arrays)
347-
vz = jnp.zeros(sizes)
348-
349-
v_star = successive_approx(_T, vz, tolerance=tol)
350-
return get_greedy(v_star, constants, sizes, arrays)
322+
:load: _static/lecture_specific/vfi.py
351323
```
352324

353325
```{code-cell} ipython3
354-
def policy_iteration(model):
355-
"Howard policy iteration routine."
356-
constants, sizes, arrays = model
357-
σ = jnp.zeros(sizes, dtype=int)
358-
i, error = 0, 1.0
359-
while error > 0:
360-
v_σ = get_value(σ, constants, sizes, arrays)
361-
σ_new = get_greedy(v_σ, constants, sizes, arrays)
362-
error = jnp.max(jnp.abs(σ_new - σ))
363-
σ = σ_new
364-
i = i + 1
365-
print(f"Concluded loop {i} with error {error}.")
366-
return σ
326+
:load: _static/lecture_specific/hpi.py
367327
```
368328

369329
```{code-cell} ipython3
370-
def optimistic_policy_iteration(model, tol=1e-5, m=10):
371-
"Implements the OPI routine."
372-
constants, sizes, arrays = model
373-
v = jnp.zeros(sizes)
374-
error = tol + 1
375-
while error > tol:
376-
last_v = v
377-
σ = get_greedy(v, constants, sizes, arrays)
378-
for _ in range(m):
379-
v = T_σ(v, σ, constants, sizes, arrays)
380-
error = jnp.max(jnp.abs(v - last_v))
381-
return get_greedy(v, constants, sizes, arrays)
330+
:load: _static/lecture_specific/opi.py
382331
```
383332

384333
## Plots
@@ -388,7 +337,6 @@ Create a JAX model for consumption, perform policy iteration, and plot the resul
388337
```{code-cell} ipython3
389338
fontsize = 12
390339
model = create_consumption_model_jax()
391-
392340
# Unpack
393341
constants, sizes, arrays = model
394342
β, R, γ = constants

0 commit comments

Comments
 (0)