From 8baace514d18307da0033f69e46eb76bd5486723 Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Thu, 14 Mar 2024 13:46:33 +1100 Subject: [PATCH 1/2] misc --- lectures/opt_savings_1.md | 32 ++++++++++++++++++++++++-------- lectures/opt_savings_2.md | 4 +--- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/lectures/opt_savings_1.md b/lectures/opt_savings_1.md index 960acdd4..b7cd3ed1 100644 --- a/lectures/opt_savings_1.md +++ b/lectures/opt_savings_1.md @@ -76,15 +76,14 @@ Let's start with a standard NumPy version, running on the CPU. This is a traditional approach using relatively old technologies. -One reason we start with NumPy is that switching from NumPy to JAX will be -relatively trivial. - -The other reason is that we want to know the speed gain associated with -switching to JAX. +Starting with NumPy will allow us to record the speed gain associated with switching to JAX. (NumPy operations are similar to MATLAB operations, so this also serves as a rough comparison with MATLAB.) + + + ### Functions and operators The following function contains default parameters and returns tuples that @@ -106,7 +105,6 @@ def create_consumption_model(R=1.01, # Gross interest rate w_grid = np.linspace(w_min, w_max, w_size) mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν) y_grid, Q = np.exp(mc.state_values), mc.P - w_grid, y_grid, Q = tuple(map(jax.device_put, [w_grid, y_grid, Q])) sizes = w_size, y_size return (β, R, γ), sizes, (w_grid, y_grid, Q) ``` @@ -397,9 +395,20 @@ The relative speed gain is print(f"Relative speed gain = {numpy_elapsed / jax_elapsed}") ``` + +This is an impressive speed up and in fact we can do better still by switching +to alternative algorithms that are better suited to parallelization. + +These algorithms are discussed in a {doct}`separate lecture ` in +this series. + + ## Switching to vmap -For this simple optimal savings problem direct vectorization is relatively easy. +Before we discuss alternative algorithms, let's take another look at value +function iteration. + +For this simple optimal savings problem, direct vectorization is relatively easy. In particular, it's straightforward to express the right hand side of the Bellman equation as an array that stores evaluations of the function at every @@ -497,8 +506,15 @@ print(jnp.allclose(v_star_vmap, v_star_jax)) print(jnp.allclose(σ_star_vmap, σ_star_jax)) ``` -The relative speed is +Here's how long the `vmap` code takes relative to the first JAX implementation +(which used direct vectorization). ```{code-cell} ipython3 print(f"Relative speed = {jax_vmap_elapsed / jax_elapsed}") ``` + +The execution times are relatively similar. + +However, as emphasized above, having a second method up our sleeves (i.e, the +`vmap` approach) will be helpful when confronting dynamic programs with more +sophisticated Bellman equations. diff --git a/lectures/opt_savings_2.md b/lectures/opt_savings_2.md index 6e3dfaff..a05fbe10 100644 --- a/lectures/opt_savings_2.md +++ b/lectures/opt_savings_2.md @@ -110,8 +110,6 @@ def create_consumption_model(R=1.01, # Gross interest rate Here's the right hand side of the Bellman equation: ```{code-cell} ipython3 -:tags: [hide-input] - def B(v, constants, sizes, arrays): """ A vectorized version of the right-hand side of the Bellman equation @@ -427,4 +425,4 @@ ax.legend(frameon=False) ax.set_xlabel("$m$") ax.set_ylabel("time") plt.show() -``` \ No newline at end of file +``` From 456d38e2088fa7bbd960785e0d89c1ab05649788 Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Thu, 14 Mar 2024 14:10:04 +1100 Subject: [PATCH 2/2] misc --- lectures/opt_savings_1.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lectures/opt_savings_1.md b/lectures/opt_savings_1.md index b7cd3ed1..d0f5bad9 100644 --- a/lectures/opt_savings_1.md +++ b/lectures/opt_savings_1.md @@ -399,8 +399,7 @@ print(f"Relative speed gain = {numpy_elapsed / jax_elapsed}") This is an impressive speed up and in fact we can do better still by switching to alternative algorithms that are better suited to parallelization. -These algorithms are discussed in a {doct}`separate lecture ` in -this series. +These algorithms are discussed in a {doc}`separate lecture `. ## Switching to vmap