Skip to content

Commit 0d564d4

Browse files
jstacclaudemmcky
authored
Fix JAX compatibility issues in Job Search III lecture (#687)
* Fix JAX compatibility issues in Job Search III lecture Updated mccall_model_with_sep_markov.md to fix several JAX-related issues: - Refactored vfi() to return only v_final instead of tuple, making it more consistent with VFI pattern - Removed separate successive_approx() function and integrated iteration logic directly into vfi() - Fixed JAX decorators: changed @jit to @jax.jit and @partial(jit, ...) to @partial(jax.jit, ...) - Rewrote get_reservation_wage() to use jnp.argmax() instead of jnp.where() to avoid JAX concretization errors in JIT compilation - Updated all vfi() call sites to explicitly compute policy with get_greedy(v_star, model) - Removed @jit decorators from T() and get_greedy() functions (not needed) Also improved wording in mccall_model_with_separation.md for clarity. Tested: Converted to Python and ran successfully without errors. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Standardize continuation value notation in Job Search II lecture Updated the McCall model with separation to use h = u(c) + β * sum_w v_u(w) q(w) as the continuation value, matching the notation from the basic McCall model lecture. This makes the progression between lectures more intuitive for readers: - Basic model: h = c + β * sum_w v*(w) q(w) - Separation model: h = u(c) + β * sum_w v_u(w) q(w) Key changes: - Replaced scalar d with h throughout mathematical derivations - Updated closed-form expression for v_e(w) to use h - Modified iteration algorithm to solve for h instead of d - Simplified Bellman equations using h notation - Updated all code functions (compute_v_e, update_h, solve_model) - Changed plots and comments to reference h This improves consistency across the job search lecture series and makes the mathematical structure clearer by explicitly showing the continuation value includes both current utility and discounted future value. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Remove JAX from pip install in Job Search II lecture The JAX library is already included in the base environment and doesn't need to be explicitly installed, which was causing build failures. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Remove JAX from pip install in Job Search III lecture The JAX library is already included in the base environment and doesn't need to be explicitly installed, which was causing build failures. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Update PyTorch installation to use cu121 * Update CI workflow to install only PyTorch packages Removed 'torchaudio' from the installation command for PyTorch. * fix: JAX version pinning * install new CUDANN binaries * remove pytorch * misc * Update mccall_model_with_separation: Change utility parameter from sigma to gamma and use glue for figures Updated the McCall model with separation lecture with the following changes: Key changes: - Changed utility function parameter from σ (sigma) to γ (gamma) - Moved γ default value from utility function to Model class (γ: float = 2.0) - Updated all functions (compute_v_e, update_h, solve_model) to pass γ parameter - Simplified model unpacking to use tuple unpacking directly (e.g., α, β, γ, c, w, q = model) - Replaced static PNG figures with myst-nb glue functionality - Added glue import and glue() calls in exercise solutions - Converted {figure} directives to {glue:figure} directives for dynamic figure generation Benefits: - More consistent parameter naming (gamma is standard for CRRA utility) - Better code organization with parameter defaults in Model class - Cleaner unpacking syntax - Dynamic figure generation eliminates need for static PNG files - Figures automatically stay in sync with code 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Remove redundant PNG files for mccall_model_with_separation Removed static PNG files that are now dynamically generated using myst-nb glue: - mccall_resw_alpha.png - mccall_resw_beta.png - mccall_resw_c.png These figures are now generated from the exercise solution code and displayed via glue:figure directives, eliminating the need for static files and ensuring figures always match the code. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * misc --------- Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: mmcky <mamckay@gmail.com>
1 parent b61b30e commit 0d564d4

File tree

6 files changed

+214
-184
lines changed

6 files changed

+214
-184
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ jobs:
3131
- name: Install JAX, Numpyro, PyTorch
3232
shell: bash -l {0}
3333
run: |
34-
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
35-
pip install pyro-ppl
36-
pip install --upgrade "jax[cuda12-local]==0.6.2"
34+
# pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
35+
# pip install pyro-ppl
36+
pip install "jax[cuda12-local]==0.6.2"
3737
pip install numpyro pyro-ppl
3838
python scripts/test-jax-install.py
3939
- name: Check nvidia Drivers
Binary file not shown.
Binary file not shown.
Binary file not shown.

lectures/mccall_model_with_sep_markov.md

Lines changed: 61 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ jupytext:
44
extension: .md
55
format_name: myst
66
format_version: 0.13
7-
jupytext_version: 1.17.1
7+
jupytext_version: 1.17.2
88
kernelspec:
9-
name: python3
109
display_name: Python 3 (ipykernel)
1110
language: python
11+
name: python3
1212
---
1313

1414
(mccall_with_sep_markov)=
@@ -49,7 +49,7 @@ libraries
4949
```{code-cell} ipython3
5050
:tags: [hide-output]
5151
52-
!pip install quantecon jax
52+
!pip install quantecon
5353
```
5454

5555
We use the following imports:
@@ -58,7 +58,7 @@ We use the following imports:
5858
from quantecon.markov import tauchen
5959
import jax.numpy as jnp
6060
import jax
61-
from jax import jit, lax
61+
from jax import lax
6262
from typing import NamedTuple
6363
import matplotlib.pyplot as plt
6464
from functools import partial
@@ -138,48 +138,11 @@ The optimal policy turns out to be a reservation wage strategy: accept all wages
138138

139139
## Code
140140

141-
142-
First, we implement the successive approximation algorithm.
143-
144-
This algorithm takes an operator $T$ and an initial condition and iterates until
145-
convergence.
146-
147-
We will use it for value function iteration.
148-
149-
```{code-cell} ipython3
150-
@partial(jit, static_argnums=(0,))
151-
def successive_approx(
152-
T, # Operator (callable) - marked as static
153-
x_0, # Initial condition
154-
tolerance: float = 1e-6, # Error tolerance
155-
max_iter: int = 100_000, # Max iteration bound
156-
):
157-
"""Computes the approximate fixed point of T via successive
158-
approximation using lax.while_loop."""
159-
160-
def cond_fn(carry):
161-
x, error, k = carry
162-
return (error > tolerance) & (k <= max_iter)
163-
164-
def body_fn(carry):
165-
x, error, k = carry
166-
x_new = T(x)
167-
error = jnp.max(jnp.abs(x_new - x))
168-
return (x_new, error, k + 1)
169-
170-
initial_carry = (x_0, tolerance + 1, 1)
171-
x_final, _, _ = lax.while_loop(cond_fn, body_fn, initial_carry)
172-
173-
return x_final
174-
```
175-
176-
177-
Next let's set up a `Model` class to store information needed to solve the model.
141+
Let's set up a `Model` class to store information needed to solve the model.
178142

179143
We include `P_cumsum`, the row-wise cumulative sum of the transition matrix, to
180144
optimize the simulation -- the details are explained below.
181145

182-
183146
```{code-cell} ipython3
184147
class Model(NamedTuple):
185148
n: int
@@ -215,7 +178,6 @@ def create_js_with_sep_model(
215178
Here's the Bellman operator for the unemployed worker's value function:
216179

217180
```{code-cell} ipython3
218-
@jit
219181
def T(v: jnp.ndarray, model: Model) -> jnp.ndarray:
220182
"""The Bellman operator for the value of being unemployed."""
221183
n, w_vals, P, P_cumsum, β, c, α = model
@@ -229,7 +191,6 @@ The next function computes the optimal policy under the assumption that $v$ is
229191
the value function:
230192

231193
```{code-cell} ipython3
232-
@jit
233194
def get_greedy(v: jnp.ndarray, model: Model) -> jnp.ndarray:
234195
"""Get a v-greedy policy."""
235196
n, w_vals, P, P_cumsum, β, c, α = model
@@ -247,14 +208,34 @@ The second routine requires a policy function, which we will typically obtain by
247208
applying the `vfi` function.
248209

249210
```{code-cell} ipython3
250-
def vfi(model: Model):
251-
"""Solve by VFI."""
211+
@jax.jit
212+
def vfi(
213+
model: Model,
214+
tolerance: float = 1e-6, # Error tolerance
215+
max_iter: int = 100_000, # Max iteration bound
216+
):
217+
252218
v_init = jnp.zeros(model.w_vals.shape)
253-
v_star = successive_approx(lambda v: T(v, model), v_init)
254-
σ_star = get_greedy(v_star, model)
255-
return v_star, σ_star
219+
220+
def cond(loop_state):
221+
v, error, i = loop_state
222+
return (error > tolerance) & (i <= max_iter)
223+
224+
def update(loop_state):
225+
v, error, i = loop_state
226+
v_new = T(v, model)
227+
error = jnp.max(jnp.abs(v_new - v))
228+
new_loop_state = v_new, error, i + 1
229+
return new_loop_state
230+
231+
initial_state = (v_init, tolerance + 1, 1)
232+
final_loop_state = lax.while_loop(cond, update, initial_state)
233+
v_final, error, i = final_loop_state
256234
235+
return v_final
257236
237+
238+
@jax.jit
258239
def get_reservation_wage(σ: jnp.ndarray, model: Model) -> float:
259240
"""
260241
Calculate the reservation wage from a given policy.
@@ -268,25 +249,24 @@ def get_reservation_wage(σ: jnp.ndarray, model: Model) -> float:
268249
"""
269250
n, w_vals, P, P_cumsum, β, c, α = model
270251
271-
# Find all wage indices where policy indicates acceptance
272-
accept_indices = jnp.where(σ == 1)[0]
273-
274-
if len(accept_indices) == 0:
275-
return jnp.inf # Agent never accepts any wage
252+
# Find the first index where policy indicates acceptance
253+
# σ is a boolean array, argmax returns the first True value
254+
first_accept_idx = jnp.argmax(σ)
276255
277-
# Return the lowest wage that is accepted
278-
return w_vals[accept_indices[0]]
256+
# If no acceptance (all False), return infinity
257+
# Otherwise return the wage at the first acceptance index
258+
return jnp.where(jnp.any(σ), w_vals[first_accept_idx], jnp.inf)
279259
```
280260

281-
282261
## Computing the Solution
283262

284263
Let's solve the model:
285264

286265
```{code-cell} ipython3
287266
model = create_js_with_sep_model()
288267
n, w_vals, P, P_cumsum, β, c, α = model
289-
v_star, σ_star = vfi(model)
268+
v_star = vfi(model)
269+
σ_star = get_greedy(v_star, model)
290270
```
291271

292272
Next we compute some related quantities, including the reservation wage.
@@ -312,19 +292,18 @@ ax.set_xlabel(r"$w$")
312292
plt.show()
313293
```
314294

315-
316295
## Sensitivity Analysis
317296

318297
Let's examine how reservation wages change with the separation rate.
319298

320-
321299
```{code-cell} ipython3
322300
α_vals: jnp.ndarray = jnp.linspace(0.0, 1.0, 10)
323301
324302
w_star_vec = jnp.empty_like(α_vals)
325303
for (i_α, α) in enumerate(α_vals):
326304
model = create_js_with_sep_model(α=α)
327-
v_star, σ_star = vfi(model)
305+
v_star = vfi(model)
306+
σ_star = get_greedy(v_star, model)
328307
w_star = get_reservation_wage(σ_star, model)
329308
w_star_vec = w_star_vec.at[i_α].set(w_star)
330309
@@ -356,9 +335,8 @@ This is implemented via `jnp.searchsorted` on the precomputed cumulative sum
356335

357336
The function `update_agent` advances the agent's state by one period.
358337

359-
360338
```{code-cell} ipython3
361-
@jit
339+
@jax.jit
362340
def update_agent(key, is_employed, wage_idx, model, σ):
363341
"""
364342
Updates an agent by one period. Updates their employment status and their
@@ -439,7 +417,8 @@ Let's create a comprehensive plot of the employment simulation:
439417
model = create_js_with_sep_model()
440418
441419
# Calculate reservation wage for plotting
442-
v_star, σ_star = vfi(model)
420+
v_star = vfi(model)
421+
σ_star = get_greedy(v_star, model)
443422
w_star = get_reservation_wage(σ_star, model)
444423
445424
wage_path, employment_status = simulate_employment_path(model, σ_star)
@@ -486,7 +465,6 @@ plt.tight_layout()
486465
plt.show()
487466
```
488467

489-
490468
The simulation helps to visualize outcomes associated with this model.
491469

492470
The agent follows a reservation wage strategy.
@@ -531,7 +509,7 @@ This holds because:
531509

532510
These properties ensure the chain is ergodic with a unique stationary distribution $\pi$ over states $(s, w)$.
533511

534-
For an ergodic Markov chain, the ergodic theorem guarantees that time averages = ensemble averages.
512+
For an ergodic Markov chain, the ergodic theorem guarantees that time averages = cross-sectional averages.
535513

536514
In particular, the fraction of time a single agent spends unemployed (across all
537515
wage states) converges to the cross-sectional unemployment rate:
@@ -568,7 +546,7 @@ update_agents_vmap = jax.vmap(
568546
Next we define the core simulation function, which uses `lax.fori_loop` to efficiently iterate many agents forward in time:
569547

570548
```{code-cell} ipython3
571-
@partial(jit, static_argnums=(3, 4))
549+
@partial(jax.jit, static_argnums=(3, 4))
572550
def _simulate_cross_section_compiled(
573551
key: jnp.ndarray,
574552
model: Model,
@@ -627,7 +605,8 @@ def simulate_cross_section(
627605
key = jax.random.PRNGKey(seed)
628606
629607
# Solve for optimal policy
630-
v_star, σ_star = vfi(model)
608+
v_star = vfi(model)
609+
σ_star = get_greedy(v_star, model)
631610
632611
# Run JIT-compiled simulation
633612
final_employment = _simulate_cross_section_compiled(
@@ -655,7 +634,8 @@ def plot_cross_sectional_unemployment(model: Model, t_snapshot: int = 200,
655634
"""
656635
# Get final employment state directly
657636
key = jax.random.PRNGKey(42)
658-
v_star, σ_star = vfi(model)
637+
v_star = vfi(model)
638+
σ_star = get_greedy(v_star, model)
659639
final_employment = _simulate_cross_section_compiled(
660640
key, model, σ_star, n_agents, t_snapshot
661641
)
@@ -681,7 +661,12 @@ def plot_cross_sectional_unemployment(model: Model, t_snapshot: int = 200,
681661
plt.show()
682662
```
683663

684-
Now let's compare the time-average unemployment rate (from a single agent's long simulation) with the cross-sectional unemployment rate (from many agents at a single point in time):
664+
Now let's compare the time-average unemployment rate (from a single agent's long simulation) with the cross-sectional unemployment rate (from many agents at a single point in time).
665+
666+
We claimed above that these numbers will be approximately equal in large
667+
samples, due to ergodicity.
668+
669+
Let's see if that's true.
685670

686671
```{code-cell} ipython3
687672
model = create_js_with_sep_model()
@@ -697,28 +682,31 @@ print(f"Cross-sectional unemployment rate (at t=200): "
697682
print(f"Difference: {abs(time_avg_unemp - cross_sectional_unemp):.4f}")
698683
```
699684

685+
Indeed, they are very close.
686+
700687
Now let's visualize the cross-sectional distribution:
701688

702689
```{code-cell} ipython3
703690
plot_cross_sectional_unemployment(model)
704691
```
705692

706-
## Cross-Sectional Analysis with Lower Unemployment Compensation (c=0.5)
693+
## Lower Unemployment Compensation (c=0.5)
707694

708-
Let's examine how the cross-sectional unemployment rate changes with lower unemployment compensation:
695+
What happens to the cross-sectional unemployment rate with lower unemployment compensation?
709696

710697
```{code-cell} ipython3
711698
model_low_c = create_js_with_sep_model(c=0.5)
712699
plot_cross_sectional_unemployment(model_low_c)
713700
```
714701

702+
715703
## Exercises
716704

717705
```{exercise-start}
718706
:label: mmwsm_ex1
719707
```
720708

721-
Create a plot that shows how the steady state cross-sectional unemployment rate
709+
Create a plot that investigates more carefully how the steady state cross-sectional unemployment rate
722710
changes with unemployment compensation.
723711

724712
```{exercise-end}
@@ -751,4 +739,3 @@ plt.show()
751739

752740
```{solution-end}
753741
```
754-

0 commit comments

Comments
 (0)