Skip to content

Commit b56e91d

Browse files
authored
Fixes in opt savings (#85)
* Fix order * fixes * Unused variables
1 parent ed4c005 commit b56e91d

File tree

2 files changed

+46
-39
lines changed

2 files changed

+46
-39
lines changed

lectures/opt_invest.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def T_σ(v, σ, constants, sizes, arrays):
250250
# Calculate the expected sum Σ_jp v[σ[i, j], jp] * Q[i, j, jp]
251251
Ev = jnp.sum(V * Q, axis=2)
252252
253-
return r_σ + β * jnp.sum(V * Q, axis=2)
253+
return r_σ + β * Ev
254254
255255
T_σ = jax.jit(T_σ, static_argnums=(3,))
256256
```
@@ -361,7 +361,6 @@ def value_iteration(model, tol=1e-5):
361361
362362
def policy_iteration(model, maxiter=250):
363363
constants, sizes, arrays = model
364-
vz = jnp.zeros(sizes)
365364
σ = jnp.zeros(sizes, dtype=int)
366365
i, error = 0, 1.0
367366
while error > 0 and i < maxiter:

lectures/opt_savings.md

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ $$ W_{t+1} + C_t \leq R W_t + Y_t $$
5757

5858
We assume that labor income $(Y_t)$ is a discretized AR(1) process.
5959

60-
The right-hand side of the Bellman equation is
60+
The right-hand side of the Bellman equation is
6161

6262
$$ B((w, y), w', v) = u(Rw + y - w') + β \sum_{y'} v(w', y') Q(y, y'). $$
6363

@@ -75,7 +75,7 @@ def successive_approx(T, # Operator (callable)
7575
tolerance=1e-6, # Error tolerance
7676
max_iter=10_000, # Max iteration bound
7777
print_step=25, # Print at multiples
78-
verbose=False):
78+
verbose=False):
7979
x = x_0
8080
error = tolerance + 1
8181
k = 1
@@ -98,7 +98,7 @@ def successive_approx(T, # Operator (callable)
9898
Here’s a `namedtuple` definition for storing parameters and grids.
9999

100100
```{code-cell} ipython3
101-
Model = namedtuple('Model',
101+
Model = namedtuple('Model',
102102
('β', 'R', 'γ', 'w_grid', 'y_grid', 'Q'))
103103
```
104104

@@ -114,7 +114,7 @@ def create_consumption_model(R=1.01, # Gross interest rate
114114
A function that takes in parameters and returns an instance of Model that
115115
contains data for the optimal savings problem.
116116
"""
117-
w_grid = jnp.linspace(w_min, w_max, w_size)
117+
w_grid = jnp.linspace(w_min, w_max, w_size)
118118
mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν)
119119
y_grid, Q = jnp.exp(mc.state_values), mc.P
120120
return Model(β=β, R=R, γ=γ, w_grid=w_grid, y_grid=y_grid, Q=Q)
@@ -129,7 +129,7 @@ def create_consumption_model_jax(R=1.01, # Gross interest rate
129129
w_size=150, # Grid side
130130
ρ=0.9, ν=0.1, y_size=100): # Income parameters
131131
"""
132-
A function that takes in parameters and returns a JAX-compatible version of
132+
A function that takes in parameters and returns a JAX-compatible version of
133133
Model that contains data for the optimal savings problem.
134134
"""
135135
w_grid = jnp.linspace(w_min, w_max, w_size)
@@ -146,15 +146,15 @@ Here's the right hand side of the Bellman equation:
146146
```{code-cell} ipython3
147147
def B(v, constants, sizes, arrays):
148148
"""
149-
A vectorized version of the right-hand side of the Bellman equation
149+
A vectorized version of the right-hand side of the Bellman equation
150150
(before maximization), which is a 3D array representing
151151
152152
B(w, y, w′) = u(Rw + y - w′) + β Σ_y′ v(w′, y′) Q(y, y′)
153153
154154
for all (w, y, w′).
155155
"""
156156
157-
# Unpack
157+
# Unpack
158158
β, R, γ = constants
159159
w_size, y_size = sizes
160160
w_grid, y_grid, Q = arrays
@@ -215,15 +215,15 @@ def T_σ(v, σ, constants, sizes, arrays):
215215
yp_idx = jnp.arange(y_size)
216216
yp_idx = jnp.reshape(yp_idx, (1, 1, y_size))
217217
σ = jnp.reshape(σ, (w_size, y_size, 1))
218-
V = v[σ, yp_idx]
218+
V = v[σ, yp_idx]
219219
220-
# Convert Q[j, jp] to Q[i, j, jp]
220+
# Convert Q[j, jp] to Q[i, j, jp]
221221
Q = jnp.reshape(Q, (1, y_size, y_size))
222222
223223
# Calculate the expected sum Σ_jp v[σ[i, j], jp] * Q[i, j, jp]
224224
Ev = jnp.sum(V * Q, axis=2)
225225
226-
return r_σ + β * jnp.sum(V * Q, axis=2)
226+
return r_σ + β * Ev
227227
```
228228

229229
and the Bellman operator $T$
@@ -248,9 +248,9 @@ The basic problem is to solve the linear system
248248

249249
$$ v(w,y ) = u(Rw + y - \sigma(w, y)) + β \sum_{y'} v(\sigma(w, y), y') Q(y, y) $$
250250

251-
for $v$.
251+
for $v$.
252252

253-
It turns out to be helpful to rewrite this as
253+
It turns out to be helpful to rewrite this as
254254

255255
$$ v(w,y) = r(w, y, \sigma(w, y)) + β \sum_{w', y'} v(w', y') P_\sigma(w, y, w', y') $$
256256

@@ -260,40 +260,23 @@ We want to write this as $v = r_\sigma + P_\sigma v$ and then solve for $v$
260260

261261
Note, however,
262262

263-
* $v$ is a 2 index array, rather than a single vector.
264-
* $P_\sigma$ has four indices rather than 2
263+
* $v$ is a 2 index array, rather than a single vector.
264+
* $P_\sigma$ has four indices rather than 2
265265

266-
The code below
266+
The code below
267267

268268
1. reshapes $v$ and $r_\sigma$ to 1D arrays and $P_\sigma$ to a matrix
269269
2. solves the linear system
270270
3. converts back to multi-index arrays.
271271

272-
```{code-cell} ipython3
273-
def get_value(σ, constants, sizes, arrays):
274-
"Get the value v_σ of policy σ by inverting the linear map R_σ."
275-
276-
# Unpack
277-
β, R, γ = constants
278-
w_size, y_size = sizes
279-
w_grid, y_grid, Q = arrays
280-
281-
r_σ = compute_r_σ(σ, constants, sizes, arrays)
282-
283-
# Reduce R_σ to a function in v
284-
partial_R_σ = lambda v: R_σ(v, σ, constants, sizes, arrays)
285-
286-
return jax.scipy.sparse.linalg.bicgstab(partial_R_σ, r_σ)[0]
287-
```
288-
289272
```{code-cell} ipython3
290273
def R_σ(v, σ, constants, sizes, arrays):
291274
"""
292-
The value v_σ of a policy σ is defined as
275+
The value v_σ of a policy σ is defined as
293276
294277
v_σ = (I - β P_σ)^{-1} r_σ
295278
296-
Here we set up the linear map v -> R_σ v, where R_σ := I - β P_σ.
279+
Here we set up the linear map v -> R_σ v, where R_σ := I - β P_σ.
297280
298281
In the consumption problem, this map can be expressed as
299282
@@ -322,6 +305,23 @@ def R_σ(v, σ, constants, sizes, arrays):
322305
return v - β * jnp.sum(V * Q, axis=2)
323306
```
324307

308+
```{code-cell} ipython3
309+
def get_value(σ, constants, sizes, arrays):
310+
"Get the value v_σ of policy σ by inverting the linear map R_σ."
311+
312+
# Unpack
313+
β, R, γ = constants
314+
w_size, y_size = sizes
315+
w_grid, y_grid, Q = arrays
316+
317+
r_σ = compute_r_σ(σ, constants, sizes, arrays)
318+
319+
# Reduce R_σ to a function in v
320+
partial_R_σ = lambda v: R_σ(v, σ, constants, sizes, arrays)
321+
322+
return jax.scipy.sparse.linalg.bicgstab(partial_R_σ, r_σ)[0]
323+
```
324+
325325
## JIT compiled versions
326326

327327
```{code-cell} ipython3
@@ -354,7 +354,6 @@ def value_iteration(model, tol=1e-5):
354354
def policy_iteration(model):
355355
"Howard policy iteration routine."
356356
constants, sizes, arrays = model
357-
vz = jnp.zeros(sizes)
358357
σ = jnp.zeros(sizes, dtype=int)
359358
i, error = 0, 1.0
360359
while error > 0:
@@ -387,14 +386,19 @@ def optimistic_policy_iteration(model, tol=1e-5, m=10):
387386
Create a JAX model for consumption, perform policy iteration, and plot the resulting optimal policy function.
388387

389388
```{code-cell} ipython3
390-
fontsize=12
389+
fontsize = 12
391390
model = create_consumption_model_jax()
392-
# Unpack
391+
392+
# Unpack
393393
constants, sizes, arrays = model
394394
β, R, γ = constants
395395
w_size, y_size = sizes
396396
w_grid, y_grid, Q = arrays
397+
```
398+
399+
```{code-cell} ipython3
397400
σ_star = policy_iteration(model)
401+
398402
fig, ax = plt.subplots(figsize=(9, 5.2))
399403
ax.plot(w_grid, w_grid, "k--", label="45")
400404
ax.plot(w_grid, w_grid[σ_star[:, 1]], label="$\\sigma^*(\cdot, y_1)$")
@@ -443,7 +447,9 @@ def run_algorithm(algorithm, model, **kwargs):
443447
elapsed_time = end_time - start_time
444448
print(f"{algorithm.__name__} completed in {elapsed_time:.2f} seconds.")
445449
return result, elapsed_time
450+
```
446451

452+
```{code-cell} ipython3
447453
model = create_consumption_model_jax()
448454
σ_pi, pi_time = run_algorithm(policy_iteration, model)
449455
σ_vfi, vfi_time = run_algorithm(value_iteration, model, tol=1e-5)
@@ -453,7 +459,9 @@ opi_times = []
453459
for m in m_vals:
454460
σ_opi, opi_time = run_algorithm(optimistic_policy_iteration, model, m=m, tol=1e-5)
455461
opi_times.append(opi_time)
462+
```
456463

464+
```{code-cell} ipython3
457465
fig, ax = plt.subplots(figsize=(9, 5.2))
458466
ax.plot(m_vals, jnp.full(len(m_vals), pi_time), lw=2, label="Howard policy iteration")
459467
ax.plot(m_vals, jnp.full(len(m_vals), vfi_time), lw=2, label="value function iteration")

0 commit comments

Comments
 (0)