Skip to content

Commit 0e3aa57

Browse files
authored
Small patches to optimal savings (#152)
* misc * misc
1 parent 216ed17 commit 0e3aa57

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

lectures/opt_savings_1.md

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,14 @@ Let's start with a standard NumPy version, running on the CPU.
7676

7777
This is a traditional approach using relatively old technologies.
7878

79-
One reason we start with NumPy is that switching from NumPy to JAX will be
80-
relatively trivial.
81-
82-
The other reason is that we want to know the speed gain associated with
83-
switching to JAX.
79+
Starting with NumPy will allow us to record the speed gain associated with switching to JAX.
8480

8581
(NumPy operations are similar to MATLAB operations, so this also serves as a
8682
rough comparison with MATLAB.)
8783

84+
85+
86+
8887
### Functions and operators
8988

9089
The following function contains default parameters and returns tuples that
@@ -106,7 +105,6 @@ def create_consumption_model(R=1.01, # Gross interest rate
106105
w_grid = np.linspace(w_min, w_max, w_size)
107106
mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν)
108107
y_grid, Q = np.exp(mc.state_values), mc.P
109-
w_grid, y_grid, Q = tuple(map(jax.device_put, [w_grid, y_grid, Q]))
110108
sizes = w_size, y_size
111109
return (β, R, γ), sizes, (w_grid, y_grid, Q)
112110
```
@@ -397,9 +395,19 @@ The relative speed gain is
397395
print(f"Relative speed gain = {numpy_elapsed / jax_elapsed}")
398396
```
399397

398+
399+
This is an impressive speed up and in fact we can do better still by switching
400+
to alternative algorithms that are better suited to parallelization.
401+
402+
These algorithms are discussed in a {doc}`separate lecture <opt_savings_2>`.
403+
404+
400405
## Switching to vmap
401406

402-
For this simple optimal savings problem direct vectorization is relatively easy.
407+
Before we discuss alternative algorithms, let's take another look at value
408+
function iteration.
409+
410+
For this simple optimal savings problem, direct vectorization is relatively easy.
403411

404412
In particular, it's straightforward to express the right hand side of the
405413
Bellman equation as an array that stores evaluations of the function at every
@@ -497,8 +505,15 @@ print(jnp.allclose(v_star_vmap, v_star_jax))
497505
print(jnp.allclose(σ_star_vmap, σ_star_jax))
498506
```
499507

500-
The relative speed is
508+
Here's how long the `vmap` code takes relative to the first JAX implementation
509+
(which used direct vectorization).
501510

502511
```{code-cell} ipython3
503512
print(f"Relative speed = {jax_vmap_elapsed / jax_elapsed}")
504513
```
514+
515+
The execution times are relatively similar.
516+
517+
However, as emphasized above, having a second method up our sleeves (i.e, the
518+
`vmap` approach) will be helpful when confronting dynamic programs with more
519+
sophisticated Bellman equations.

lectures/opt_savings_2.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,6 @@ def create_consumption_model(R=1.01, # Gross interest rate
110110
Here's the right hand side of the Bellman equation:
111111

112112
```{code-cell} ipython3
113-
:tags: [hide-input]
114-
115113
def B(v, constants, sizes, arrays):
116114
"""
117115
A vectorized version of the right-hand side of the Bellman equation
@@ -427,4 +425,4 @@ ax.legend(frameon=False)
427425
ax.set_xlabel("$m$")
428426
ax.set_ylabel("time")
429427
plt.show()
430-
```
428+
```

0 commit comments

Comments
 (0)