Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 75 additions & 19 deletions lectures/mccall_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -498,27 +498,29 @@ colors = prop_cycle.by_key()['color']

# Plot the wage offer distribution
ax.plot(w, q, '-', alpha=0.6, lw=2,
label='wage offer distribution',
label='wage offer distribution',
color=colors[0])

# Compute reservation wage with default beta
model_default = McCallModel()
v_init = model_default.w / (1 - model_default.β)
c, β, w, q = model_default
v_init = w / (1 - β)
v_default, res_wage_default = compute_reservation_wage(
model_default, v_init
)

# Compute reservation wage with lower beta
β_new = 0.96
model_low_beta = McCallModel(β=β_new)
v_init_low = model_low_beta.w / (1 - model_low_beta.β)
c, β_low, w, q = model_low_beta
v_init_low = w / (1 - β_low)
v_low, res_wage_low = compute_reservation_wage(
model_low_beta, v_init_low
)

# Plot vertical lines for reservation wages
ax.axvline(x=res_wage_default, color=colors[1], lw=2,
label=f'reservation wage (β={model_default.β})')
label=f'reservation wage (β={β})')
ax.axvline(x=res_wage_low, color=colors[2], lw=2,
label=f'reservation wage (β={β_new})')

Expand Down Expand Up @@ -621,12 +623,43 @@ Let $h$ denote the continuation value:

The Bellman equation can now be written as

$$
```{math}
:label: j1b

v^*(w')
= \max \left\{ \frac{w'}{1 - \beta}, \, h \right\}
```

Now let's derive a nonlinear equation for $h$ alone.

Starting from {eq}`j1b`, we multiply both sides by $q(w')$ to get

$$
v^*(w') q(w') = \max \left\{ \frac{w'}{1 - \beta}, h \right\} q(w')
$$

Next, we sum both sides over $w' \in \mathbb{W}$:

$$
\sum_{w' \in \mathbb W} v^*(w') q(w')
= \sum_{w' \in \mathbb W} \max \left\{ \frac{w'}{1 - \beta}, h \right\} q(w')
$$

Now multiply both sides by $\beta$:

$$
\beta \sum_{w' \in \mathbb W} v^*(w') q(w')
= \beta \sum_{w' \in \mathbb W} \max \left\{ \frac{w'}{1 - \beta}, h \right\} q(w')
$$

Substituting this last equation into {eq}`j1` gives
Add $c$ to both sides:

$$
c + \beta \sum_{w' \in \mathbb W} v^*(w') q(w')
= c + \beta \sum_{w' \in \mathbb W} \max \left\{ \frac{w'}{1 - \beta}, h \right\} q(w')
$$

Finally, using the definition of $h$ from {eq}`j1`, the left-hand side is just $h$, giving us

```{math}
:label: j2
Expand All @@ -638,7 +671,7 @@ Substituting this last equation into {eq}`j1` gives
\right\} q (w')
```

This is a nonlinear equation that we can solve for $h$.
This is a nonlinear equation in the single scalar $h$ that we can solve for $h$.

As before, we will use successive approximations:

Expand Down Expand Up @@ -781,8 +814,28 @@ plt.show()
And here's a solution using JAX.

```{code-cell} ipython3
# First, we set up a function to draw random wage offers from the distribution.
# We use the inverse transform method: draw a uniform random variable u,
# then find the smallest wage w such that the CDF at w is >= u.
cdf = jnp.cumsum(q_default)

def draw_wage(uniform_rv):
"""
Draw a wage from the distribution q_default using the inverse transform method.

Parameters:
-----------
uniform_rv : float
A uniform random variable on [0, 1]

Returns:
--------
wage : float
A wage drawn from w_default with probabilities given by q_default
"""
return w_default[jnp.searchsorted(cdf, uniform_rv)]


def compute_stopping_time(w_bar, key):
"""
Compute stopping time by drawing wages until one exceeds `w_bar`.
Expand All @@ -791,7 +844,7 @@ def compute_stopping_time(w_bar, key):
t, key, accept = loop_state
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)
w = w_default[jnp.searchsorted(cdf, u)]
w = draw_wage(u)
accept = w >= w_bar
t = t + 1
return t, key, accept
Expand Down Expand Up @@ -831,7 +884,8 @@ def compute_stop_time_for_c(c):
return compute_mean_stopping_time(w_bar)

# Vectorize across all c values
stop_times = jax.vmap(compute_stop_time_for_c)(c_vals)
compute_stop_time_vectorized = jax.vmap(compute_stop_time_for_c)
stop_times = compute_stop_time_vectorized(c_vals)

fig, ax = plt.subplots()

Expand Down Expand Up @@ -928,12 +982,12 @@ def create_mccall_continuous(
key = jax.random.PRNGKey(seed)
s = jax.random.normal(key, (mc_size,))
w_draws = jnp.exp(μ + σ * s)
return McCallModelContinuous(c=c, β, σ, μ=μ, w_draws=w_draws)
return McCallModelContinuous(c, β, σ, μ, w_draws)


@jax.jit
def compute_reservation_wage_continuous(model, max_iter=500, tol=1e-5):
c, β, σ, μ, w_draws = model.c, model.β, model.σ, model.μ, model.w_draws
c, β, σ, μ, w_draws = model

h = jnp.mean(w_draws) / (1 - β) # initial guess

Expand All @@ -949,8 +1003,9 @@ def compute_reservation_wage_continuous(model, max_iter=500, tol=1e-5):
return jnp.logical_and(i < max_iter, error > tol)

initial_state = (h, 0, tol + 1)
h_final, _, _ = jax.lax.while_loop(cond, update, initial_state)

final_state = jax.lax.while_loop(cond, update, initial_state)
h_final, _, _ = final_state

# Now compute the reservation wage
return (1 - β) * h_final
```
Expand All @@ -969,12 +1024,13 @@ def compute_R_element(c, β):
model = create_mccall_continuous(c=c, β=β)
return compute_reservation_wage_continuous(model)

# Create meshgrid and vectorize computation
c_grid, β_grid = jnp.meshgrid(c_vals, β_vals, indexing='ij')
compute_R_vectorized = jax.vmap(
jax.vmap(compute_R_element,
in_axes=(None, 0)),
in_axes=(0, None))
# First, vectorize over β (holding c fixed)
compute_R_over_β = jax.vmap(compute_R_element, in_axes=(None, 0))

# Next, vectorize over c (applying the above function to each c)
compute_R_vectorized = jax.vmap(compute_R_over_β, in_axes=(0, None))

# Apply to compute the full grid
R = compute_R_vectorized(c_vals, β_vals)
```

Expand Down
Loading
Loading