5656W_{t+1} + C_t \leq R W_t + Y_t
5757$$
5858
59- We assume that labor income $(Y_t)$ is a discretized AR(1) process.
59+ where
60+
61+ * $C_t$ is consumption and $C_t \geq 0$,
62+ * $W_t$ is wealth and $W_t \geq 0$,
63+ * $R > 0$ is a gross rate of return, and
64+ * $(Y_t)$ is labor income.
65+
66+ We assume below that labor income is a discretized AR(1) process.
6067
61- The right-hand side of the Bellman equation is
68+ The Bellman equation is
6269
6370$$
64- B((w, y), w', v) = u(Rw + y - w') + β \sum_{y'} v(w', y') Q(y, y').
71+ v(w) = \max_{0 \leq w' \leq Rw + y}
72+ \left\{
73+ u(Rw + y - w') + β \sum_{y'} v(w', y') Q(y, y')
74+ \right\}
6575$$
6676
6777where
6878
6979$$
70- u(c) = \frac{c^{1-\gamma}}{1-\gamma}
80+ u(c) = \frac{c^{1-\gamma}}{1-\gamma}
7181$$
7282
73- ## Starting with NumPy
83+ In the code we use the function
84+
85+ $$
86+ B((w, y), w', v) = u(Rw + y - w') + β \sum_{y'} v(w', y') Q(y, y').
87+ $$
88+
89+ the encapsulate the right hand side of the Bellman equation.
90+
7491
75- Let's start with a standard NumPy version, running on the CPU.
7692
77- This is a traditional approach using relatively old technologies.
93+ ## Starting with NumPy
94+
95+ Let's start with a standard NumPy version running on the CPU.
7896
79- Starting with NumPy will allow us to record the speed gain associated with switching to JAX.
97+ Starting with this traditional approach will allow us to record the speed gain
98+ associated with switching to JAX.
8099
81100(NumPy operations are similar to MATLAB operations, so this also serves as a
82101rough comparison with MATLAB.)
83102
84103
85104
86-
87105### Functions and operators
88106
89107The following function contains default parameters and returns tuples that
@@ -218,6 +236,8 @@ ax.legend()
218236plt.show()
219237```
220238
239+
240+
221241## Switching to JAX
222242
223243To switch over to JAX, we change ` np ` to ` jnp ` throughout and add some
@@ -284,7 +304,6 @@ def B(v, constants, sizes, arrays):
284304 return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf)
285305
286306
287- B = jax.jit(B, static_argnums=(2,))
288307```
289308
290309Some readers might be concerned that we are creating high dimensional arrays,
@@ -295,6 +314,12 @@ Could they be avoided by more careful vectorization?
295314In fact this is not necessary: this function will be JIT-compiled by JAX, and
296315the JIT compiler will optimize compiled code to minimize memory use.
297316
317+ ``` {code-cell} ipython3
318+ B = jax.jit(B, static_argnums=(2,))
319+ ```
320+
321+ In the call above, we indicate to the compiler that ` sizes ` is static, so the
322+ compiler can parallelize optimally while taking array sizes as fixed.
298323
299324The Bellman operator $T$ can be implemented by
300325
@@ -505,14 +530,19 @@ print(jnp.allclose(v_star_vmap, v_star_jax))
505530print(jnp.allclose(σ_star_vmap, σ_star_jax))
506531```
507532
508- Here's how long the ` vmap ` code takes relative to the first JAX implementation
509- (which used direct vectorization).
533+ Here's the speed gain associated with switching from the NumPy version to JAX with ` vmap ` :
534+
535+ ``` {code-cell} ipython3
536+ print(f"Relative speed = {numpy_elapsed / jax_vmap_elapsed}")
537+ ```
538+
539+ And here's the comparison with the first JAX implementation (which used direct vectorization).
510540
511541``` {code-cell} ipython3
512- print(f"Relative speed = {jax_vmap_elapsed / jax_elapsed }")
542+ print(f"Relative speed = {jax_elapsed / jax_vmap_elapsed }")
513543```
514544
515- The execution times are relatively similar.
545+ The execution times for the two JAX versions are relatively similar.
516546
517547However, as emphasized above, having a second method up our sleeves (i.e, the
518548` vmap ` approach) will be helpful when confronting dynamic programs with more
0 commit comments