From 8acbc3ec7bb2df7c43cda3e03e10e057ce44339f Mon Sep 17 00:00:00 2001 From: Smit-create Date: Mon, 17 Jul 2023 12:39:30 +0530 Subject: [PATCH 1/3] Extract out common code --- lectures/_static/lecture_specific/hpi.py | 15 +++++++++++++ lectures/_static/lecture_specific/opi.py | 13 ++++++++++++ .../lecture_specific/successive_approx.py | 21 +++++++++++++++++++ lectures/_static/lecture_specific/vfi.py | 9 ++++++++ 4 files changed, 58 insertions(+) create mode 100644 lectures/_static/lecture_specific/hpi.py create mode 100644 lectures/_static/lecture_specific/opi.py create mode 100644 lectures/_static/lecture_specific/successive_approx.py create mode 100644 lectures/_static/lecture_specific/vfi.py diff --git a/lectures/_static/lecture_specific/hpi.py b/lectures/_static/lecture_specific/hpi.py new file mode 100644 index 00000000..01d1cd60 --- /dev/null +++ b/lectures/_static/lecture_specific/hpi.py @@ -0,0 +1,15 @@ +# Implements HPI-Howard policy iteration routine + +def policy_iteration(model, maxiter=250): + constants, sizes, arrays = model + vz = jnp.zeros(sizes) + σ = jnp.zeros(sizes, dtype=int) + i, error = 0, 1.0 + while error > 0 and i < maxiter: + v_σ = get_value(σ, constants, sizes, arrays) + σ_new = get_greedy(v_σ, constants, sizes, arrays) + error = jnp.max(jnp.abs(σ_new - σ)) + σ = σ_new + i = i + 1 + print(f"Concluded loop {i} with error {error}.") + return σ diff --git a/lectures/_static/lecture_specific/opi.py b/lectures/_static/lecture_specific/opi.py new file mode 100644 index 00000000..90df42f4 --- /dev/null +++ b/lectures/_static/lecture_specific/opi.py @@ -0,0 +1,13 @@ +# Implements the OPI-Optimal policy Iteration routine + +def optimistic_policy_iteration(model, tol=1e-5, m=10): + constants, sizes, arrays = model + v = jnp.zeros(sizes) + error = tol + 1 + while error > tol: + last_v = v + σ = get_greedy(v, constants, sizes, arrays) + for _ in range(m): + v = T_σ(v, σ, constants, sizes, arrays) + error = jnp.max(jnp.abs(v - last_v)) + return get_greedy(v, constants, sizes, arrays) diff --git a/lectures/_static/lecture_specific/successive_approx.py b/lectures/_static/lecture_specific/successive_approx.py new file mode 100644 index 00000000..18ddbf4d --- /dev/null +++ b/lectures/_static/lecture_specific/successive_approx.py @@ -0,0 +1,21 @@ +def successive_approx(T, # Operator (callable) + x_0, # Initial condition + tolerance=1e-6, # Error tolerance + max_iter=10_000, # Max iteration bound + print_step=25, # Print at multiples + verbose=False): + x = x_0 + error = tolerance + 1 + k = 1 + while error > tolerance and k <= max_iter: + x_new = T(x) + error = jnp.max(jnp.abs(x_new - x)) + if verbose and k % print_step == 0: + print(f"Completed iteration {k} with error {error}.") + x = x_new + k += 1 + if error > tolerance: + print(f"Warning: Iteration hit upper bound {max_iter}.") + elif verbose: + print(f"Terminated successfully in {k} iterations.") + return x diff --git a/lectures/_static/lecture_specific/vfi.py b/lectures/_static/lecture_specific/vfi.py new file mode 100644 index 00000000..d4f1f7b7 --- /dev/null +++ b/lectures/_static/lecture_specific/vfi.py @@ -0,0 +1,9 @@ +# Implements VFI-Value Function iteration + +def value_iteration(model, tol=1e-5): + constants, sizes, arrays = model + _T = lambda v: T(v, constants, sizes, arrays) + vz = jnp.zeros(sizes) + + v_star = successive_approx(_T, vz, tolerance=tol) + return get_greedy(v_star, constants, sizes, arrays) From 30bf3b8b3b789dcfd117d1040f71ef738cfb4c76 Mon Sep 17 00:00:00 2001 From: Smit-create Date: Mon, 24 Jul 2023 16:32:16 +0530 Subject: [PATCH 2/3] Use load --- lectures/_static/lecture_specific/hpi.py | 1 - lectures/opt_invest.md | 61 ++----------------- lectures/opt_savings.md | 77 +++++++----------------- 3 files changed, 25 insertions(+), 114 deletions(-) diff --git a/lectures/_static/lecture_specific/hpi.py b/lectures/_static/lecture_specific/hpi.py index 01d1cd60..128e1bed 100644 --- a/lectures/_static/lecture_specific/hpi.py +++ b/lectures/_static/lecture_specific/hpi.py @@ -2,7 +2,6 @@ def policy_iteration(model, maxiter=250): constants, sizes, arrays = model - vz = jnp.zeros(sizes) σ = jnp.zeros(sizes, dtype=int) i, error = 0, 1.0 while error > 0 and i < maxiter: diff --git a/lectures/opt_invest.md b/lectures/opt_invest.md index 9bf6c092..625cd73e 100644 --- a/lectures/opt_invest.md +++ b/lectures/opt_invest.md @@ -81,27 +81,7 @@ jax.config.update("jax_enable_x64", True) We need the following successive approximation function. ```{code-cell} ipython3 -def successive_approx(T, # Operator (callable) - x_0, # Initial condition - tolerance=1e-6, # Error tolerance - max_iter=10_000, # Max iteration bound - print_step=25, # Print at multiples - verbose=False): - x = x_0 - error = tolerance + 1 - k = 1 - while error > tolerance and k <= max_iter: - x_new = T(x) - error = jnp.max(jnp.abs(x_new - x)) - if verbose and k % print_step == 0: - print(f"Completed iteration {k} with error {error}.") - x = x_new - k += 1 - if error > tolerance: - print(f"Warning: Iteration hit upper bound {max_iter}.") - elif verbose: - print(f"Terminated successfully in {k} iterations.") - return x +:load: _static/lecture_specific/successive_approx.py ``` @@ -345,48 +325,15 @@ get_value = jax.jit(get_value, static_argnums=(2,)) Now we define the solvers, which implement VFI, HPI and OPI. ```{code-cell} ipython3 -# Implements VFI-Value Function iteration - -def value_iteration(model, tol=1e-5): - constants, sizes, arrays = model - _T = lambda v: T(v, constants, sizes, arrays) - vz = jnp.zeros(sizes) - - v_star = successive_approx(_T, vz, tolerance=tol) - return get_greedy(v_star, constants, sizes, arrays) +:load: _static/lecture_specific/vfi.py ``` ```{code-cell} ipython3 -# Implements HPI-Howard policy iteration routine - -def policy_iteration(model, maxiter=250): - constants, sizes, arrays = model - σ = jnp.zeros(sizes, dtype=int) - i, error = 0, 1.0 - while error > 0 and i < maxiter: - v_σ = get_value(σ, constants, sizes, arrays) - σ_new = get_greedy(v_σ, constants, sizes, arrays) - error = jnp.max(jnp.abs(σ_new - σ)) - σ = σ_new - i = i + 1 - print(f"Concluded loop {i} with error {error}.") - return σ +:load: _static/lecture_specific/hpi.py ``` ```{code-cell} ipython3 -# Implements the OPI-Optimal policy Iteration routine - -def optimistic_policy_iteration(model, tol=1e-5, m=10): - constants, sizes, arrays = model - v = jnp.zeros(sizes) - error = tol + 1 - while error > tol: - last_v = v - σ = get_greedy(v, constants, sizes, arrays) - for _ in range(m): - v = T_σ(v, σ, constants, sizes, arrays) - error = jnp.max(jnp.abs(v - last_v)) - return get_greedy(v, constants, sizes, arrays) +:load: _static/lecture_specific/opi.py ``` ```{code-cell} ipython3 diff --git a/lectures/opt_savings.md b/lectures/opt_savings.md index 327de924..c71c943b 100644 --- a/lectures/opt_savings.md +++ b/lectures/opt_savings.md @@ -70,27 +70,7 @@ $$ u(c) = \frac{c^{1-\gamma}}{1-\gamma} $$ We use successive approximation for VFI. ```{code-cell} ipython3 -def successive_approx(T, # Operator (callable) - x_0, # Initial condition - tolerance=1e-6, # Error tolerance - max_iter=10_000, # Max iteration bound - print_step=25, # Print at multiples - verbose=False): - x = x_0 - error = tolerance + 1 - k = 1 - while error > tolerance and k <= max_iter: - x_new = T(x) - error = jnp.max(jnp.abs(x_new - x)) - if verbose and k % print_step == 0: - print(f"Completed iteration {k} with error {error}.") - x = x_new - k += 1 - if error > tolerance: - print(f"Warning: Iteration hit upper bound {max_iter}.") - elif verbose: - print(f"Terminated successfully in {k} iterations.") - return x +:load: _static/lecture_specific/successive_approx.py ``` ## Model primitives @@ -269,6 +249,23 @@ The code below 2. solves the linear system 3. converts back to multi-index arrays. +```{code-cell} ipython3 +def get_value(σ, constants, sizes, arrays): + "Get the value v_σ of policy σ by inverting the linear map R_σ." + + # Unpack + β, R, γ = constants + w_size, y_size = sizes + w_grid, y_grid, Q = arrays + + r_σ = compute_r_σ(σ, constants, sizes, arrays) + + # Reduce R_σ to a function in v + partial_R_σ = lambda v: R_σ(v, σ, constants, sizes, arrays) + + return jax.scipy.sparse.linalg.bicgstab(partial_R_σ, r_σ)[0] +``` + ```{code-cell} ipython3 def R_σ(v, σ, constants, sizes, arrays): """ @@ -339,46 +336,15 @@ R_σ = jax.jit(R_σ, static_argnums=(3,)) Now we define the solvers, which implement VFI, HPI and OPI. ```{code-cell} ipython3 -def value_iteration(model, tol=1e-5): - "Implements VFI." - - constants, sizes, arrays = model - _T = lambda v: T(v, constants, sizes, arrays) - vz = jnp.zeros(sizes) - - v_star = successive_approx(_T, vz, tolerance=tol) - return get_greedy(v_star, constants, sizes, arrays) +:load: _static/lecture_specific/vfi.py ``` ```{code-cell} ipython3 -def policy_iteration(model): - "Howard policy iteration routine." - constants, sizes, arrays = model - σ = jnp.zeros(sizes, dtype=int) - i, error = 0, 1.0 - while error > 0: - v_σ = get_value(σ, constants, sizes, arrays) - σ_new = get_greedy(v_σ, constants, sizes, arrays) - error = jnp.max(jnp.abs(σ_new - σ)) - σ = σ_new - i = i + 1 - print(f"Concluded loop {i} with error {error}.") - return σ +:load: _static/lecture_specific/hpi.py ``` ```{code-cell} ipython3 -def optimistic_policy_iteration(model, tol=1e-5, m=10): - "Implements the OPI routine." - constants, sizes, arrays = model - v = jnp.zeros(sizes) - error = tol + 1 - while error > tol: - last_v = v - σ = get_greedy(v, constants, sizes, arrays) - for _ in range(m): - v = T_σ(v, σ, constants, sizes, arrays) - error = jnp.max(jnp.abs(v - last_v)) - return get_greedy(v, constants, sizes, arrays) +:load: _static/lecture_specific/opi.py ``` ## Plots @@ -388,7 +354,6 @@ Create a JAX model for consumption, perform policy iteration, and plot the resul ```{code-cell} ipython3 fontsize = 12 model = create_consumption_model_jax() - # Unpack constants, sizes, arrays = model β, R, γ = constants From 21ebfc2b8fc0441c1bb4ea574eeecb6e7fa0ca7d Mon Sep 17 00:00:00 2001 From: Smit-create Date: Mon, 24 Jul 2023 16:34:13 +0530 Subject: [PATCH 3/3] Remove merge issue --- lectures/opt_savings.md | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/lectures/opt_savings.md b/lectures/opt_savings.md index c71c943b..6ceb9157 100644 --- a/lectures/opt_savings.md +++ b/lectures/opt_savings.md @@ -249,23 +249,6 @@ The code below 2. solves the linear system 3. converts back to multi-index arrays. -```{code-cell} ipython3 -def get_value(σ, constants, sizes, arrays): - "Get the value v_σ of policy σ by inverting the linear map R_σ." - - # Unpack - β, R, γ = constants - w_size, y_size = sizes - w_grid, y_grid, Q = arrays - - r_σ = compute_r_σ(σ, constants, sizes, arrays) - - # Reduce R_σ to a function in v - partial_R_σ = lambda v: R_σ(v, σ, constants, sizes, arrays) - - return jax.scipy.sparse.linalg.bicgstab(partial_R_σ, r_σ)[0] -``` - ```{code-cell} ipython3 def R_σ(v, σ, constants, sizes, arrays): """