Skip to content

Commit a72d9b3

Browse files
committed
use vectorized computation
vectorized computation instead of repeatedly resetting arrays with for loops
1 parent 9143f6b commit a72d9b3

File tree

1 file changed

+27
-23
lines changed

1 file changed

+27
-23
lines changed

lectures/mccall_q.md

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ $$
124124
125125
Let's use Python code from {doc}`this quantecon lecture <mccall_model>`.
126126
127-
We use a Python method called `VFI` to compute the optimal value function using value function iterations.
127+
We use a Python method called `vfi` to compute the optimal value function using value function iterations.
128128
129129
We construct an assumed distribution of wages and plot it with the following Python code
130130
@@ -170,32 +170,36 @@ def state_action_values(model, i, v):
170170
return jnp.array([accept, reject])
171171
172172
@jax.jit
173-
def VFI(model, eps=1e-5, max_iter=500):
174-
"""Find the optimal value function."""
175-
n = len(model.w)
176-
v_init = model.w / (1 - model.β)
173+
def update(model, v):
174+
n = model.w.shape[0]
177175
178-
def body_fun(state):
179-
v, i, error = state
180-
v_next = jnp.empty_like(v)
176+
def v_at_state(i):
177+
sa = state_action_values(model, i, v)
178+
return jnp.max(sa)
181179
182-
# Update all elements of v_next
183-
for j in range(n):
184-
v_next = v_next.at[j].set(jnp.max(state_action_values(model, j, v)))
180+
indices = jnp.arange(n)
181+
v_new = jax.vmap(v_at_state)(indices)
182+
return v_new
185183
186-
error = jnp.max(jnp.abs(v_next - v))
187-
return v_next, i + 1, error
184+
@jax.jit
185+
def vfi(model, tol=1e-5, max_iter=500):
186+
187+
v0 = model.w / (1.0 - model.β)
188188
189-
def cond_fun(state):
190-
v, i, error = state
191-
return (error > eps) & (i < max_iter)
189+
def body_fun(state):
190+
v, i, err = state
191+
v_new = update(model, v)
192+
err_new = jnp.max(jnp.abs(v_new - v))
193+
return v_new, i + 1, err_new
192194
193-
# Initial state: (v, iteration, error)
194-
init_state = (v_init, 0, eps + 1)
195-
final_v, final_i, final_error = jax.lax.while_loop(cond_fun, body_fun, init_state)
195+
def cond_fun(state):
196+
_, i, err = state
197+
return (err > tol) & (i < max_iter)
196198
197-
flag = jnp.where(final_error <= eps, 1, 0)
198-
return final_v, flag
199+
init_state = (v0, 0, tol + 1.0)
200+
v_final, iters, err = jax.lax.while_loop(cond_fun, body_fun, init_state)
201+
converged = jnp.where(err <= tol, 1, 0)
202+
return v_final, converged
199203
200204
def plot_value_function_seq(mcm, ax, num_plots=8):
201205
"""
@@ -220,7 +224,7 @@ def plot_value_function_seq(mcm, ax, num_plots=8):
220224
221225
```{code-cell} ipython3
222226
mcm = create_mccall_model()
223-
valfunc_VFI, flag = VFI(mcm)
227+
valfunc_VFI, converged = vfi(mcm)
224228
225229
fig, ax = plt.subplots(figsize=(10,6))
226230
ax.set_xlabel('wage')
@@ -687,7 +691,7 @@ plt.show()
687691
```{code-cell} ipython3
688692
# VFI
689693
mcm = create_mccall_model(w=w_new, q=q_new)
690-
valfunc_VFI, flag = VFI(mcm)
694+
valfunc_VFI, converged = vfi(mcm)
691695
valfunc_VFI
692696
```
693697

0 commit comments

Comments
 (0)