Skip to content

Commit 91efc04

Browse files
authored
misc (#153)
1 parent 0e3aa57 commit 91efc04

File tree

1 file changed

+44
-14
lines changed

1 file changed

+44
-14
lines changed

lectures/opt_savings_1.md

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,34 +56,52 @@ $$
5656
W_{t+1} + C_t \leq R W_t + Y_t
5757
$$
5858

59-
We assume that labor income $(Y_t)$ is a discretized AR(1) process.
59+
where
60+
61+
* $C_t$ is consumption and $C_t \geq 0$,
62+
* $W_t$ is wealth and $W_t \geq 0$,
63+
* $R > 0$ is a gross rate of return, and
64+
* $(Y_t)$ is labor income.
65+
66+
We assume below that labor income is a discretized AR(1) process.
6067

61-
The right-hand side of the Bellman equation is
68+
The Bellman equation is
6269

6370
$$
64-
B((w, y), w', v) = u(Rw + y - w') + β \sum_{y'} v(w', y') Q(y, y').
71+
v(w) = \max_{0 \leq w' \leq Rw + y}
72+
\left\{
73+
u(Rw + y - w') + β \sum_{y'} v(w', y') Q(y, y')
74+
\right\}
6575
$$
6676

6777
where
6878

6979
$$
70-
u(c) = \frac{c^{1-\gamma}}{1-\gamma}
80+
u(c) = \frac{c^{1-\gamma}}{1-\gamma}
7181
$$
7282

73-
## Starting with NumPy
83+
In the code we use the function
84+
85+
$$
86+
B((w, y), w', v) = u(Rw + y - w') + β \sum_{y'} v(w', y') Q(y, y').
87+
$$
88+
89+
the encapsulate the right hand side of the Bellman equation.
90+
7491

75-
Let's start with a standard NumPy version, running on the CPU.
7692

77-
This is a traditional approach using relatively old technologies.
93+
## Starting with NumPy
94+
95+
Let's start with a standard NumPy version running on the CPU.
7896

79-
Starting with NumPy will allow us to record the speed gain associated with switching to JAX.
97+
Starting with this traditional approach will allow us to record the speed gain
98+
associated with switching to JAX.
8099

81100
(NumPy operations are similar to MATLAB operations, so this also serves as a
82101
rough comparison with MATLAB.)
83102

84103

85104

86-
87105
### Functions and operators
88106

89107
The following function contains default parameters and returns tuples that
@@ -218,6 +236,8 @@ ax.legend()
218236
plt.show()
219237
```
220238

239+
240+
221241
## Switching to JAX
222242

223243
To switch over to JAX, we change `np` to `jnp` throughout and add some
@@ -284,7 +304,6 @@ def B(v, constants, sizes, arrays):
284304
return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf)
285305
286306
287-
B = jax.jit(B, static_argnums=(2,))
288307
```
289308

290309
Some readers might be concerned that we are creating high dimensional arrays,
@@ -295,6 +314,12 @@ Could they be avoided by more careful vectorization?
295314
In fact this is not necessary: this function will be JIT-compiled by JAX, and
296315
the JIT compiler will optimize compiled code to minimize memory use.
297316

317+
```{code-cell} ipython3
318+
B = jax.jit(B, static_argnums=(2,))
319+
```
320+
321+
In the call above, we indicate to the compiler that `sizes` is static, so the
322+
compiler can parallelize optimally while taking array sizes as fixed.
298323

299324
The Bellman operator $T$ can be implemented by
300325

@@ -505,14 +530,19 @@ print(jnp.allclose(v_star_vmap, v_star_jax))
505530
print(jnp.allclose(σ_star_vmap, σ_star_jax))
506531
```
507532

508-
Here's how long the `vmap` code takes relative to the first JAX implementation
509-
(which used direct vectorization).
533+
Here's the speed gain associated with switching from the NumPy version to JAX with `vmap`:
534+
535+
```{code-cell} ipython3
536+
print(f"Relative speed = {numpy_elapsed / jax_vmap_elapsed}")
537+
```
538+
539+
And here's the comparison with the first JAX implementation (which used direct vectorization).
510540

511541
```{code-cell} ipython3
512-
print(f"Relative speed = {jax_vmap_elapsed / jax_elapsed}")
542+
print(f"Relative speed = {jax_elapsed / jax_vmap_elapsed}")
513543
```
514544

515-
The execution times are relatively similar.
545+
The execution times for the two JAX versions are relatively similar.
516546

517547
However, as emphasized above, having a second method up our sleeves (i.e, the
518548
`vmap` approach) will be helpful when confronting dynamic programs with more

0 commit comments

Comments
 (0)