@@ -498,27 +498,29 @@ colors = prop_cycle.by_key()['color']
498498
499499# Plot the wage offer distribution
500500ax.plot(w, q, '-', alpha=0.6, lw=2,
501- label='wage offer distribution',
501+ label='wage offer distribution',
502502 color=colors[0])
503503
504504# Compute reservation wage with default beta
505505model_default = McCallModel()
506- v_init = model_default.w / (1 - model_default.β)
506+ c, β, w, q = model_default
507+ v_init = w / (1 - β)
507508v_default, res_wage_default = compute_reservation_wage(
508509 model_default, v_init
509510)
510511
511512# Compute reservation wage with lower beta
512513β_new = 0.96
513514model_low_beta = McCallModel(β=β_new)
514- v_init_low = model_low_beta.w / (1 - model_low_beta.β)
515+ c, β_low, w, q = model_low_beta
516+ v_init_low = w / (1 - β_low)
515517v_low, res_wage_low = compute_reservation_wage(
516518 model_low_beta, v_init_low
517519)
518520
519521# Plot vertical lines for reservation wages
520522ax.axvline(x=res_wage_default, color=colors[1], lw=2,
521- label=f'reservation wage (β={model_default. β})')
523+ label=f'reservation wage (β={β})')
522524ax.axvline(x=res_wage_low, color=colors[2], lw=2,
523525 label=f'reservation wage (β={β_new})')
524526
@@ -621,12 +623,43 @@ Let $h$ denote the continuation value:
621623
622624The Bellman equation can now be written as
623625
624- $$
626+ ``` {math}
627+ :label: j1b
628+
625629 v^*(w')
626630 = \max \left\{ \frac{w'}{1 - \beta}, \, h \right\}
631+ ```
632+
633+ Now let's derive a nonlinear equation for $h$ alone.
634+
635+ Starting from {eq}` j1b ` , we multiply both sides by $q(w')$ to get
636+
637+ $$
638+ v^*(w') q(w') = \max \left\{ \frac{w'}{1 - \beta}, h \right\} q(w')
639+ $$
640+
641+ Next, we sum both sides over $w' \in \mathbb{W}$:
642+
643+ $$
644+ \sum_{w' \in \mathbb W} v^*(w') q(w')
645+ = \sum_{w' \in \mathbb W} \max \left\{ \frac{w'}{1 - \beta}, h \right\} q(w')
646+ $$
647+
648+ Now multiply both sides by $\beta$:
649+
650+ $$
651+ \beta \sum_{w' \in \mathbb W} v^*(w') q(w')
652+ = \beta \sum_{w' \in \mathbb W} \max \left\{ \frac{w'}{1 - \beta}, h \right\} q(w')
627653$$
628654
629- Substituting this last equation into {eq}` j1 ` gives
655+ Add $c$ to both sides:
656+
657+ $$
658+ c + \beta \sum_{w' \in \mathbb W} v^*(w') q(w')
659+ = c + \beta \sum_{w' \in \mathbb W} \max \left\{ \frac{w'}{1 - \beta}, h \right\} q(w')
660+ $$
661+
662+ Finally, using the definition of $h$ from {eq}` j1 ` , the left-hand side is just $h$, giving us
630663
631664``` {math}
632665:label: j2
@@ -638,7 +671,7 @@ Substituting this last equation into {eq}`j1` gives
638671 \right\} q (w')
639672```
640673
641- This is a nonlinear equation that we can solve for $h$.
674+ This is a nonlinear equation in the single scalar $h$ that we can solve for $h$.
642675
643676As before, we will use successive approximations:
644677
@@ -781,8 +814,28 @@ plt.show()
781814And here's a solution using JAX.
782815
783816``` {code-cell} ipython3
817+ # First, we set up a function to draw random wage offers from the distribution.
818+ # We use the inverse transform method: draw a uniform random variable u,
819+ # then find the smallest wage w such that the CDF at w is >= u.
784820cdf = jnp.cumsum(q_default)
785821
822+ def draw_wage(uniform_rv):
823+ """
824+ Draw a wage from the distribution q_default using the inverse transform method.
825+
826+ Parameters:
827+ -----------
828+ uniform_rv : float
829+ A uniform random variable on [0, 1]
830+
831+ Returns:
832+ --------
833+ wage : float
834+ A wage drawn from w_default with probabilities given by q_default
835+ """
836+ return w_default[jnp.searchsorted(cdf, uniform_rv)]
837+
838+
786839def compute_stopping_time(w_bar, key):
787840 """
788841 Compute stopping time by drawing wages until one exceeds `w_bar`.
@@ -791,7 +844,7 @@ def compute_stopping_time(w_bar, key):
791844 t, key, accept = loop_state
792845 key, subkey = jax.random.split(key)
793846 u = jax.random.uniform(subkey)
794- w = w_default[jnp.searchsorted(cdf, u)]
847+ w = draw_wage(u)
795848 accept = w >= w_bar
796849 t = t + 1
797850 return t, key, accept
@@ -831,7 +884,8 @@ def compute_stop_time_for_c(c):
831884 return compute_mean_stopping_time(w_bar)
832885
833886# Vectorize across all c values
834- stop_times = jax.vmap(compute_stop_time_for_c)(c_vals)
887+ compute_stop_time_vectorized = jax.vmap(compute_stop_time_for_c)
888+ stop_times = compute_stop_time_vectorized(c_vals)
835889
836890fig, ax = plt.subplots()
837891
@@ -928,12 +982,12 @@ def create_mccall_continuous(
928982 key = jax.random.PRNGKey(seed)
929983 s = jax.random.normal(key, (mc_size,))
930984 w_draws = jnp.exp(μ + σ * s)
931- return McCallModelContinuous(c=c , β=β , σ=σ , μ=μ, w_draws= w_draws)
985+ return McCallModelContinuous(c, β, σ, μ, w_draws)
932986
933987
934988@jax.jit
935989def compute_reservation_wage_continuous(model, max_iter=500, tol=1e-5):
936- c, β, σ, μ, w_draws = model.c, model.β, model.σ, model.μ, model.w_draws
990+ c, β, σ, μ, w_draws = model
937991
938992 h = jnp.mean(w_draws) / (1 - β) # initial guess
939993
@@ -949,8 +1003,9 @@ def compute_reservation_wage_continuous(model, max_iter=500, tol=1e-5):
9491003 return jnp.logical_and(i < max_iter, error > tol)
9501004
9511005 initial_state = (h, 0, tol + 1)
952- h_final, _, _ = jax.lax.while_loop(cond, update, initial_state)
953-
1006+ final_state = jax.lax.while_loop(cond, update, initial_state)
1007+ h_final, _, _ = final_state
1008+
9541009 # Now compute the reservation wage
9551010 return (1 - β) * h_final
9561011```
@@ -969,12 +1024,13 @@ def compute_R_element(c, β):
9691024 model = create_mccall_continuous(c=c, β=β)
9701025 return compute_reservation_wage_continuous(model)
9711026
972- # Create meshgrid and vectorize computation
973- c_grid, β_grid = jnp.meshgrid(c_vals, β_vals, indexing='ij')
974- compute_R_vectorized = jax.vmap(
975- jax.vmap(compute_R_element,
976- in_axes=(None, 0)),
977- in_axes=(0, None))
1027+ # First, vectorize over β (holding c fixed)
1028+ compute_R_over_β = jax.vmap(compute_R_element, in_axes=(None, 0))
1029+
1030+ # Next, vectorize over c (applying the above function to each c)
1031+ compute_R_vectorized = jax.vmap(compute_R_over_β, in_axes=(0, None))
1032+
1033+ # Apply to compute the full grid
9781034R = compute_R_vectorized(c_vals, β_vals)
9791035```
9801036
0 commit comments