124
124
125
125
Let's use Python code from {doc}`this quantecon lecture <mccall_model>`.
126
126
127
- We use a Python method called `VFI ` to compute the optimal value function using value function iterations.
127
+ We use a Python method called `vfi ` to compute the optimal value function using value function iterations.
128
128
129
129
We construct an assumed distribution of wages and plot it with the following Python code
130
130
@@ -170,32 +170,36 @@ def state_action_values(model, i, v):
170
170
return jnp.array([accept, reject])
171
171
172
172
@jax.jit
173
- def VFI(model, eps=1e-5, max_iter=500):
174
- """Find the optimal value function."""
175
- n = len(model.w)
176
- v_init = model.w / (1 - model.β)
173
+ def update(model, v):
174
+ n = model.w.shape[0]
177
175
178
- def body_fun(state ):
179
- v , i, error = state
180
- v_next = jnp.empty_like(v )
176
+ def v_at_state(i ):
177
+ sa = state_action_values(model , i, v)
178
+ return jnp.max(sa )
181
179
182
- # Update all elements of v_next
183
- for j in range(n):
184
- v_next = v_next.at[j].set(jnp.max(state_action_values(model, j, v)))
180
+ indices = jnp.arange(n)
181
+ v_new = jax.vmap(v_at_state)(indices)
182
+ return v_new
185
183
186
- error = jnp.max(jnp.abs(v_next - v))
187
- return v_next, i + 1, error
184
+ @jax.jit
185
+ def vfi(model, tol=1e-5, max_iter=500):
186
+
187
+ v0 = model.w / (1.0 - model.β)
188
188
189
- def cond_fun(state):
190
- v, i, error = state
191
- return (error > eps) & (i < max_iter)
189
+ def body_fun(state):
190
+ v, i, err = state
191
+ v_new = update(model, v)
192
+ err_new = jnp.max(jnp.abs(v_new - v))
193
+ return v_new, i + 1, err_new
192
194
193
- # Initial state: (v, iteration, error)
194
- init_state = (v_init, 0, eps + 1)
195
- final_v, final_i, final_error = jax.lax.while_loop(cond_fun, body_fun, init_state )
195
+ def cond_fun(state):
196
+ _, i, err = state
197
+ return (err > tol) & (i < max_iter )
196
198
197
- flag = jnp.where(final_error <= eps, 1, 0)
198
- return final_v, flag
199
+ init_state = (v0, 0, tol + 1.0)
200
+ v_final, iters, err = jax.lax.while_loop(cond_fun, body_fun, init_state)
201
+ converged = jnp.where(err <= tol, 1, 0)
202
+ return v_final, converged
199
203
200
204
def plot_value_function_seq(mcm, ax, num_plots=8):
201
205
"""
@@ -220,7 +224,7 @@ def plot_value_function_seq(mcm, ax, num_plots=8):
220
224
221
225
```{code-cell} ipython3
222
226
mcm = create_mccall_model()
223
- valfunc_VFI, flag = VFI (mcm)
227
+ valfunc_VFI, converged = vfi (mcm)
224
228
225
229
fig, ax = plt.subplots(figsize=(10,6))
226
230
ax.set_xlabel('wage')
@@ -687,7 +691,7 @@ plt.show()
687
691
```{code-cell} ipython3
688
692
# VFI
689
693
mcm = create_mccall_model(w=w_new, q=q_new)
690
- valfunc_VFI, flag = VFI (mcm)
694
+ valfunc_VFI, converged = vfi (mcm)
691
695
valfunc_VFI
692
696
```
693
697
0 commit comments