@@ -57,7 +57,7 @@ $$ W_{t+1} + C_t \leq R W_t + Y_t $$
5757
5858We assume that labor income $(Y_t)$ is a discretized AR(1) process.
5959
60- The right-hand side of the Bellman equation is
60+ The right-hand side of the Bellman equation is
6161
6262$$ B((w, y), w', v) = u(Rw + y - w') + β \sum_{y'} v(w', y') Q(y, y'). $$
6363
@@ -75,7 +75,7 @@ def successive_approx(T, # Operator (callable)
7575 tolerance=1e-6, # Error tolerance
7676 max_iter=10_000, # Max iteration bound
7777 print_step=25, # Print at multiples
78- verbose=False):
78+ verbose=False):
7979 x = x_0
8080 error = tolerance + 1
8181 k = 1
@@ -98,7 +98,7 @@ def successive_approx(T, # Operator (callable)
9898Here’s a ` namedtuple ` definition for storing parameters and grids.
9999
100100``` {code-cell} ipython3
101- Model = namedtuple('Model',
101+ Model = namedtuple('Model',
102102 ('β', 'R', 'γ', 'w_grid', 'y_grid', 'Q'))
103103```
104104
@@ -114,7 +114,7 @@ def create_consumption_model(R=1.01, # Gross interest rate
114114 A function that takes in parameters and returns an instance of Model that
115115 contains data for the optimal savings problem.
116116 """
117- w_grid = jnp.linspace(w_min, w_max, w_size)
117+ w_grid = jnp.linspace(w_min, w_max, w_size)
118118 mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν)
119119 y_grid, Q = jnp.exp(mc.state_values), mc.P
120120 return Model(β=β, R=R, γ=γ, w_grid=w_grid, y_grid=y_grid, Q=Q)
@@ -129,7 +129,7 @@ def create_consumption_model_jax(R=1.01, # Gross interest rate
129129 w_size=150, # Grid side
130130 ρ=0.9, ν=0.1, y_size=100): # Income parameters
131131 """
132- A function that takes in parameters and returns a JAX-compatible version of
132+ A function that takes in parameters and returns a JAX-compatible version of
133133 Model that contains data for the optimal savings problem.
134134 """
135135 w_grid = jnp.linspace(w_min, w_max, w_size)
@@ -146,15 +146,15 @@ Here's the right hand side of the Bellman equation:
146146``` {code-cell} ipython3
147147def B(v, constants, sizes, arrays):
148148 """
149- A vectorized version of the right-hand side of the Bellman equation
149+ A vectorized version of the right-hand side of the Bellman equation
150150 (before maximization), which is a 3D array representing
151151
152152 B(w, y, w′) = u(Rw + y - w′) + β Σ_y′ v(w′, y′) Q(y, y′)
153153
154154 for all (w, y, w′).
155155 """
156156
157- # Unpack
157+ # Unpack
158158 β, R, γ = constants
159159 w_size, y_size = sizes
160160 w_grid, y_grid, Q = arrays
@@ -215,15 +215,15 @@ def T_σ(v, σ, constants, sizes, arrays):
215215 yp_idx = jnp.arange(y_size)
216216 yp_idx = jnp.reshape(yp_idx, (1, 1, y_size))
217217 σ = jnp.reshape(σ, (w_size, y_size, 1))
218- V = v[σ, yp_idx]
218+ V = v[σ, yp_idx]
219219
220- # Convert Q[j, jp] to Q[i, j, jp]
220+ # Convert Q[j, jp] to Q[i, j, jp]
221221 Q = jnp.reshape(Q, (1, y_size, y_size))
222222
223223 # Calculate the expected sum Σ_jp v[σ[i, j], jp] * Q[i, j, jp]
224224 Ev = jnp.sum(V * Q, axis=2)
225225
226- return r_σ + β * jnp.sum(V * Q, axis=2)
226+ return r_σ + β * Ev
227227```
228228
229229and the Bellman operator $T$
@@ -248,9 +248,9 @@ The basic problem is to solve the linear system
248248
249249$$ v(w,y ) = u(Rw + y - \sigma(w, y)) + β \sum_{y'} v(\sigma(w, y), y') Q(y, y) $$
250250
251- for $v$.
251+ for $v$.
252252
253- It turns out to be helpful to rewrite this as
253+ It turns out to be helpful to rewrite this as
254254
255255$$ v(w,y) = r(w, y, \sigma(w, y)) + β \sum_{w', y'} v(w', y') P_\sigma(w, y, w', y') $$
256256
@@ -260,40 +260,23 @@ We want to write this as $v = r_\sigma + P_\sigma v$ and then solve for $v$
260260
261261Note, however,
262262
263- * $v$ is a 2 index array, rather than a single vector.
264- * $P_ \sigma$ has four indices rather than 2
263+ * $v$ is a 2 index array, rather than a single vector.
264+ * $P_ \sigma$ has four indices rather than 2
265265
266- The code below
266+ The code below
267267
2682681 . reshapes $v$ and $r_ \sigma$ to 1D arrays and $P_ \sigma$ to a matrix
2692692 . solves the linear system
2702703 . converts back to multi-index arrays.
271271
272- ``` {code-cell} ipython3
273- def get_value(σ, constants, sizes, arrays):
274- "Get the value v_σ of policy σ by inverting the linear map R_σ."
275-
276- # Unpack
277- β, R, γ = constants
278- w_size, y_size = sizes
279- w_grid, y_grid, Q = arrays
280-
281- r_σ = compute_r_σ(σ, constants, sizes, arrays)
282-
283- # Reduce R_σ to a function in v
284- partial_R_σ = lambda v: R_σ(v, σ, constants, sizes, arrays)
285-
286- return jax.scipy.sparse.linalg.bicgstab(partial_R_σ, r_σ)[0]
287- ```
288-
289272``` {code-cell} ipython3
290273def R_σ(v, σ, constants, sizes, arrays):
291274 """
292- The value v_σ of a policy σ is defined as
275+ The value v_σ of a policy σ is defined as
293276
294277 v_σ = (I - β P_σ)^{-1} r_σ
295278
296- Here we set up the linear map v -> R_σ v, where R_σ := I - β P_σ.
279+ Here we set up the linear map v -> R_σ v, where R_σ := I - β P_σ.
297280
298281 In the consumption problem, this map can be expressed as
299282
@@ -322,6 +305,23 @@ def R_σ(v, σ, constants, sizes, arrays):
322305 return v - β * jnp.sum(V * Q, axis=2)
323306```
324307
308+ ``` {code-cell} ipython3
309+ def get_value(σ, constants, sizes, arrays):
310+ "Get the value v_σ of policy σ by inverting the linear map R_σ."
311+
312+ # Unpack
313+ β, R, γ = constants
314+ w_size, y_size = sizes
315+ w_grid, y_grid, Q = arrays
316+
317+ r_σ = compute_r_σ(σ, constants, sizes, arrays)
318+
319+ # Reduce R_σ to a function in v
320+ partial_R_σ = lambda v: R_σ(v, σ, constants, sizes, arrays)
321+
322+ return jax.scipy.sparse.linalg.bicgstab(partial_R_σ, r_σ)[0]
323+ ```
324+
325325## JIT compiled versions
326326
327327``` {code-cell} ipython3
@@ -354,7 +354,6 @@ def value_iteration(model, tol=1e-5):
354354def policy_iteration(model):
355355 "Howard policy iteration routine."
356356 constants, sizes, arrays = model
357- vz = jnp.zeros(sizes)
358357 σ = jnp.zeros(sizes, dtype=int)
359358 i, error = 0, 1.0
360359 while error > 0:
@@ -387,14 +386,19 @@ def optimistic_policy_iteration(model, tol=1e-5, m=10):
387386Create a JAX model for consumption, perform policy iteration, and plot the resulting optimal policy function.
388387
389388``` {code-cell} ipython3
390- fontsize= 12
389+ fontsize = 12
391390model = create_consumption_model_jax()
392- # Unpack
391+
392+ # Unpack
393393constants, sizes, arrays = model
394394β, R, γ = constants
395395w_size, y_size = sizes
396396w_grid, y_grid, Q = arrays
397+ ```
398+
399+ ``` {code-cell} ipython3
397400σ_star = policy_iteration(model)
401+
398402fig, ax = plt.subplots(figsize=(9, 5.2))
399403ax.plot(w_grid, w_grid, "k--", label="45")
400404ax.plot(w_grid, w_grid[σ_star[:, 1]], label="$\\sigma^*(\cdot, y_1)$")
@@ -443,7 +447,9 @@ def run_algorithm(algorithm, model, **kwargs):
443447 elapsed_time = end_time - start_time
444448 print(f"{algorithm.__name__} completed in {elapsed_time:.2f} seconds.")
445449 return result, elapsed_time
450+ ```
446451
452+ ``` {code-cell} ipython3
447453model = create_consumption_model_jax()
448454σ_pi, pi_time = run_algorithm(policy_iteration, model)
449455σ_vfi, vfi_time = run_algorithm(value_iteration, model, tol=1e-5)
@@ -453,7 +459,9 @@ opi_times = []
453459for m in m_vals:
454460 σ_opi, opi_time = run_algorithm(optimistic_policy_iteration, model, m=m, tol=1e-5)
455461 opi_times.append(opi_time)
462+ ```
456463
464+ ``` {code-cell} ipython3
457465fig, ax = plt.subplots(figsize=(9, 5.2))
458466ax.plot(m_vals, jnp.full(len(m_vals), pi_time), lw=2, label="Howard policy iteration")
459467ax.plot(m_vals, jnp.full(len(m_vals), vfi_time), lw=2, label="value function iteration")
0 commit comments