@@ -76,15 +76,14 @@ Let's start with a standard NumPy version, running on the CPU.
7676
7777This is a traditional approach using relatively old technologies.
7878
79- One reason we start with NumPy is that switching from NumPy to JAX will be
80- relatively trivial.
81-
82- The other reason is that we want to know the speed gain associated with
83- switching to JAX.
79+ Starting with NumPy will allow us to record the speed gain associated with switching to JAX.
8480
8581(NumPy operations are similar to MATLAB operations, so this also serves as a
8682rough comparison with MATLAB.)
8783
84+
85+
86+
8887### Functions and operators
8988
9089The following function contains default parameters and returns tuples that
@@ -106,7 +105,6 @@ def create_consumption_model(R=1.01, # Gross interest rate
106105 w_grid = np.linspace(w_min, w_max, w_size)
107106 mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν)
108107 y_grid, Q = np.exp(mc.state_values), mc.P
109- w_grid, y_grid, Q = tuple(map(jax.device_put, [w_grid, y_grid, Q]))
110108 sizes = w_size, y_size
111109 return (β, R, γ), sizes, (w_grid, y_grid, Q)
112110```
@@ -397,9 +395,19 @@ The relative speed gain is
397395print(f"Relative speed gain = {numpy_elapsed / jax_elapsed}")
398396```
399397
398+
399+ This is an impressive speed up and in fact we can do better still by switching
400+ to alternative algorithms that are better suited to parallelization.
401+
402+ These algorithms are discussed in a {doc}` separate lecture <opt_savings_2> ` .
403+
404+
400405## Switching to vmap
401406
402- For this simple optimal savings problem direct vectorization is relatively easy.
407+ Before we discuss alternative algorithms, let's take another look at value
408+ function iteration.
409+
410+ For this simple optimal savings problem, direct vectorization is relatively easy.
403411
404412In particular, it's straightforward to express the right hand side of the
405413Bellman equation as an array that stores evaluations of the function at every
@@ -497,8 +505,15 @@ print(jnp.allclose(v_star_vmap, v_star_jax))
497505print(jnp.allclose(σ_star_vmap, σ_star_jax))
498506```
499507
500- The relative speed is
508+ Here's how long the ` vmap ` code takes relative to the first JAX implementation
509+ (which used direct vectorization).
501510
502511``` {code-cell} ipython3
503512print(f"Relative speed = {jax_vmap_elapsed / jax_elapsed}")
504513```
514+
515+ The execution times are relatively similar.
516+
517+ However, as emphasized above, having a second method up our sleeves (i.e, the
518+ ` vmap ` approach) will be helpful when confronting dynamic programs with more
519+ sophisticated Bellman equations.
0 commit comments