Skip to content

Commit 5957e45

Browse files
jstacclaude
andauthored
Improve McCall model lecture: Fix beta unpacking and enhance readability (#694)
Made several improvements to the mccall_model.md lecture to enhance code clarity and pedagogical value: **Key changes:** 1. Fixed beta unpacking in Comparative statics section - now unpacks model parameters before use 2. Improved derivation of nonlinear equation in h under "Take 2" - added step-by-step algebraic transformations 3. Enhanced JAX solution in Exercise 1 - extracted wage drawing logic into documented draw_wage() function 4. Broke vectorization into clear steps - separated vmap calls with explanatory comments 5. Simplified Exercise 2 solution - removed redundant keyword arguments and meshgrid, unpacked model directly 6. Made while_loop pattern clearer - explicitly created final_state before unpacking All changes tested and verified to run correctly. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude <noreply@anthropic.com>
1 parent 374bf14 commit 5957e45

File tree

2 files changed

+1143
-19
lines changed

2 files changed

+1143
-19
lines changed

lectures/mccall_model.md

Lines changed: 75 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -498,27 +498,29 @@ colors = prop_cycle.by_key()['color']
498498
499499
# Plot the wage offer distribution
500500
ax.plot(w, q, '-', alpha=0.6, lw=2,
501-
label='wage offer distribution',
501+
label='wage offer distribution',
502502
color=colors[0])
503503
504504
# Compute reservation wage with default beta
505505
model_default = McCallModel()
506-
v_init = model_default.w / (1 - model_default.β)
506+
c, β, w, q = model_default
507+
v_init = w / (1 - β)
507508
v_default, res_wage_default = compute_reservation_wage(
508509
model_default, v_init
509510
)
510511
511512
# Compute reservation wage with lower beta
512513
β_new = 0.96
513514
model_low_beta = McCallModel(β=β_new)
514-
v_init_low = model_low_beta.w / (1 - model_low_beta.β)
515+
c, β_low, w, q = model_low_beta
516+
v_init_low = w / (1 - β_low)
515517
v_low, res_wage_low = compute_reservation_wage(
516518
model_low_beta, v_init_low
517519
)
518520
519521
# Plot vertical lines for reservation wages
520522
ax.axvline(x=res_wage_default, color=colors[1], lw=2,
521-
label=f'reservation wage (β={model_default.β})')
523+
label=f'reservation wage (β={β})')
522524
ax.axvline(x=res_wage_low, color=colors[2], lw=2,
523525
label=f'reservation wage (β={β_new})')
524526
@@ -621,12 +623,43 @@ Let $h$ denote the continuation value:
621623

622624
The Bellman equation can now be written as
623625

624-
$$
626+
```{math}
627+
:label: j1b
628+
625629
v^*(w')
626630
= \max \left\{ \frac{w'}{1 - \beta}, \, h \right\}
631+
```
632+
633+
Now let's derive a nonlinear equation for $h$ alone.
634+
635+
Starting from {eq}`j1b`, we multiply both sides by $q(w')$ to get
636+
637+
$$
638+
v^*(w') q(w') = \max \left\{ \frac{w'}{1 - \beta}, h \right\} q(w')
639+
$$
640+
641+
Next, we sum both sides over $w' \in \mathbb{W}$:
642+
643+
$$
644+
\sum_{w' \in \mathbb W} v^*(w') q(w')
645+
= \sum_{w' \in \mathbb W} \max \left\{ \frac{w'}{1 - \beta}, h \right\} q(w')
646+
$$
647+
648+
Now multiply both sides by $\beta$:
649+
650+
$$
651+
\beta \sum_{w' \in \mathbb W} v^*(w') q(w')
652+
= \beta \sum_{w' \in \mathbb W} \max \left\{ \frac{w'}{1 - \beta}, h \right\} q(w')
627653
$$
628654

629-
Substituting this last equation into {eq}`j1` gives
655+
Add $c$ to both sides:
656+
657+
$$
658+
c + \beta \sum_{w' \in \mathbb W} v^*(w') q(w')
659+
= c + \beta \sum_{w' \in \mathbb W} \max \left\{ \frac{w'}{1 - \beta}, h \right\} q(w')
660+
$$
661+
662+
Finally, using the definition of $h$ from {eq}`j1`, the left-hand side is just $h$, giving us
630663

631664
```{math}
632665
:label: j2
@@ -638,7 +671,7 @@ Substituting this last equation into {eq}`j1` gives
638671
\right\} q (w')
639672
```
640673

641-
This is a nonlinear equation that we can solve for $h$.
674+
This is a nonlinear equation in the single scalar $h$ that we can solve for $h$.
642675

643676
As before, we will use successive approximations:
644677

@@ -781,8 +814,28 @@ plt.show()
781814
And here's a solution using JAX.
782815

783816
```{code-cell} ipython3
817+
# First, we set up a function to draw random wage offers from the distribution.
818+
# We use the inverse transform method: draw a uniform random variable u,
819+
# then find the smallest wage w such that the CDF at w is >= u.
784820
cdf = jnp.cumsum(q_default)
785821
822+
def draw_wage(uniform_rv):
823+
"""
824+
Draw a wage from the distribution q_default using the inverse transform method.
825+
826+
Parameters:
827+
-----------
828+
uniform_rv : float
829+
A uniform random variable on [0, 1]
830+
831+
Returns:
832+
--------
833+
wage : float
834+
A wage drawn from w_default with probabilities given by q_default
835+
"""
836+
return w_default[jnp.searchsorted(cdf, uniform_rv)]
837+
838+
786839
def compute_stopping_time(w_bar, key):
787840
"""
788841
Compute stopping time by drawing wages until one exceeds `w_bar`.
@@ -791,7 +844,7 @@ def compute_stopping_time(w_bar, key):
791844
t, key, accept = loop_state
792845
key, subkey = jax.random.split(key)
793846
u = jax.random.uniform(subkey)
794-
w = w_default[jnp.searchsorted(cdf, u)]
847+
w = draw_wage(u)
795848
accept = w >= w_bar
796849
t = t + 1
797850
return t, key, accept
@@ -831,7 +884,8 @@ def compute_stop_time_for_c(c):
831884
return compute_mean_stopping_time(w_bar)
832885
833886
# Vectorize across all c values
834-
stop_times = jax.vmap(compute_stop_time_for_c)(c_vals)
887+
compute_stop_time_vectorized = jax.vmap(compute_stop_time_for_c)
888+
stop_times = compute_stop_time_vectorized(c_vals)
835889
836890
fig, ax = plt.subplots()
837891
@@ -928,12 +982,12 @@ def create_mccall_continuous(
928982
key = jax.random.PRNGKey(seed)
929983
s = jax.random.normal(key, (mc_size,))
930984
w_draws = jnp.exp(μ + σ * s)
931-
return McCallModelContinuous(c=c, β, σ, μ=μ, w_draws=w_draws)
985+
return McCallModelContinuous(c, β, σ, μ, w_draws)
932986
933987
934988
@jax.jit
935989
def compute_reservation_wage_continuous(model, max_iter=500, tol=1e-5):
936-
c, β, σ, μ, w_draws = model.c, model.β, model.σ, model.μ, model.w_draws
990+
c, β, σ, μ, w_draws = model
937991
938992
h = jnp.mean(w_draws) / (1 - β) # initial guess
939993
@@ -949,8 +1003,9 @@ def compute_reservation_wage_continuous(model, max_iter=500, tol=1e-5):
9491003
return jnp.logical_and(i < max_iter, error > tol)
9501004
9511005
initial_state = (h, 0, tol + 1)
952-
h_final, _, _ = jax.lax.while_loop(cond, update, initial_state)
953-
1006+
final_state = jax.lax.while_loop(cond, update, initial_state)
1007+
h_final, _, _ = final_state
1008+
9541009
# Now compute the reservation wage
9551010
return (1 - β) * h_final
9561011
```
@@ -969,12 +1024,13 @@ def compute_R_element(c, β):
9691024
model = create_mccall_continuous(c=c, β=β)
9701025
return compute_reservation_wage_continuous(model)
9711026
972-
# Create meshgrid and vectorize computation
973-
c_grid, β_grid = jnp.meshgrid(c_vals, β_vals, indexing='ij')
974-
compute_R_vectorized = jax.vmap(
975-
jax.vmap(compute_R_element,
976-
in_axes=(None, 0)),
977-
in_axes=(0, None))
1027+
# First, vectorize over β (holding c fixed)
1028+
compute_R_over_β = jax.vmap(compute_R_element, in_axes=(None, 0))
1029+
1030+
# Next, vectorize over c (applying the above function to each c)
1031+
compute_R_vectorized = jax.vmap(compute_R_over_β, in_axes=(0, None))
1032+
1033+
# Apply to compute the full grid
9781034
R = compute_R_vectorized(c_vals, β_vals)
9791035
```
9801036

0 commit comments

Comments
 (0)