@@ -88,7 +88,7 @@ def create_consumption_model(
8888 β=0.98, # Discount factor
8989 γ=2, # CRRA parameter
9090 a_min=0.01, # Min assets
91- a_max=5 .0, # Max assets
91+ a_max=10 .0, # Max assets
9292 a_size=150, # Grid size
9393 ρ=0.9, ν=0.1, y_size=100 # Income parameters
9494 ):
@@ -109,52 +109,50 @@ We repeat some functions from {doc}`ifp_discrete`.
109109Here is the right hand side of the Bellman equation:
110110
111111``` {code-cell} ipython3
112- @jax.jit
113- def B(v, model):
112+ def B(v, model, i, j, ip):
114113 """
115- A vectorized version of the right-hand side of the Bellman equation
116- (before maximization), which is a 3D array representing
114+ The right-hand side of the Bellman equation before maximization, which takes
115+ the form
117116
118117 B(a, y, a′) = u(Ra + y - a′) + β Σ_y′ v(a′, y′) Q(y, y′)
119118
120- for all (a, y, a′).
119+ The indices are (i, j, ip) -> (a, y, a′).
121120 """
122-
123- # Unpack
124121 β, R, γ, a_grid, y_grid, Q = model
125- a_size, y_size = len(a_grid), len(y_grid)
126-
127- # Compute current rewards r(a, y, ap) as array r[i, j, ip]
128- a = jnp.reshape(a_grid, (a_size, 1, 1)) # a[i] -> a[i, j, ip]
129- y = jnp.reshape(y_grid, (1, y_size, 1)) # z[j] -> z[i, j, ip]
130- ap = jnp.reshape(a_grid, (1, 1, a_size)) # ap[ip] -> ap[i, j, ip]
122+ a, y, ap = a_grid[i], y_grid[j], a_grid[ip]
131123 c = R * a + y - ap
124+ EV = jnp.sum(v[ip, :] * Q[j, :])
125+ return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf)
126+ ```
132127
133- # Calculate continuation rewards at all combinations of (a, y, ap)
134- v = jnp.reshape(v, (1, 1, a_size, y_size)) # v[ip, jp] -> v[i, j, ip, jp]
135- Q = jnp.reshape(Q, (1, y_size, 1, y_size)) # Q[j, jp] -> Q[i, j, ip, jp]
136- EV = jnp.sum(v * Q, axis=3) # sum over last index jp
128+ Now we successively apply ` vmap ` to vectorize over all indices:
137129
138- # Compute the right-hand side of the Bellman equation
139- return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf)
130+ ``` {code-cell} ipython3
131+ B_1 = jax.vmap(B, in_axes=(None, None, None, None, 0))
132+ B_2 = jax.vmap(B_1, in_axes=(None, None, None, 0, None))
133+ B_vmap = jax.vmap(B_2, in_axes=(None, None, 0, None, None))
140134```
141135
142136Here's the Bellman operator:
143137
144138``` {code-cell} ipython3
145- @jax.jit
146139def T(v, model):
147140 "The Bellman operator."
148- return jnp.max(B(v, model), axis=2)
141+ a_indices = jnp.arange(len(model.a_grid))
142+ y_indices = jnp.arange(len(model.y_grid))
143+ B_values = B_vmap(v, model, a_indices, y_indices, a_indices)
144+ return jnp.max(B_values, axis=-1)
149145```
150146
151147Here's the function that computes a $v$-greedy policy:
152148
153149``` {code-cell} ipython3
154- @jax.jit
155150def get_greedy(v, model):
156151 "Computes a v-greedy policy, returned as a set of indices."
157- return jnp.argmax(B(v, model), axis=2)
152+ a_indices = jnp.arange(len(model.a_grid))
153+ y_indices = jnp.arange(len(model.y_grid))
154+ B_values = B_vmap(v, model, a_indices, y_indices, a_indices)
155+ return jnp.argmax(B_values, axis=-1)
158156```
159157
160158Now we define the policy operator $T_ \sigma$, which is the Bellman operator with
@@ -194,7 +192,6 @@ Apply vmap to vectorize:
194192T_σ_1 = jax.vmap(T_σ, in_axes=(None, None, None, None, 0))
195193T_σ_vmap = jax.vmap(T_σ_1, in_axes=(None, None, None, 0, None))
196194
197- @jax.jit
198195def T_σ_vec(v, σ, model):
199196 """Vectorized version of T_σ."""
200197 a_size, y_size = len(model.a_grid), len(model.y_grid)
@@ -206,7 +203,6 @@ def T_σ_vec(v, σ, model):
206203Now we need a function to apply the policy operator m times:
207204
208205``` {code-cell} ipython3
209- @jax.jit
210206def iterate_policy_operator(σ, v, m, model):
211207 """
212208 Apply the policy operator T_σ exactly m times to v.
@@ -324,9 +320,9 @@ print(f"VFI completed in {vfi_time:.2f} seconds.")
324320Now let's time OPI with different values of m:
325321
326322``` {code-cell} ipython3
327- print("Starting OPI with m=10 .")
323+ print("Starting OPI with m=50 .")
328324start = time()
329- v_star_opi, σ_star_opi = optimistic_policy_iteration(model, m=10 )
325+ v_star_opi, σ_star_opi = optimistic_policy_iteration(model, m=50 )
330326v_star_opi.block_until_ready()
331327opi_time_with_compile = time() - start
332328print(f"OPI completed in {opi_time_with_compile:.2f} seconds.")
@@ -336,7 +332,7 @@ Run it again:
336332
337333``` {code-cell} ipython3
338334start = time()
339- v_star_opi, σ_star_opi = optimistic_policy_iteration(model, m=10 )
335+ v_star_opi, σ_star_opi = optimistic_policy_iteration(model, m=50 )
340336v_star_opi.block_until_ready()
341337opi_time = time() - start
342338print(f"OPI completed in {opi_time:.2f} seconds.")
@@ -345,9 +341,38 @@ print(f"OPI completed in {opi_time:.2f} seconds.")
345341Check that we get the same result:
346342
347343``` {code-cell} ipython3
348- print(f"Policies match: {jnp.allclose(σ_star_vfi, σ_star_opi )}")
344+ print(f"Values match: {jnp.allclose(v_star_vfi, v_star_opi )}")
349345```
350346
347+ The value functions match, confirming both algorithms converge to the same solution.
348+
349+ Let's visually compare the asset dynamics under both policies:
350+
351+ ``` {code-cell} ipython3
352+ fig, axes = plt.subplots(1, 2, figsize=(12, 5))
353+
354+ # VFI policy
355+ for j, label in zip([0, -1], ['low income', 'high income']):
356+ a_next_vfi = model.a_grid[σ_star_vfi[:, j]]
357+ axes[0].plot(model.a_grid, a_next_vfi, label=label)
358+ axes[0].plot(model.a_grid, model.a_grid, 'k--', linewidth=0.5, alpha=0.5)
359+ axes[0].set(xlabel='current assets', ylabel='next period assets', title='VFI')
360+ axes[0].legend()
361+
362+ # OPI policy
363+ for j, label in zip([0, -1], ['low income', 'high income']):
364+ a_next_opi = model.a_grid[σ_star_opi[:, j]]
365+ axes[1].plot(model.a_grid, a_next_opi, label=label)
366+ axes[1].plot(model.a_grid, model.a_grid, 'k--', linewidth=0.5, alpha=0.5)
367+ axes[1].set(xlabel='current assets', ylabel='next period assets', title='OPI')
368+ axes[1].legend()
369+
370+ plt.tight_layout()
371+ plt.show()
372+ ```
373+
374+ The policies are visually indistinguishable, confirming both methods produce the same solution.
375+
351376Here's the speedup:
352377
353378``` {code-cell} ipython3
@@ -384,9 +409,7 @@ plt.show()
384409
385410Here's a summary of the results
386411
387- * When $m=1$, OPI is slight slower than VFI, even though they should be mathematically equivalent, due to small inefficiencies associated with extra function calls.
388-
389- * OPI outperforms VFI for a very large range of $m$ values.
412+ * OPI outperforms VFI for a large range of $m$ values.
390413
391414* For very large $m$, OPI performance begins to degrade as we spend too much
392415 time iterating the policy operator.
0 commit comments