From 72ae11fe2c535507d3da587bcb1fd3dcb838b482 Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Mon, 17 Nov 2025 17:04:07 +0900 Subject: [PATCH 1/2] Improve key handling and fix parameter consistency in McCall lectures MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Refactor random key handling to use fold_in instead of key threading - More idiomatic JAX pattern for indexed loops - Removes key from loop state for cleaner code - Deterministic randomness based on time step - Fix missing n_agents variable in _simulate_cross_section_compiled - Extract from initial_wage_indices using len() - Standardize separation rate across lectures - Set α = 0.05 in mccall_fitted_vfi to match mccall_model_with_sep_markov - All economic parameters now consistent between lectures 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- lectures/mccall_fitted_vfi.md | 16 +++--- lectures/mccall_model_with_sep_markov.md | 73 ++++++++++++------------ 2 files changed, 45 insertions(+), 44 deletions(-) diff --git a/lectures/mccall_fitted_vfi.md b/lectures/mccall_fitted_vfi.md index f25e9e7cf..5f05c56cb 100644 --- a/lectures/mccall_fitted_vfi.md +++ b/lectures/mccall_fitted_vfi.md @@ -268,7 +268,7 @@ class Model(NamedTuple): def create_mccall_model( c: float = 1.0, - α: float = 0.1, + α: float = 0.05, β: float = 0.96, ρ: float = 0.9, ν: float = 0.2, @@ -633,29 +633,29 @@ def _simulate_cross_section_compiled( c, α, β, ρ, ν, γ, w_grid, z_draws = model # Initialize arrays - key, subkey = jax.random.split(key) + init_key, subkey = jax.random.split(key) wages = jnp.exp(jax.random.normal(subkey, (n_agents,)) * ν) status = jnp.zeros(n_agents, dtype=jnp.int32) def update(t, loop_state): - key, status, wages = loop_state + status, wages = loop_state # Shift loop state forwards - key, subkey = jax.random.split(key) - agent_keys = jax.random.split(subkey, n_agents) + step_key = jax.random.fold_in(init_key, t) + agent_keys = jax.random.split(step_key, n_agents) status, wages = update_agents_vmap( agent_keys, status, wages, model, w_bar ) - return key, status, wages + return status, wages # Run simulation using fori_loop - initial_loop_state = (key, status, wages) + initial_loop_state = (status, wages) final_loop_state = lax.fori_loop(0, T, update, initial_loop_state) # Return only final employment state - _, final_is_employed, _ = final_loop_state + final_is_employed, _ = final_loop_state return final_is_employed diff --git a/lectures/mccall_model_with_sep_markov.md b/lectures/mccall_model_with_sep_markov.md index c6f36f4f7..49d1bcb35 100644 --- a/lectures/mccall_model_with_sep_markov.md +++ b/lectures/mccall_model_with_sep_markov.md @@ -751,8 +751,8 @@ Now let's simulate many agents simultaneously to examine the cross-sectional une We first create a vectorized version of `update_agent` to efficiently update all agents in parallel: ```{code-cell} ipython3 -# Create vectorized version of update_agent -# The last parameter is now w_bar (scalar) instead of σ (array) +# Create vectorized version of update_agent. +# Vectorize over key, status, wage_idx update_agents_vmap = jax.vmap( update_agent, in_axes=(0, 0, 0, None, None) ) @@ -761,61 +761,56 @@ update_agents_vmap = jax.vmap( Next we define the core simulation function, which uses `lax.fori_loop` to efficiently iterate many agents forward in time: ```{code-cell} ipython3 -@partial(jax.jit, static_argnums=(3, 4)) +@jax.jit def _simulate_cross_section_compiled( key: jnp.ndarray, model: Model, w_bar: float, - n_agents: int, + initial_wage_indices: jnp.ndarray, + initial_status_vec: jnp.ndarray, T: int ): - """JIT-compiled core simulation loop using lax.fori_loop. - Returns only the final employment state to save memory.""" + """ + JIT-compiled core simulation loop for shifting the cross section + using lax.fori_loop. Returns the final employment employment status + cross-section. + + """ n, w_vals, P, P_cumsum, β, c, α, γ = model + n_agents = len(initial_wage_indices) - # Initialize arrays - wage_indices = jnp.zeros(n_agents, dtype=jnp.int32) - status = jnp.zeros(n_agents, dtype=jnp.int32) def update(t, loop_state): - key, status, wage_indices = loop_state - - # Shift loop state forwards - key, subkey = jax.random.split(key) - agent_keys = jax.random.split(subkey, n_agents) - + " Shift loop state forwards " + status, wage_indices = loop_state + step_key = jax.random.fold_in(key, t) + agent_keys = jax.random.split(step_key, n_agents) status, wage_indices = update_agents_vmap( agent_keys, status, wage_indices, model, w_bar ) - - return key, status, wage_indices + return status, wage_indices # Run simulation using fori_loop - initial_loop_state = (key, status, wage_indices) + initial_loop_state = (initial_status_vec, initial_wage_indices) final_loop_state = lax.fori_loop(0, T, update, initial_loop_state) # Return only final employment state - _, final_is_employed, _ = final_loop_state + final_is_employed, _ = final_loop_state return final_is_employed def simulate_cross_section( - model: Model, - n_agents: int = 100_000, - T: int = 200, - seed: int = 42 + model: Model, # Model instance with parameters + n_agents: int = 100_000, # Number of agents to simulate + T: int = 200, # Length of burn-in + seed: int = 42 # For reproducibility ) -> float: """ - Simulate employment paths for many agents and return final unemployment rate. + Wrapper function for _simulate_cross_section_compiled. - Parameters: - - model: Model instance with parameters - - n_agents: Number of agents to simulate - - T: Number of periods to simulate - - seed: Random seed for reproducibility + Push forward a cross-section for T periods and return the final + cross-sectional unemployment rate. - Returns: - - unemployment_rate: Fraction of agents unemployed at time T """ key = jax.random.PRNGKey(seed) @@ -823,14 +818,15 @@ def simulate_cross_section( v_u = vfi(model) w_bar = get_reservation_wage(v_u, model) - # Run JIT-compiled simulation + # Initialize arrays + initial_wage_indices = jnp.zeros(n_agents, dtype=jnp.int32) + initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32) + final_status = _simulate_cross_section_compiled( - key, model, w_bar, n_agents, T + key, model, w_bar, initial_wage_indices, initial_status_vec, T ) - # Calculate unemployment rate at final period unemployment_rate = 1 - jnp.mean(final_status) - return unemployment_rate ``` @@ -850,8 +846,13 @@ def plot_cross_sectional_unemployment( key = jax.random.PRNGKey(42) v_u = vfi(model) w_bar = get_reservation_wage(v_u, model) + + # Initialize arrays + initial_wage_indices = jnp.zeros(n_agents, dtype=jnp.int32) + initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32) + final_status = _simulate_cross_section_compiled( - key, model, w_bar, n_agents, t_snapshot + key, model, w_bar, initial_wage_indices, initial_status_vec, t_snapshot ) # Calculate unemployment rate From f1ad9ff1421435d2153abddb7b734ff500fbe5fc Mon Sep 17 00:00:00 2001 From: Humphrey Yang Date: Mon, 17 Nov 2025 20:21:35 +1100 Subject: [PATCH 2/2] fix minor typos and section title capitalization --- lectures/mccall_fitted_vfi.md | 9 ++++-- lectures/mccall_model_with_sep_markov.md | 38 ++++++++++++------------ 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/lectures/mccall_fitted_vfi.md b/lectures/mccall_fitted_vfi.md index 5f05c56cb..8cd5fb2b7 100644 --- a/lectures/mccall_fitted_vfi.md +++ b/lectures/mccall_fitted_vfi.md @@ -361,14 +361,17 @@ def vfi( ``` Here's a function that uses a solution $v_u$ to compute the remaining functions of -interest: $v_u$, and the continuation value function $h$. +interest: $v_e$, and the continuation value function $h$. We use the same expressions as we did in the {doc}`discrete case `, after replacing sums with integrals. ```{code-cell} ipython3 def compute_solution_functions(model, v_u): - # Interpolate v_u + # Unpack model parameters + c, α, β, ρ, ν, γ, w_grid, z_draws = model + + # Interpolate v_u on the wage grid vf = lambda x: jnp.interp(x, w_grid, v_u) def compute_expectation(w): @@ -604,7 +607,7 @@ When unemployed, the agent accepts offers that exceed the reservation wage. When employed, the agent faces job separation with probability $\alpha$ each period. -### Cross-Sectional Analysis +### Cross-sectional analysis Now let's simulate many agents simultaneously to examine the cross-sectional unemployment rate. diff --git a/lectures/mccall_model_with_sep_markov.md b/lectures/mccall_model_with_sep_markov.md index 49d1bcb35..36aa1be3f 100644 --- a/lectures/mccall_model_with_sep_markov.md +++ b/lectures/mccall_model_with_sep_markov.md @@ -49,7 +49,7 @@ libraries ```{code-cell} ipython3 :tags: [hide-output] -!pip install quantecon +!pip install quantecon jax ``` We use the following imports: @@ -64,7 +64,7 @@ import matplotlib.pyplot as plt from functools import partial ``` -## Model Setup +## Model setup The setting is as follows: @@ -74,7 +74,7 @@ The setting is as follows: - Unemployed workers receive compensation $c$ per period - Future payoffs are discounted by factor $\beta \in (0,1)$ -### Decision Problem +### Decision problem When unemployed and receiving wage offer $w$, the agent chooses between: @@ -86,7 +86,7 @@ The wage updates are as follows: * If an unemployed agent rejects offer $w$, then their next offer is drawn from $P(w, \cdot)$ * If an employed agent loses a job in which they were paid wage $w$, then their next offer is drawn from $P(w, \cdot)$ -### The Wage Offer Process +### The wage offer process To construct the wage offer process we start with an AR1 process. @@ -112,7 +112,7 @@ Actually, in practice, we approximate this wage process as follows: -### Value Functions +### Value functions We let @@ -168,12 +168,12 @@ $$ +++ -### Optimal Policy +### Optimal policy Once we have the solutions $v_e$ and $v_u$ to these Bellman equations, we can compute the optimal policy: accept at current wage offer $w$ if $$ - v_e(w) ≥ u(c) + β(Pv_u)(w) + v_e(w) \geq u(c) + \beta (P v_u)(w) $$ The optimal policy turns out to be a reservation wage strategy: accept all wages above some threshold. @@ -185,7 +185,7 @@ The optimal policy turns out to be a reservation wage strategy: accept all wages Let's now implement the model. -### Set Up +### Set up The default utility function is a CRRA utility function @@ -234,7 +234,7 @@ def create_js_with_sep_model( ``` -### Solution: First Pass +### Solution: first pass Let's put together a (not very efficient) routine for calculating the reservation wage. @@ -244,7 +244,7 @@ reservation wage. It works by starting with guesses for $v_e$ and $v_u$ and iterating to convergence. -Here's are Bellman operators that update $v_u$ and $v_e$ respectively. +Here are Bellman operators that update $v_u$ and $v_e$ respectively. ```{code-cell} ipython3 @@ -313,7 +313,7 @@ def solve_model_first_pass( ``` -### Road Test +### Road test Let's solve the model: @@ -348,9 +348,9 @@ The reservation wage is at the intersection of $v_e$, and the continuation value function, which is the value of rejecting. -## Improving Efficiency +## Improving efficiency -The solution method desribed above works fine but we can do much better. +The solution method described above works fine but we can do much better. First, we use the employed worker's Bellman equation to express $v_e$ in terms of $Pv_u$ @@ -495,7 +495,7 @@ The result is the same as before but we only iterate on one array --- and also our JAX code is more efficient. -## Sensitivity Analysis +## Sensitivity analysis Let's examine how reservation wages change with the separation rate. @@ -523,7 +523,7 @@ Can you provide an intuitive economic story behind the outcome that you see in t +++ -## Employment Simulation +## Employment simulation Now let's simulate the employment dynamics of a single agent under the optimal policy. @@ -691,7 +691,7 @@ often leads a high new draw. +++ -## The Ergodic Property +## Ergodic property Below we examine cross-sectional unemployment. @@ -699,7 +699,7 @@ In particular, we will look at the unemployment rate in a cross-sectional simulation and compare it to the time-average unemployment rate, which is the fraction of time an agent spends unemployed over a long time series. -We will see that these two values are approximately equal -- if fact they are +We will see that these two values are approximately equal -- in fact they are exactly equal in the limit. The reason is that the process $(S_t, W_t)$, where @@ -744,7 +744,7 @@ Often the second approach is better for our purposes, since it's easier to paral +++ -## Cross-Sectional Analysis +## Cross-sectional analysis Now let's simulate many agents simultaneously to examine the cross-sectional unemployment rate. @@ -907,7 +907,7 @@ Now let's visualize the cross-sectional distribution: plot_cross_sectional_unemployment(model) ``` -## Lower Unemployment Compensation (c=0.5) +## Lower unemployment compensation (c=0.5) What happens to the cross-sectional unemployment rate with lower unemployment compensation?