Skip to content

Conversation

@jstac
Copy link
Contributor

@jstac jstac commented Nov 17, 2025

Summary

  • Refactored random key handling in both McCall lectures to use jax.random.fold_in instead of threading keys through loop state
  • Fixed bug where n_agents was undefined in _simulate_cross_section_compiled
  • Standardized separation rate (α = 0.05) across both lectures for consistency

Changes

Key Handling Improvements

Both mccall_model_with_sep_markov.md and mccall_fitted_vfi.md now use the more idiomatic JAX pattern:

  • Replace double key splitting (split then split again) with fold_in
  • Remove key from loop state (cleaner, more efficient)
  • Deterministic randomness based on time step t

Before:

def update(t, loop_state):
    key, status, wages = loop_state
    key, subkey = jax.random.split(key)
    agent_keys = jax.random.split(subkey, n_agents)
    # ...
    return key, status, wages

After:

def update(t, loop_state):
    status, wages = loop_state
    step_key = jax.random.fold_in(init_key, t)
    agent_keys = jax.random.split(step_key, n_agents)
    # ...
    return status, wages

Bug Fixes

  • Fixed NameError in mccall_model_with_sep_markov.md where n_agents was undefined
  • Added n_agents = len(initial_wage_indices) to extract from input arrays

Parameter Consistency

  • Changed separation rate in mccall_fitted_vfi.md from α = 0.1 to α = 0.05
  • All economic parameters now match between the two lectures:
    • c = 1.0 (unemployment compensation)
    • α = 0.05 (separation rate)
    • β = 0.96 (discount factor)
    • ρ = 0.9 (wage persistence)
    • ν = 0.2 (wage volatility)
    • γ = 1.5 (utility parameter)

Testing

Both lectures tested with jupytext --to py and run successfully without errors.

🤖 Generated with Claude Code

- Refactor random key handling to use fold_in instead of key threading
  - More idiomatic JAX pattern for indexed loops
  - Removes key from loop state for cleaner code
  - Deterministic randomness based on time step

- Fix missing n_agents variable in _simulate_cross_section_compiled
  - Extract from initial_wage_indices using len()

- Standardize separation rate across lectures
  - Set α = 0.05 in mccall_fitted_vfi to match mccall_model_with_sep_markov
  - All economic parameters now consistent between lectures

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
@jstac
Copy link
Contributor Author

jstac commented Nov 17, 2025

Work Summary

This PR modernizes the JAX code in two McCall job search lectures by improving random key handling patterns and ensuring parameter consistency.

Key Improvements

1. More Efficient Random Key Handling
Replaced the inefficient double-split pattern with jax.random.fold_in:

  • Removes unnecessary intermediate variables
  • Eliminates key from loop state (reduces tuple packing/unpacking overhead)
  • More idiomatic JAX pattern for time-indexed loops
  • Provides deterministic randomness based on loop iteration

2. Bug Fixes

  • Fixed NameError where n_agents was undefined in _simulate_cross_section_compiled
  • Solution: Extract n_agents = len(initial_wage_indices) from input arrays

3. Parameter Standardization

  • Aligned separation rate α across both lectures (0.1 → 0.05)
  • Makes lectures directly comparable with consistent economic assumptions
  • Cross-sectional unemployment rate dropped from ~29% to ~20% in the fitted VFI lecture

Testing

Both lectures verified by:

  1. Converting to Python with jupytext --to py
  2. Running full execution without errors
  3. Confirming expected output and visualizations

All economic parameters now consistent between lectures, improving the learning experience for students working through the sequence.

@jstac
Copy link
Contributor Author

jstac commented Nov 17, 2025

@HumphreyYang can you please run your eyes over these lectures quickly, and this PR?

If you're happy please flag it for @mmcky so he can merge it.

@HumphreyYang
Copy link
Member

Roger that @jstac, once it builds I will look into them!

@github-actions
Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-715--sunny-cactus-210e3e.netlify.app (72ae11f)

📚 Changed Lecture Pages: mccall_fitted_vfi, mccall_model_with_sep_markov

@HumphreyYang
Copy link
Member

Hi @mmcky,

I think this is ready to merge once it is built!

@github-actions
Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-715--sunny-cactus-210e3e.netlify.app (f1ad9ff)

📚 Changed Lecture Pages: mccall_fitted_vfi, mccall_model_with_sep_markov

@mmcky mmcky merged commit c6cc43c into main Nov 17, 2025
1 check passed
@mmcky mmcky deleted the jsiii_jit branch November 17, 2025 10:13
@mmcky
Copy link
Contributor

mmcky commented Nov 17, 2025

thanks @HumphreyYang and @jstac.

Merged and will make live.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants