Skip to content

Commit 4beb13a

Browse files
authored
Use JAX in successive_approx (#134)
1 parent d4e46d3 commit 4beb13a

File tree

4 files changed

+40
-55
lines changed

4 files changed

+40
-55
lines changed
Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
1-
def successive_approx(T, # Operator (callable)
2-
x_0, # Initial condition
3-
tolerance=1e-6, # Error tolerance
4-
max_iter=10_000, # Max iteration bound
5-
print_step=25, # Print at multiples
6-
verbose=False):
7-
x = x_0
8-
error = tolerance + 1
9-
k = 1
10-
while error > tolerance and k <= max_iter:
11-
x_new = T(x)
1+
def successive_approx_jax(x_0, # Initial condition
2+
constants,
3+
sizes,
4+
arrays,
5+
tolerance=1e-6, # Error tolerance
6+
max_iter=10_000): # Max iteration bound
7+
8+
def body_fun(k_x_err):
9+
k, x, error = k_x_err
10+
x_new = T(x, constants, sizes, arrays)
1211
error = jnp.max(jnp.abs(x_new - x))
13-
if verbose and k % print_step == 0:
14-
print(f"Completed iteration {k} with error {error}.")
15-
x = x_new
16-
k += 1
17-
if error > tolerance:
18-
print(f"Warning: Iteration hit upper bound {max_iter}.")
19-
elif verbose:
20-
print(f"Terminated successfully in {k} iterations.")
12+
return k + 1, x_new, error
13+
14+
def cond_fun(k_x_err):
15+
k, x, error = k_x_err
16+
return jnp.logical_and(error > tolerance, k < max_iter)
17+
18+
k, x, error = jax.lax.while_loop(cond_fun, body_fun, (1, x_0, tolerance + 1))
2119
return x
20+
21+
successive_approx_jax = jax.jit(successive_approx_jax, static_argnums=(2,))

lectures/_static/lecture_specific/vfi.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
def value_iteration(model, tol=1e-5):
44
constants, sizes, arrays = model
5-
_T = lambda v: T(v, constants, sizes, arrays)
65
vz = jnp.zeros(sizes)
76

8-
v_star = successive_approx(_T, vz, tolerance=tol)
7+
v_star = successive_approx_jax(vz, constants, sizes, arrays, tolerance=tol)
98
return get_greedy(v_star, constants, sizes, arrays)

lectures/opt_invest.md

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@ jupytext:
44
extension: .md
55
format_name: myst
66
format_version: 0.13
7-
jupytext_version: 1.14.5
7+
jupytext_version: 1.16.1
88
kernelspec:
99
display_name: Python 3 (ipykernel)
1010
language: python
1111
name: python3
1212
---
1313

14-
1514
# Optimal Investment
1615

1716
```{include} _admonition/gpu.md
@@ -76,14 +75,6 @@ We will use 64 bit floats with JAX in order to increase the precision.
7675
jax.config.update("jax_enable_x64", True)
7776
```
7877

79-
80-
We need the following successive approximation function.
81-
82-
```{code-cell} ipython3
83-
:load: _static/lecture_specific/successive_approx.py
84-
```
85-
86-
8778
Let's define a function to create an investment model using the given parameters.
8879

8980
```{code-cell} ipython3
@@ -113,7 +104,6 @@ def create_investment_model(
113104
return constants, sizes, arrays
114105
```
115106

116-
117107
Let's re-write the vectorized version of the right-hand side of the
118108
Bellman equation (before maximization), which is a 3D array representing
119109

@@ -183,7 +173,6 @@ def compute_r_σ(σ, constants, sizes, arrays):
183173
compute_r_σ = jax.jit(compute_r_σ, static_argnums=(2,))
184174
```
185175

186-
187176
Define the Bellman operator.
188177

189178
```{code-cell} ipython3
@@ -194,7 +183,6 @@ def T(v, constants, sizes, arrays):
194183
T = jax.jit(T, static_argnums=(2,))
195184
```
196185

197-
198186
The following function computes a v-greedy policy.
199187

200188
```{code-cell} ipython3
@@ -205,7 +193,6 @@ def get_greedy(v, constants, sizes, arrays):
205193
get_greedy = jax.jit(get_greedy, static_argnums=(2,))
206194
```
207195

208-
209196
Define the $\sigma$-policy operator.
210197

211198
```{code-cell} ipython3
@@ -236,7 +223,6 @@ def T_σ(v, σ, constants, sizes, arrays):
236223
T_σ = jax.jit(T_σ, static_argnums=(3,))
237224
```
238225

239-
240226
Next, we want to computes the lifetime value of following policy $\sigma$.
241227

242228
This lifetime value is a function $v_\sigma$ that satisfies
@@ -285,8 +271,7 @@ def L_σ(v, σ, constants, sizes, arrays):
285271
L_σ = jax.jit(L_σ, static_argnums=(3,))
286272
```
287273

288-
Now we can define a function to compute $v_{\sigma}$
289-
274+
Now we can define a function to compute $v_{\sigma}$
290275

291276
```{code-cell} ipython3
292277
def get_value(σ, constants, sizes, arrays):
@@ -306,6 +291,11 @@ def get_value(σ, constants, sizes, arrays):
306291
get_value = jax.jit(get_value, static_argnums=(2,))
307292
```
308293

294+
We use successive approximation for VFI.
295+
296+
```{code-cell} ipython3
297+
:load: _static/lecture_specific/successive_approx.py
298+
```
309299

310300
Finally, we introduce the solvers that implement VFI, HPI and OPI.
311301

@@ -355,7 +345,6 @@ print(out)
355345
print(f"OPI completed in {elapsed} seconds.")
356346
```
357347

358-
359348
Here's the plot of the Howard policy, as a function of $y$ at the highest and lowest values of $z$.
360349

361350
```{code-cell} ipython3
@@ -377,7 +366,6 @@ ax.legend(fontsize=12)
377366
plt.show()
378367
```
379368

380-
381369
Let's plot the time taken by each of the solvers and compare them.
382370

383371
```{code-cell} ipython3
@@ -403,6 +391,7 @@ print(f"VFI completed in {vfi_time} seconds.")
403391

404392
```{code-cell} ipython3
405393
:tags: [hide-output]
394+
406395
opi_times = []
407396
for m in m_vals:
408397
print(f"Running optimistic policy iteration with m={m}.")

lectures/opt_savings.md

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ jupytext:
44
extension: .md
55
format_name: myst
66
format_version: 0.13
7-
jupytext_version: 1.14.5
7+
jupytext_version: 1.16.1
88
kernelspec:
99
display_name: Python 3 (ipykernel)
1010
language: python
@@ -65,20 +65,13 @@ where
6565

6666
$$ u(c) = \frac{c^{1-\gamma}}{1-\gamma} $$
6767

68-
+++
69-
70-
We use successive approximation for VFI.
71-
72-
```{code-cell} ipython3
73-
:load: _static/lecture_specific/successive_approx.py
74-
```
7568

7669
## Model primitives
7770

7871
First we define a model that stores parameters and grids
7972

8073
```{code-cell} ipython3
81-
def create_consumption_model(R=1.01, # Gross interest rate
74+
def create_consumption_model(R=1.01, # Gross interest rate
8275
β=0.98, # Discount factor
8376
γ=2, # CRRA parameter
8477
w_min=0.01, # Min wealth
@@ -140,8 +133,6 @@ which is defined as the vector
140133

141134
$$ r_\sigma(w, y) := r(w, y, \sigma(w, y)) $$
142135

143-
144-
145136
```{code-cell} ipython3
146137
def compute_r_σ(σ, constants, sizes, arrays):
147138
"""
@@ -187,9 +178,9 @@ def T_σ(v, σ, constants, sizes, arrays):
187178
Q = jnp.reshape(Q, (1, y_size, y_size))
188179
189180
# Calculate the expected sum Σ_jp v[σ[i, j], jp] * Q[i, j, jp]
190-
Ev = jnp.sum(V * Q, axis=2)
181+
EV = jnp.sum(V * Q, axis=2)
191182
192-
return r_σ + β * Ev
183+
return r_σ + β * EV
193184
```
194185

195186
and the Bellman operator $T$
@@ -260,7 +251,7 @@ def L_σ(v, σ, constants, sizes, arrays):
260251
return v - β * jnp.sum(V * Q, axis=2)
261252
```
262253

263-
Now we can define a function to compute $v_{\sigma}$
254+
Now we can define a function to compute $v_{\sigma}$
264255

265256
```{code-cell} ipython3
266257
def get_value(σ, constants, sizes, arrays):
@@ -291,6 +282,12 @@ T_σ = jax.jit(T_σ, static_argnums=(3,))
291282
L_σ = jax.jit(L_σ, static_argnums=(3,))
292283
```
293284

285+
We use successive approximation for VFI.
286+
287+
```{code-cell} ipython3
288+
:load: _static/lecture_specific/successive_approx.py
289+
```
290+
294291
## Solvers
295292

296293
Now we define the solvers, which implement VFI, HPI and OPI.
@@ -353,7 +350,7 @@ print("Starting VFI.")
353350
start_time = time.time()
354351
out = value_iteration(model)
355352
elapsed = time.time() - start_time
356-
print(f"VFI(jax not in succ) completed in {elapsed} seconds.")
353+
print(f"VFI completed in {elapsed} seconds.")
357354
```
358355

359356
```{code-cell} ipython3

0 commit comments

Comments
 (0)