Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 47 additions & 50 deletions lectures/mccall_model_with_sep_markov.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ kernelspec:
This lecture builds on the job search model with separation presented in the
{doc}`previous lecture <mccall_model_with_separation>`.

The key difference is that wage offers now follow a **Markov chain** rather than
being independent and identically distributed (IID).
The key difference is that wage offers now follow a {doc}`Markov chain <finite_markov>` rather than
being IID.

This modification adds persistence to the wage offer process, meaning that
today's wage offer provides information about tomorrow's offer.
Expand Down Expand Up @@ -264,8 +264,7 @@ def T(v: jnp.ndarray, model: Model) -> jnp.ndarray:
return jnp.maximum(accept, reject)
```

Here's a routine for value function iteration, as well as a second routine that
computes the reservation wage directly from the value function.
Here's a routine for value function iteration.

```{code-cell} ipython3
@jax.jit
Expand Down Expand Up @@ -293,67 +292,65 @@ def vfi(
v_final, error, i = final_loop_state

return v_final
```

Here is a routine that computes the reservation wage from the value function.

```{code-cell} ipython3
@jax.jit
def get_reservation_wage(v: jnp.ndarray, model: Model) -> float:
"""
Calculate the reservation wage directly from the value function.
Calculate the reservation wage from the unemployed agents
value function v := v_u.

The reservation wage is the lowest wage w where accepting (v_e(w))
is at least as good as rejecting (u(c) + β(Pv)(w)).

Parameters:
- v: Value function v_u
- model: Model instance containing parameters
is at least as good as rejecting (u(c) + β(Pv_u)(w)).

Returns:
- Reservation wage (lowest wage for which acceptance is optimal)
"""
n, w_vals, P, P_cumsum, β, c, α, γ = model

# Compute accept and reject values
d = 1 / (1 - β * (1 - α))
accept = d * (u(w_vals, γ) + α * β * P @ v)
reject = u(c, γ) + β * P @ v
v_e = d * (u(w_vals, γ) + α * β * P @ v)
continuation_value = u(c, γ) + β * P @ v

# Find where acceptance becomes optimal
should_accept = accept >= reject
first_accept_idx = jnp.argmax(should_accept)
accept_indices = v_e >= continuation_value
first_accept_idx = jnp.argmax(accept_indices) # index of first True

# If no acceptance (all False), return infinity
# Otherwise return the wage at the first acceptance index
return jnp.where(jnp.any(should_accept), w_vals[first_accept_idx], jnp.inf)
return jnp.where(jnp.any(accept_indices), w_vals[first_accept_idx], jnp.inf)
```


## Computing the Solution

Let's solve the model:

```{code-cell} ipython3
model = create_js_with_sep_model()
n, w_vals, P, P_cumsum, β, c, α, γ = model
v_star = vfi(model)
w_star = get_reservation_wage(v_star, model)
v_u = vfi(model)
w_bar = get_reservation_wage(v_u, model)
```

Next we compute some related quantities for plotting.

```{code-cell} ipython3
d = 1 / (1 - β * (1 - α))
accept = d * (u(w_vals, γ) + α * β * P @ v_star)
h_star = u(c, γ) + β * P @ v_star
v_e = d * (u(w_vals, γ) + α * β * P @ v_u)
h = u(c, γ) + β * P @ v_u
```

Let's plot our results.

```{code-cell} ipython3
fig, ax = plt.subplots(figsize=(9, 5.2))
ax.plot(w_vals, h_star, linewidth=4, ls="--", alpha=0.4,
label="continuation value")
ax.plot(w_vals, accept, linewidth=4, ls="--", alpha=0.4,
label="stopping value")
ax.plot(w_vals, v_star, "k-", alpha=0.7, label=r"$v_u^*(w)$")
ax.plot(w_vals, h, 'g-', linewidth=2,
label="continuation value function $h$")
ax.plot(w_vals, v_e, 'b-', linewidth=2,
label="employment value function $v_e$")
ax.legend(frameon=False)
ax.set_xlabel(r"$w$")
plt.show()
Expand All @@ -370,16 +367,16 @@ Let's examine how reservation wages change with the separation rate.
```{code-cell} ipython3
α_vals: jnp.ndarray = jnp.linspace(0.0, 1.0, 10)

w_star_vec = []
w_bar_vec = []
for α in α_vals:
model = create_js_with_sep_model(α=α)
v_star = vfi(model)
w_star = get_reservation_wage(v_star, model)
w_star_vec.append(w_star)
v_u = vfi(model)
w_bar = get_reservation_wage(v_u, model)
w_bar_vec.append(w_bar)

fig, ax = plt.subplots(figsize=(9, 5.2))
ax.plot(
α_vals, w_star_vec, linewidth=2, alpha=0.6, label="reservation wage"
α_vals, w_bar_vec, linewidth=2, alpha=0.6, label="reservation wage"
)
ax.legend(frameon=False)
ax.set_xlabel(r"$\alpha$")
Expand Down Expand Up @@ -414,7 +411,7 @@ unemployed, 1 if employed) and $w_t$ is
* their current wage, if employed.

```{code-cell} ipython3
def update_agent(key, status, wage_idx, model, w_star):
def update_agent(key, status, wage_idx, model, w_bar):
"""
Updates an agent's employment status and current wage.

Expand All @@ -423,7 +420,7 @@ def update_agent(key, status, wage_idx, model, w_star):
- status: Current employment status (0 or 1)
- wage_idx: Current wage, recorded as an array index
- model: Model instance
- w_star: Reservation wage
- w_bar: Reservation wage

"""
n, w_vals, P, P_cumsum, β, c, α, γ = model
Expand All @@ -436,7 +433,7 @@ def update_agent(key, status, wage_idx, model, w_star):
)
separation_occurs = jax.random.uniform(key2) < α
# Accept if current wage meets or exceeds reservation wage
accepts = w_vals[wage_idx] >= w_star
accepts = w_vals[wage_idx] >= w_bar

# If employed: status = 1 if no separation, 0 if separation
# If unemployed: status = 1 if accepts, 0 if rejects
Expand All @@ -462,7 +459,7 @@ Here's a function to simulate the employment path of a single agent.
```{code-cell} ipython3
def simulate_employment_path(
model: Model, # Model details
w_star: float, # Reservation wage
w_bar: float, # Reservation wage
T: int = 2_000, # Simulation length
seed: int = 42 # Set seed for simulation
):
Expand All @@ -487,7 +484,7 @@ def simulate_employment_path(

key, subkey = jax.random.split(key)
status, wage_idx = update_agent(
subkey, status, wage_idx, model, w_star
subkey, status, wage_idx, model, w_bar
)

return jnp.array(wage_path), jnp.array(status_path)
Expand All @@ -499,10 +496,10 @@ Let's create a comprehensive plot of the employment simulation:
model = create_js_with_sep_model()

# Calculate reservation wage for plotting
v_star = vfi(model)
w_star = get_reservation_wage(v_star, model)
v_u = vfi(model)
w_bar = get_reservation_wage(v_u, model)

wage_path, employment_status = simulate_employment_path(model, w_star)
wage_path, employment_status = simulate_employment_path(model, w_bar)

fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(8, 6))

Expand All @@ -518,8 +515,8 @@ ax1.set_ylim(-0.1, 1.1)

# Plot wage path with employment status coloring
ax2.plot(wage_path, 'b-', alpha=0.7, linewidth=1)
ax2.axhline(y=w_star, color='black', linestyle='--', alpha=0.8,
label=f'Reservation wage: {w_star:.2f}')
ax2.axhline(y=w_bar, color='black', linestyle='--', alpha=0.8,
label=f'Reservation wage: {w_bar:.2f}')
ax2.set_xlabel('time')
ax2.set_ylabel('wage')
ax2.set_title('Wage path (actual and offers)')
Expand Down Expand Up @@ -620,7 +617,7 @@ We first create a vectorized version of `update_agent` to efficiently update all

```{code-cell} ipython3
# Create vectorized version of update_agent
# The last parameter is now w_star (scalar) instead of σ (array)
# The last parameter is now w_bar (scalar) instead of σ (array)
update_agents_vmap = jax.vmap(
update_agent, in_axes=(0, 0, 0, None, None)
)
Expand All @@ -633,7 +630,7 @@ Next we define the core simulation function, which uses `lax.fori_loop` to effic
def _simulate_cross_section_compiled(
key: jnp.ndarray,
model: Model,
w_star: float,
w_bar: float,
n_agents: int,
T: int
):
Expand All @@ -653,7 +650,7 @@ def _simulate_cross_section_compiled(
agent_keys = jax.random.split(subkey, n_agents)

status, wage_indices = update_agents_vmap(
agent_keys, status, wage_indices, model, w_star
agent_keys, status, wage_indices, model, w_bar
)

return key, status, wage_indices
Expand Down Expand Up @@ -688,12 +685,12 @@ def simulate_cross_section(
key = jax.random.PRNGKey(seed)

# Solve for optimal reservation wage
v_star = vfi(model)
w_star = get_reservation_wage(v_star, model)
v_u = vfi(model)
w_bar = get_reservation_wage(v_u, model)

# Run JIT-compiled simulation
final_status = _simulate_cross_section_compiled(
key, model, w_star, n_agents, T
key, model, w_bar, n_agents, T
)

# Calculate unemployment rate at final period
Expand All @@ -717,10 +714,10 @@ def plot_cross_sectional_unemployment(model: Model, t_snapshot: int = 200,
"""
# Get final employment state directly
key = jax.random.PRNGKey(42)
v_star = vfi(model)
w_star = get_reservation_wage(v_star, model)
v_u = vfi(model)
w_bar = get_reservation_wage(v_u, model)
final_status = _simulate_cross_section_compiled(
key, model, w_star, n_agents, t_snapshot
key, model, w_bar, n_agents, t_snapshot
)

# Calculate unemployment rate
Expand Down
Loading