@@ -70,27 +70,7 @@ $$ u(c) = \frac{c^{1-\gamma}}{1-\gamma} $$
7070We use successive approximation for VFI.
7171
7272``` {code-cell} ipython3
73- def successive_approx(T, # Operator (callable)
74- x_0, # Initial condition
75- tolerance=1e-6, # Error tolerance
76- max_iter=10_000, # Max iteration bound
77- print_step=25, # Print at multiples
78- verbose=False):
79- x = x_0
80- error = tolerance + 1
81- k = 1
82- while error > tolerance and k <= max_iter:
83- x_new = T(x)
84- error = jnp.max(jnp.abs(x_new - x))
85- if verbose and k % print_step == 0:
86- print(f"Completed iteration {k} with error {error}.")
87- x = x_new
88- k += 1
89- if error > tolerance:
90- print(f"Warning: Iteration hit upper bound {max_iter}.")
91- elif verbose:
92- print(f"Terminated successfully in {k} iterations.")
93- return x
73+ :load: _static/lecture_specific/successive_approx.py
9474```
9575
9676## Model primitives
@@ -339,46 +319,15 @@ R_σ = jax.jit(R_σ, static_argnums=(3,))
339319Now we define the solvers, which implement VFI, HPI and OPI.
340320
341321``` {code-cell} ipython3
342- def value_iteration(model, tol=1e-5):
343- "Implements VFI."
344-
345- constants, sizes, arrays = model
346- _T = lambda v: T(v, constants, sizes, arrays)
347- vz = jnp.zeros(sizes)
348-
349- v_star = successive_approx(_T, vz, tolerance=tol)
350- return get_greedy(v_star, constants, sizes, arrays)
322+ :load: _static/lecture_specific/vfi.py
351323```
352324
353325``` {code-cell} ipython3
354- def policy_iteration(model):
355- "Howard policy iteration routine."
356- constants, sizes, arrays = model
357- σ = jnp.zeros(sizes, dtype=int)
358- i, error = 0, 1.0
359- while error > 0:
360- v_σ = get_value(σ, constants, sizes, arrays)
361- σ_new = get_greedy(v_σ, constants, sizes, arrays)
362- error = jnp.max(jnp.abs(σ_new - σ))
363- σ = σ_new
364- i = i + 1
365- print(f"Concluded loop {i} with error {error}.")
366- return σ
326+ :load: _static/lecture_specific/hpi.py
367327```
368328
369329``` {code-cell} ipython3
370- def optimistic_policy_iteration(model, tol=1e-5, m=10):
371- "Implements the OPI routine."
372- constants, sizes, arrays = model
373- v = jnp.zeros(sizes)
374- error = tol + 1
375- while error > tol:
376- last_v = v
377- σ = get_greedy(v, constants, sizes, arrays)
378- for _ in range(m):
379- v = T_σ(v, σ, constants, sizes, arrays)
380- error = jnp.max(jnp.abs(v - last_v))
381- return get_greedy(v, constants, sizes, arrays)
330+ :load: _static/lecture_specific/opi.py
382331```
383332
384333## Plots
@@ -388,7 +337,6 @@ Create a JAX model for consumption, perform policy iteration, and plot the resul
388337``` {code-cell} ipython3
389338fontsize = 12
390339model = create_consumption_model_jax()
391-
392340# Unpack
393341constants, sizes, arrays = model
394342β, R, γ = constants
0 commit comments