Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 66 additions & 54 deletions lectures/mccall_fitted_vfi.md
Original file line number Diff line number Diff line change
Expand Up @@ -611,55 +611,50 @@ When employed, the agent faces job separation with probability $\alpha$ each per

Now let's simulate many agents simultaneously to examine the cross-sectional unemployment rate.

We first create a vectorized version of `update_agent` to efficiently update all agents in parallel:
To do this efficiently, we need a different approach than `simulate_employment_path` defined above.

```{code-cell} ipython3
# Create vectorized version of update_agent
update_agents_vmap = jax.vmap(
update_agent, in_axes=(0, 0, 0, None, None)
)
```
The key differences are:

Next we define the core simulation function, which uses `lax.fori_loop` to efficiently iterate many agents forward in time:

```{code-cell} ipython3
@partial(jax.jit, static_argnums=(3, 4))
def _simulate_cross_section_compiled(
key: jnp.ndarray,
model: Model,
w_bar: float,
n_agents: int,
T: int
):
"""JIT-compiled core simulation loop using lax.fori_loop.
Returns only the final employment state to save memory."""
c, α, β, ρ, ν, γ, w_grid, z_draws = model
- `simulate_employment_path` records the entire history (all T periods) for a single agent, which is useful for visualization but memory-intensive
- The new function `sim_agent` below only tracks and returns the final state, which is all we need for cross-sectional statistics
- `sim_agent` uses `lax.fori_loop` instead of a Python loop, making it JIT-compilable and suitable for vectorization across many agents

# Initialize arrays
init_key, subkey = jax.random.split(key)
wages = jnp.exp(jax.random.normal(subkey, (n_agents,)) * ν)
status = jnp.zeros(n_agents, dtype=jnp.int32)
We first define a function that simulates a single agent forward T time steps:

def update(t, loop_state):
status, wages = loop_state
```{code-cell} ipython3
@jax.jit
def sim_agent(key, initial_status, initial_wage, model, w_bar, T):
"""
Simulate a single agent forward T time steps using lax.fori_loop.

# Shift loop state forwards
step_key = jax.random.fold_in(init_key, t)
agent_keys = jax.random.split(step_key, n_agents)
Uses fold_in to generate a new key at each time step.

status, wages = update_agents_vmap(
agent_keys, status, wages, model, w_bar
)
Parameters:
- key: JAX random key for this agent
- initial_status: Initial employment status (0 or 1)
- initial_wage: Initial wage
- model: Model instance
- w_bar: Reservation wage
- T: Number of time periods to simulate

return status, wages
Returns:
- final_status: Employment status after T periods
- final_wage: Wage after T periods
"""
def update(t, loop_state):
status, wage = loop_state
step_key = jax.random.fold_in(key, t)
status, wage = update_agent(step_key, status, wage, model, w_bar)
return status, wage

# Run simulation using fori_loop
initial_loop_state = (status, wages)
initial_loop_state = (initial_status, initial_wage)
final_loop_state = lax.fori_loop(0, T, update, initial_loop_state)
final_status, final_wage = final_loop_state
return final_status, final_wage


# Return only final employment state
final_is_employed, _ = final_loop_state
return final_is_employed
# Create vectorized version of sim_agent to process multiple agents in parallel
sim_agents_vmap = jax.vmap(sim_agent, in_axes=(0, 0, 0, None, None, None))


def simulate_cross_section(
Expand All @@ -669,30 +664,36 @@ def simulate_cross_section(
seed: int = 42
) -> float:
"""
Simulate employment paths for many agents and return final unemployment rate.
Simulate cross-section of agents and return unemployment rate.

Parameters:
- model: Model instance with parameters
- n_agents: Number of agents to simulate
- T: Number of periods to simulate
- seed: Random seed for reproducibility
This approach:
1. Generates n_agents random keys
2. Calls sim_agent for each agent (vectorized via vmap)
3. Collects the final states to produce the cross-section

Returns:
- unemployment_rate: Fraction of agents unemployed at time T
Returns the cross-sectional unemployment rate.
"""
c, α, β, ρ, ν, γ, w_grid, z_draws = model

key = jax.random.PRNGKey(seed)

# Solve for optimal reservation wage
w_bar = get_reservation_wage(model)

# Run JIT-compiled simulation
final_status = _simulate_cross_section_compiled(
key, model, w_bar, n_agents, T
# Initialize arrays
init_key, subkey = jax.random.split(key)
initial_wages = jnp.exp(jax.random.normal(subkey, (n_agents,)) * ν)
initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32)

# Generate n_agents random keys
agent_keys = jax.random.split(init_key, n_agents)

# Simulate each agent forward T steps (vectorized)
final_status, final_wages = sim_agents_vmap(
agent_keys, initial_status_vec, initial_wages, model, w_bar, T
)

# Calculate unemployment rate at final period
unemployment_rate = 1 - jnp.mean(final_status)

return unemployment_rate
```

Expand Down Expand Up @@ -743,12 +744,23 @@ def plot_cross_sectional_unemployment(
Generate histogram of cross-sectional unemployment at a specific time.

"""
c, α, β, ρ, ν, γ, w_grid, z_draws = model

# Get final employment state directly
key = jax.random.PRNGKey(42)
w_bar = get_reservation_wage(model)
final_status = _simulate_cross_section_compiled(
key, model, w_bar, n_agents, t_snapshot

# Initialize arrays
init_key, subkey = jax.random.split(key)
initial_wages = jnp.exp(jax.random.normal(subkey, (n_agents,)) * ν)
initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32)

# Generate n_agents random keys
agent_keys = jax.random.split(init_key, n_agents)

# Simulate each agent forward T steps (vectorized)
final_status, _ = sim_agents_vmap(
agent_keys, initial_status_vec, initial_wages, model, w_bar, t_snapshot
)

# Calculate unemployment rate
Expand Down
92 changes: 49 additions & 43 deletions lectures/mccall_model_with_sep_markov.md
Original file line number Diff line number Diff line change
Expand Up @@ -748,55 +748,50 @@ Often the second approach is better for our purposes, since it's easier to paral

Now let's simulate many agents simultaneously to examine the cross-sectional unemployment rate.

We first create a vectorized version of `update_agent` to efficiently update all agents in parallel:
To do this efficiently, we need a different approach than `simulate_employment_path` defined above.

```{code-cell} ipython3
# Create vectorized version of update_agent.
# Vectorize over key, status, wage_idx
update_agents_vmap = jax.vmap(
update_agent, in_axes=(0, 0, 0, None, None)
)
```
The key differences are:

Next we define the core simulation function, which uses `lax.fori_loop` to efficiently iterate many agents forward in time:
- `simulate_employment_path` records the entire history (all T periods) for a single agent, which is useful for visualization but memory-intensive
- The new function `sim_agent` below only tracks and returns the final state, which is all we need for cross-sectional statistics
- `sim_agent` uses `lax.fori_loop` instead of a Python loop, making it JIT-compilable and suitable for vectorization across many agents

We first define a function that simulates a single agent forward T time steps:

```{code-cell} ipython3
@jax.jit
def _simulate_cross_section_compiled(
key: jnp.ndarray,
model: Model,
w_bar: float,
initial_wage_indices: jnp.ndarray,
initial_status_vec: jnp.ndarray,
T: int
):
def sim_agent(key, initial_status, initial_wage_idx, model, w_bar, T):
"""
JIT-compiled core simulation loop for shifting the cross section
using lax.fori_loop. Returns the final employment employment status
cross-section.
Simulate a single agent forward T time steps using lax.fori_loop.

"""
n, w_vals, P, P_cumsum, β, c, α, γ = model
n_agents = len(initial_wage_indices)
Uses fold_in to generate a new key at each time step.

Parameters:
- key: JAX random key for this agent
- initial_status: Initial employment status (0 or 1)
- initial_wage_idx: Initial wage index
- model: Model instance
- w_bar: Reservation wage
- T: Number of time periods to simulate

Returns:
- final_status: Employment status after T periods
- final_wage_idx: Wage index after T periods
"""
def update(t, loop_state):
" Shift loop state forwards "
status, wage_indices = loop_state
status, wage_idx = loop_state
step_key = jax.random.fold_in(key, t)
agent_keys = jax.random.split(step_key, n_agents)
status, wage_indices = update_agents_vmap(
agent_keys, status, wage_indices, model, w_bar
)
return status, wage_indices
status, wage_idx = update_agent(step_key, status, wage_idx, model, w_bar)
return status, wage_idx

# Run simulation using fori_loop
initial_loop_state = (initial_status_vec, initial_wage_indices)
initial_loop_state = (initial_status, initial_wage_idx)
final_loop_state = lax.fori_loop(0, T, update, initial_loop_state)
final_status, final_wage_idx = final_loop_state
return final_status, final_wage_idx

# Return only final employment state
final_is_employed, _ = final_loop_state
return final_is_employed

# Create vectorized version of sim_agent to process multiple agents in parallel
sim_agents_vmap = jax.vmap(sim_agent, in_axes=(0, 0, 0, None, None, None))


def simulate_cross_section(
Expand All @@ -806,11 +801,14 @@ def simulate_cross_section(
seed: int = 42 # For reproducibility
) -> float:
"""
Wrapper function for _simulate_cross_section_compiled.
Simulate cross-section of agents and return unemployment rate.

Push forward a cross-section for T periods and return the final
cross-sectional unemployment rate.
This approach:
1. Generates n_agents random keys
2. Calls sim_agent for each agent (vectorized via vmap)
3. Collects the final states to produce the cross-section

Returns the cross-sectional unemployment rate.
"""
key = jax.random.PRNGKey(seed)

Expand All @@ -822,8 +820,12 @@ def simulate_cross_section(
initial_wage_indices = jnp.zeros(n_agents, dtype=jnp.int32)
initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32)

final_status = _simulate_cross_section_compiled(
key, model, w_bar, initial_wage_indices, initial_status_vec, T
# Generate n_agents random keys
agent_keys = jax.random.split(key, n_agents)

# Simulate each agent forward T steps (vectorized)
final_status, final_wage_idx = sim_agents_vmap(
agent_keys, initial_status_vec, initial_wage_indices, model, w_bar, T
)

unemployment_rate = 1 - jnp.mean(final_status)
Expand All @@ -834,7 +836,7 @@ This function generates a histogram showing the distribution of employment statu

```{code-cell} ipython3
def plot_cross_sectional_unemployment(
model: Model,
model: Model,
t_snapshot: int = 200, # Time of cross-sectional snapshot
n_agents: int = 20_000 # Number of agents to simulate
):
Expand All @@ -851,8 +853,12 @@ def plot_cross_sectional_unemployment(
initial_wage_indices = jnp.zeros(n_agents, dtype=jnp.int32)
initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32)

final_status = _simulate_cross_section_compiled(
key, model, w_bar, initial_wage_indices, initial_status_vec, t_snapshot
# Generate n_agents random keys
agent_keys = jax.random.split(key, n_agents)

# Simulate each agent forward T steps (vectorized)
final_status, _ = sim_agents_vmap(
agent_keys, initial_status_vec, initial_wage_indices, model, w_bar, t_snapshot
)

# Calculate unemployment rate
Expand Down
Loading