@@ -11,7 +11,6 @@ kernelspec:
1111 name : python3
1212---
1313
14-
1514# Optimal Savings II: Alternative Algorithms
1615
1716``` {include} _admonition/gpu.md
@@ -65,7 +64,6 @@ We will use the following imports:
6564import quantecon as qe
6665import jax
6766import jax.numpy as jnp
68- from collections import namedtuple
6967import matplotlib.pyplot as plt
7068import time
7169```
@@ -171,7 +169,6 @@ def get_greedy(v, params, sizes, arrays):
171169 return jnp.argmax(B(v, params, sizes, arrays), axis=-1)
172170
173171get_greedy = jax.jit(get_greedy, static_argnums=(2,))
174-
175172```
176173
177174We define a function to compute the current rewards $r_ \sigma$ given policy $\sigma$,
@@ -248,7 +245,6 @@ def T_σ(v, σ, params, sizes, arrays):
248245T_σ = jax.jit(T_σ, static_argnums=(3,))
249246```
250247
251-
252248The function below computes the value $v_ \sigma$ of following policy $\sigma$.
253249
254250This lifetime value is a function $v_ \sigma$ that satisfies
@@ -325,11 +321,8 @@ def get_value(σ, params, sizes, arrays):
325321 return jax.scipy.sparse.linalg.bicgstab(partial_L_σ, r_σ)[0]
326322
327323get_value = jax.jit(get_value, static_argnums=(2,))
328-
329324```
330325
331-
332-
333326## Iteration
334327
335328
@@ -374,7 +367,6 @@ iterate_policy_operator = jax.jit(iterate_policy_operator,
374367 static_argnums=(4,))
375368```
376369
377-
378370## Solvers
379371
380372Now we define the solvers, which implement VFI, HPI and OPI.
@@ -395,7 +387,6 @@ def value_function_iteration(model, tol=1e-5):
395387
396388For OPI we will use a compiled JAX ` lax.while_loop ` operation to speed execution.
397389
398-
399390``` {code-cell} ipython3
400391def opi_loop(params, sizes, arrays, m, tol, max_iter):
401392 """
@@ -436,7 +427,6 @@ def optimistic_policy_iteration(model, m=10, tol=1e-5, max_iter=10_000):
436427 return σ_star
437428```
438429
439-
440430Here's HPI.
441431
442432``` {code-cell} ipython3
@@ -457,9 +447,9 @@ def howard_policy_iteration(model, maxiter=250):
457447 return σ
458448```
459449
460- ## Plots
450+ ## Tests
461451
462- Create a model for consumption, perform policy iteration, and plot the resulting optimal policy function.
452+ Let's create a model for consumption, and plot the resulting optimal policy function using all the three algorithms and also check the time taken by each solver .
463453
464454``` {code-cell} ipython3
465455model = create_consumption_model()
@@ -470,55 +460,82 @@ w_size, y_size = sizes
470460w_grid, y_grid, Q = arrays
471461```
472462
463+ ``` {code-cell} ipython3
464+ print("Starting HPI.")
465+ start_time = time.time()
466+ σ_star_hpi = howard_policy_iteration(model)
467+ elapsed = time.time() - start_time
468+ print(f"HPI completed in {elapsed} seconds.")
469+ ```
470+
473471``` {code-cell} ipython3
474472---
475473mystnb:
476474 figure:
477- caption: Optimal policy function
478- name: optimal-policy-function
475+ caption: Optimal policy function (HPI)
476+ name: optimal-policy-function-hpi
479477---
480- σ_star = howard_policy_iteration(model)
481478
482479fig, ax = plt.subplots()
483480ax.plot(w_grid, w_grid, "k--", label="45")
484- ax.plot(w_grid, w_grid[σ_star [:, 1]], label="$\\sigma^* (\cdot, y_1)$")
485- ax.plot(w_grid, w_grid[σ_star [:, -1]], label="$\\sigma^* (\cdot, y_N)$")
481+ ax.plot(w_grid, w_grid[σ_star_hpi [:, 1]], label="$\\sigma^{*}_{HPI} (\cdot, y_1)$")
482+ ax.plot(w_grid, w_grid[σ_star_hpi [:, -1]], label="$\\sigma^{*}_{HPI} (\cdot, y_N)$")
486483ax.legend()
487484plt.show()
488485```
489486
490- ## Tests
491-
492- Here's a quick test of the timing of each solver.
493-
494- ``` {code-cell} ipython3
495- model = create_consumption_model()
496- ```
497-
498487``` {code-cell} ipython3
499- print("Starting HPI .")
488+ print("Starting VFI .")
500489start_time = time.time()
501- out = howard_policy_iteration (model)
490+ σ_star_vfi = value_function_iteration (model)
502491elapsed = time.time() - start_time
503- print(f"HPI completed in {elapsed} seconds.")
492+ print(f"VFI completed in {elapsed} seconds.")
504493```
505494
506495``` {code-cell} ipython3
507- print("Starting VFI.")
508- start_time = time.time()
509- out = value_function_iteration(model)
510- elapsed = time.time() - start_time
511- print(f"VFI completed in {elapsed} seconds.")
496+ ---
497+ mystnb:
498+ figure:
499+ caption: Optimal policy function (VFI)
500+ name: optimal-policy-function-vfi
501+ ---
502+
503+ fig, ax = plt.subplots()
504+ ax.plot(w_grid, w_grid, "k--", label="45")
505+ ax.plot(w_grid, w_grid[σ_star_vfi[:, 1]], label="$\\sigma^{*}_{VFI}(\cdot, y_1)$")
506+ ax.plot(w_grid, w_grid[σ_star_vfi[:, -1]], label="$\\sigma^{*}_{VFI}(\cdot, y_N)$")
507+ ax.legend()
508+ plt.show()
512509```
513510
514511``` {code-cell} ipython3
515512print("Starting OPI.")
516513start_time = time.time()
517- out = optimistic_policy_iteration(model, m=100)
514+ σ_star_opi = optimistic_policy_iteration(model, m=100)
518515elapsed = time.time() - start_time
519516print(f"OPI completed in {elapsed} seconds.")
520517```
521518
519+ ``` {code-cell} ipython3
520+ ---
521+ mystnb:
522+ figure:
523+ caption: Optimal policy function (OPI)
524+ name: optimal-policy-function-opi
525+ ---
526+
527+ fig, ax = plt.subplots()
528+ ax.plot(w_grid, w_grid, "k--", label="45")
529+ ax.plot(w_grid, w_grid[σ_star_opi[:, 1]], label="$\\sigma^{*}_{OPI}(\cdot, y_1)$")
530+ ax.plot(w_grid, w_grid[σ_star_opi[:, -1]], label="$\\sigma^{*}_{OPI}(\cdot, y_N)$")
531+ ax.legend()
532+ plt.show()
533+ ```
534+
535+ We observe that all the solvers produce the same output from the above three plots.
536+
537+ Now, let's create a plot to visualize the time differences among these algorithms.
538+
522539``` {code-cell} ipython3
523540def run_algorithm(algorithm, model, **kwargs):
524541 start_time = time.time()
@@ -530,7 +547,6 @@ def run_algorithm(algorithm, model, **kwargs):
530547```
531548
532549``` {code-cell} ipython3
533- model = create_consumption_model()
534550σ_pi, pi_time = run_algorithm(howard_policy_iteration, model)
535551σ_vfi, vfi_time = run_algorithm(value_function_iteration, model, tol=1e-5)
536552
0 commit comments