From bc17309860d623b2cfee4501d3cbbe5918df8094 Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Mon, 24 Nov 2025 06:11:13 +0900 Subject: [PATCH] Refactor cross-section simulation: reverse loop structure for better performance MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit refactors the cross-sectional agent simulation in both McCall model lectures to use a more efficient loop structure. Changes: - Replaced old approach (loop over time, vectorize over agents at each step) with new approach (vectorize over agents, loop over time per agent) - Added sim_agent() function that uses lax.fori_loop to simulate a single agent forward T time steps - Added sim_agents_vmap to vectorize sim_agent across multiple agents - Updated simulate_cross_section() to use the new implementation - Updated plot_cross_sectional_unemployment() to use sim_agents_vmap - Added explanatory text clarifying differences between simulate_employment_path() and sim_agent() Performance: The new approach has comparable or slightly better performance while being more modular and conceptually cleaner. Files modified: - mccall_model_with_sep_markov.md (discrete wage case) - mccall_fitted_vfi.md (continuous wage case) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- lectures/mccall_fitted_vfi.md | 120 +++++++++++++---------- lectures/mccall_model_with_sep_markov.md | 92 +++++++++-------- 2 files changed, 115 insertions(+), 97 deletions(-) diff --git a/lectures/mccall_fitted_vfi.md b/lectures/mccall_fitted_vfi.md index 8cd5fb2b7..c712d8ebb 100644 --- a/lectures/mccall_fitted_vfi.md +++ b/lectures/mccall_fitted_vfi.md @@ -611,55 +611,50 @@ When employed, the agent faces job separation with probability $\alpha$ each per Now let's simulate many agents simultaneously to examine the cross-sectional unemployment rate. -We first create a vectorized version of `update_agent` to efficiently update all agents in parallel: +To do this efficiently, we need a different approach than `simulate_employment_path` defined above. -```{code-cell} ipython3 -# Create vectorized version of update_agent -update_agents_vmap = jax.vmap( - update_agent, in_axes=(0, 0, 0, None, None) -) -``` +The key differences are: -Next we define the core simulation function, which uses `lax.fori_loop` to efficiently iterate many agents forward in time: - -```{code-cell} ipython3 -@partial(jax.jit, static_argnums=(3, 4)) -def _simulate_cross_section_compiled( - key: jnp.ndarray, - model: Model, - w_bar: float, - n_agents: int, - T: int - ): - """JIT-compiled core simulation loop using lax.fori_loop. - Returns only the final employment state to save memory.""" - c, α, β, ρ, ν, γ, w_grid, z_draws = model +- `simulate_employment_path` records the entire history (all T periods) for a single agent, which is useful for visualization but memory-intensive +- The new function `sim_agent` below only tracks and returns the final state, which is all we need for cross-sectional statistics +- `sim_agent` uses `lax.fori_loop` instead of a Python loop, making it JIT-compilable and suitable for vectorization across many agents - # Initialize arrays - init_key, subkey = jax.random.split(key) - wages = jnp.exp(jax.random.normal(subkey, (n_agents,)) * ν) - status = jnp.zeros(n_agents, dtype=jnp.int32) +We first define a function that simulates a single agent forward T time steps: - def update(t, loop_state): - status, wages = loop_state +```{code-cell} ipython3 +@jax.jit +def sim_agent(key, initial_status, initial_wage, model, w_bar, T): + """ + Simulate a single agent forward T time steps using lax.fori_loop. - # Shift loop state forwards - step_key = jax.random.fold_in(init_key, t) - agent_keys = jax.random.split(step_key, n_agents) + Uses fold_in to generate a new key at each time step. - status, wages = update_agents_vmap( - agent_keys, status, wages, model, w_bar - ) + Parameters: + - key: JAX random key for this agent + - initial_status: Initial employment status (0 or 1) + - initial_wage: Initial wage + - model: Model instance + - w_bar: Reservation wage + - T: Number of time periods to simulate - return status, wages + Returns: + - final_status: Employment status after T periods + - final_wage: Wage after T periods + """ + def update(t, loop_state): + status, wage = loop_state + step_key = jax.random.fold_in(key, t) + status, wage = update_agent(step_key, status, wage, model, w_bar) + return status, wage - # Run simulation using fori_loop - initial_loop_state = (status, wages) + initial_loop_state = (initial_status, initial_wage) final_loop_state = lax.fori_loop(0, T, update, initial_loop_state) + final_status, final_wage = final_loop_state + return final_status, final_wage + - # Return only final employment state - final_is_employed, _ = final_loop_state - return final_is_employed +# Create vectorized version of sim_agent to process multiple agents in parallel +sim_agents_vmap = jax.vmap(sim_agent, in_axes=(0, 0, 0, None, None, None)) def simulate_cross_section( @@ -669,30 +664,36 @@ def simulate_cross_section( seed: int = 42 ) -> float: """ - Simulate employment paths for many agents and return final unemployment rate. + Simulate cross-section of agents and return unemployment rate. - Parameters: - - model: Model instance with parameters - - n_agents: Number of agents to simulate - - T: Number of periods to simulate - - seed: Random seed for reproducibility + This approach: + 1. Generates n_agents random keys + 2. Calls sim_agent for each agent (vectorized via vmap) + 3. Collects the final states to produce the cross-section - Returns: - - unemployment_rate: Fraction of agents unemployed at time T + Returns the cross-sectional unemployment rate. """ + c, α, β, ρ, ν, γ, w_grid, z_draws = model + key = jax.random.PRNGKey(seed) # Solve for optimal reservation wage w_bar = get_reservation_wage(model) - # Run JIT-compiled simulation - final_status = _simulate_cross_section_compiled( - key, model, w_bar, n_agents, T + # Initialize arrays + init_key, subkey = jax.random.split(key) + initial_wages = jnp.exp(jax.random.normal(subkey, (n_agents,)) * ν) + initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32) + + # Generate n_agents random keys + agent_keys = jax.random.split(init_key, n_agents) + + # Simulate each agent forward T steps (vectorized) + final_status, final_wages = sim_agents_vmap( + agent_keys, initial_status_vec, initial_wages, model, w_bar, T ) - # Calculate unemployment rate at final period unemployment_rate = 1 - jnp.mean(final_status) - return unemployment_rate ``` @@ -743,12 +744,23 @@ def plot_cross_sectional_unemployment( Generate histogram of cross-sectional unemployment at a specific time. """ + c, α, β, ρ, ν, γ, w_grid, z_draws = model # Get final employment state directly key = jax.random.PRNGKey(42) w_bar = get_reservation_wage(model) - final_status = _simulate_cross_section_compiled( - key, model, w_bar, n_agents, t_snapshot + + # Initialize arrays + init_key, subkey = jax.random.split(key) + initial_wages = jnp.exp(jax.random.normal(subkey, (n_agents,)) * ν) + initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32) + + # Generate n_agents random keys + agent_keys = jax.random.split(init_key, n_agents) + + # Simulate each agent forward T steps (vectorized) + final_status, _ = sim_agents_vmap( + agent_keys, initial_status_vec, initial_wages, model, w_bar, t_snapshot ) # Calculate unemployment rate diff --git a/lectures/mccall_model_with_sep_markov.md b/lectures/mccall_model_with_sep_markov.md index 36aa1be3f..7902bf70f 100644 --- a/lectures/mccall_model_with_sep_markov.md +++ b/lectures/mccall_model_with_sep_markov.md @@ -748,55 +748,50 @@ Often the second approach is better for our purposes, since it's easier to paral Now let's simulate many agents simultaneously to examine the cross-sectional unemployment rate. -We first create a vectorized version of `update_agent` to efficiently update all agents in parallel: +To do this efficiently, we need a different approach than `simulate_employment_path` defined above. -```{code-cell} ipython3 -# Create vectorized version of update_agent. -# Vectorize over key, status, wage_idx -update_agents_vmap = jax.vmap( - update_agent, in_axes=(0, 0, 0, None, None) -) -``` +The key differences are: -Next we define the core simulation function, which uses `lax.fori_loop` to efficiently iterate many agents forward in time: +- `simulate_employment_path` records the entire history (all T periods) for a single agent, which is useful for visualization but memory-intensive +- The new function `sim_agent` below only tracks and returns the final state, which is all we need for cross-sectional statistics +- `sim_agent` uses `lax.fori_loop` instead of a Python loop, making it JIT-compilable and suitable for vectorization across many agents + +We first define a function that simulates a single agent forward T time steps: ```{code-cell} ipython3 @jax.jit -def _simulate_cross_section_compiled( - key: jnp.ndarray, - model: Model, - w_bar: float, - initial_wage_indices: jnp.ndarray, - initial_status_vec: jnp.ndarray, - T: int - ): +def sim_agent(key, initial_status, initial_wage_idx, model, w_bar, T): """ - JIT-compiled core simulation loop for shifting the cross section - using lax.fori_loop. Returns the final employment employment status - cross-section. + Simulate a single agent forward T time steps using lax.fori_loop. - """ - n, w_vals, P, P_cumsum, β, c, α, γ = model - n_agents = len(initial_wage_indices) + Uses fold_in to generate a new key at each time step. + Parameters: + - key: JAX random key for this agent + - initial_status: Initial employment status (0 or 1) + - initial_wage_idx: Initial wage index + - model: Model instance + - w_bar: Reservation wage + - T: Number of time periods to simulate + Returns: + - final_status: Employment status after T periods + - final_wage_idx: Wage index after T periods + """ def update(t, loop_state): - " Shift loop state forwards " - status, wage_indices = loop_state + status, wage_idx = loop_state step_key = jax.random.fold_in(key, t) - agent_keys = jax.random.split(step_key, n_agents) - status, wage_indices = update_agents_vmap( - agent_keys, status, wage_indices, model, w_bar - ) - return status, wage_indices + status, wage_idx = update_agent(step_key, status, wage_idx, model, w_bar) + return status, wage_idx - # Run simulation using fori_loop - initial_loop_state = (initial_status_vec, initial_wage_indices) + initial_loop_state = (initial_status, initial_wage_idx) final_loop_state = lax.fori_loop(0, T, update, initial_loop_state) + final_status, final_wage_idx = final_loop_state + return final_status, final_wage_idx - # Return only final employment state - final_is_employed, _ = final_loop_state - return final_is_employed + +# Create vectorized version of sim_agent to process multiple agents in parallel +sim_agents_vmap = jax.vmap(sim_agent, in_axes=(0, 0, 0, None, None, None)) def simulate_cross_section( @@ -806,11 +801,14 @@ def simulate_cross_section( seed: int = 42 # For reproducibility ) -> float: """ - Wrapper function for _simulate_cross_section_compiled. + Simulate cross-section of agents and return unemployment rate. - Push forward a cross-section for T periods and return the final - cross-sectional unemployment rate. + This approach: + 1. Generates n_agents random keys + 2. Calls sim_agent for each agent (vectorized via vmap) + 3. Collects the final states to produce the cross-section + Returns the cross-sectional unemployment rate. """ key = jax.random.PRNGKey(seed) @@ -822,8 +820,12 @@ def simulate_cross_section( initial_wage_indices = jnp.zeros(n_agents, dtype=jnp.int32) initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32) - final_status = _simulate_cross_section_compiled( - key, model, w_bar, initial_wage_indices, initial_status_vec, T + # Generate n_agents random keys + agent_keys = jax.random.split(key, n_agents) + + # Simulate each agent forward T steps (vectorized) + final_status, final_wage_idx = sim_agents_vmap( + agent_keys, initial_status_vec, initial_wage_indices, model, w_bar, T ) unemployment_rate = 1 - jnp.mean(final_status) @@ -834,7 +836,7 @@ This function generates a histogram showing the distribution of employment statu ```{code-cell} ipython3 def plot_cross_sectional_unemployment( - model: Model, + model: Model, t_snapshot: int = 200, # Time of cross-sectional snapshot n_agents: int = 20_000 # Number of agents to simulate ): @@ -851,8 +853,12 @@ def plot_cross_sectional_unemployment( initial_wage_indices = jnp.zeros(n_agents, dtype=jnp.int32) initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32) - final_status = _simulate_cross_section_compiled( - key, model, w_bar, initial_wage_indices, initial_status_vec, t_snapshot + # Generate n_agents random keys + agent_keys = jax.random.split(key, n_agents) + + # Simulate each agent forward T steps (vectorized) + final_status, _ = sim_agents_vmap( + agent_keys, initial_status_vec, initial_wage_indices, model, w_bar, t_snapshot ) # Calculate unemployment rate