Skip to content

Conversation

@jstac
Copy link
Contributor

@jstac jstac commented Nov 23, 2025

Summary

This PR refactors the cross-sectional agent simulation in both McCall model lectures to use a more efficient and modular loop structure.

Key Changes

Previous approach: Loop over time steps, vectorize over all agents at each step
New approach: Vectorize over agents, with each agent looping over time internally

Implementation Details

  1. Added sim_agent() function - Uses lax.fori_loop to simulate a single agent forward T time steps with fold_in for key generation
  2. Added sim_agents_vmap - Vectorizes sim_agent across multiple agents using jax.vmap
  3. Updated simulate_cross_section() - Now generates n_agents keys and passes each to sim_agent
  4. Updated plot_cross_sectional_unemployment() - Uses sim_agents_vmap directly
  5. Added explanatory text - Clarifies the difference between simulate_employment_path() (records full history for visualization) and sim_agent() (returns only final state for efficiency)

Performance

Testing with 50,000 agents over 200 periods showed:

  • With @jax.jit on sim_agent: New approach is ~1.07x faster and has lower variance
  • The JIT compilation of the intermediate function before vmapping improves performance

Files Modified

  • mccall_model_with_sep_markov.md (discrete wage case)
  • mccall_fitted_vfi.md (continuous wage case)

Both files now use the same efficient pattern, with the only difference being continuous vs discrete wages.

Testing

Both notebooks have been converted to Python and run successfully, producing expected results that match the ergodic theorem (time-average ≈ cross-sectional average).

🤖 Generated with Claude Code

…performance

This commit refactors the cross-sectional agent simulation in both McCall
model lectures to use a more efficient loop structure.

Changes:
- Replaced old approach (loop over time, vectorize over agents at each step)
  with new approach (vectorize over agents, loop over time per agent)
- Added sim_agent() function that uses lax.fori_loop to simulate a single
  agent forward T time steps
- Added sim_agents_vmap to vectorize sim_agent across multiple agents
- Updated simulate_cross_section() to use the new implementation
- Updated plot_cross_sectional_unemployment() to use sim_agents_vmap
- Added explanatory text clarifying differences between
  simulate_employment_path() and sim_agent()

Performance: The new approach has comparable or slightly better performance
while being more modular and conceptually cleaner.

Files modified:
- mccall_model_with_sep_markov.md (discrete wage case)
- mccall_fitted_vfi.md (continuous wage case)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
@jstac
Copy link
Contributor Author

jstac commented Nov 23, 2025

Additional Context

This refactoring was motivated by exploring different strategies for simulating cross-sections in JAX:

Architecture Comparison

The original implementation used:

def update(t, loop_state):
    # For each time step t:
    # 1. Generate keys for all agents
    # 2. Update all agents in parallel (vmap)
    lax.fori_loop(0, T, update, ...)  # outer loop over time

The new implementation reverses this:

def sim_agent(key, ...):
    def update(t, state):
        # For a single agent:
        # Update state using fold_in for keys
    lax.fori_loop(0, T, update, ...)  # inner loop over time

sim_agents_vmap(...)  # outer vectorization over agents

Why This Matters

  1. Modularity: sim_agent is a self-contained function that can be tested and reused independently
  2. Conceptual clarity: Each agent's trajectory is simulated as a unit, which matches the economic interpretation
  3. Performance: The JIT compiler can optimize sim_agent before it gets vmapped, leading to better performance
  4. Flexibility: Easier to extend (e.g., returning additional statistics per agent)

Interesting Finding

Removing @jax.jit from sim_agent made the new approach slower, which goes against the usual advice to avoid JIT on intermediate functions. In this case, JIT-compiling before vmapping helps the compiler optimize the lax.fori_loop structure.

Both implementations are correct and produce statistically equivalent results. The new one is cleaner and performs just as well (or slightly better).

@github-actions
Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-734--sunny-cactus-210e3e.netlify.app (bc17309)

📚 Changed Lecture Pages: mccall_fitted_vfi, mccall_model_with_sep_markov

@jstac jstac merged commit 2afed31 into main Nov 23, 2025
1 check passed
@jstac jstac deleted the test-mccall-sim branch November 23, 2025 23:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants