Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 44 additions & 14 deletions lectures/opt_savings_1.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,34 +56,52 @@ $$
W_{t+1} + C_t \leq R W_t + Y_t
$$

We assume that labor income $(Y_t)$ is a discretized AR(1) process.
where

* $C_t$ is consumption and $C_t \geq 0$,
* $W_t$ is wealth and $W_t \geq 0$,
* $R > 0$ is a gross rate of return, and
* $(Y_t)$ is labor income.

We assume below that labor income is a discretized AR(1) process.

The right-hand side of the Bellman equation is
The Bellman equation is

$$
B((w, y), w', v) = u(Rw + y - w') + β \sum_{y'} v(w', y') Q(y, y').
v(w) = \max_{0 \leq w' \leq Rw + y}
\left\{
u(Rw + y - w') + β \sum_{y'} v(w', y') Q(y, y')
\right\}
$$

where

$$
u(c) = \frac{c^{1-\gamma}}{1-\gamma}
u(c) = \frac{c^{1-\gamma}}{1-\gamma}
$$

## Starting with NumPy
In the code we use the function

$$
B((w, y), w', v) = u(Rw + y - w') + β \sum_{y'} v(w', y') Q(y, y').
$$

the encapsulate the right hand side of the Bellman equation.


Let's start with a standard NumPy version, running on the CPU.

This is a traditional approach using relatively old technologies.
## Starting with NumPy

Let's start with a standard NumPy version running on the CPU.

Starting with NumPy will allow us to record the speed gain associated with switching to JAX.
Starting with this traditional approach will allow us to record the speed gain
associated with switching to JAX.

(NumPy operations are similar to MATLAB operations, so this also serves as a
rough comparison with MATLAB.)




### Functions and operators

The following function contains default parameters and returns tuples that
Expand Down Expand Up @@ -218,6 +236,8 @@ ax.legend()
plt.show()
```



## Switching to JAX

To switch over to JAX, we change `np` to `jnp` throughout and add some
Expand Down Expand Up @@ -284,7 +304,6 @@ def B(v, constants, sizes, arrays):
return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf)


B = jax.jit(B, static_argnums=(2,))
```

Some readers might be concerned that we are creating high dimensional arrays,
Expand All @@ -295,6 +314,12 @@ Could they be avoided by more careful vectorization?
In fact this is not necessary: this function will be JIT-compiled by JAX, and
the JIT compiler will optimize compiled code to minimize memory use.

```{code-cell} ipython3
B = jax.jit(B, static_argnums=(2,))
```

In the call above, we indicate to the compiler that `sizes` is static, so the
compiler can parallelize optimally while taking array sizes as fixed.

The Bellman operator $T$ can be implemented by

Expand Down Expand Up @@ -505,14 +530,19 @@ print(jnp.allclose(v_star_vmap, v_star_jax))
print(jnp.allclose(σ_star_vmap, σ_star_jax))
```

Here's how long the `vmap` code takes relative to the first JAX implementation
(which used direct vectorization).
Here's the speed gain associated with switching from the NumPy version to JAX with `vmap`:

```{code-cell} ipython3
print(f"Relative speed = {numpy_elapsed / jax_vmap_elapsed}")
```

And here's the comparison with the first JAX implementation (which used direct vectorization).

```{code-cell} ipython3
print(f"Relative speed = {jax_vmap_elapsed / jax_elapsed}")
print(f"Relative speed = {jax_elapsed / jax_vmap_elapsed}")
```

The execution times are relatively similar.
The execution times for the two JAX versions are relatively similar.

However, as emphasized above, having a second method up our sleeves (i.e, the
`vmap` approach) will be helpful when confronting dynamic programs with more
Expand Down