Skip to content

Commit c6cc43c

Browse files
jstacclaudeHumphreyYang
authored
Improve key handling and fix parameter consistency in McCall lectures (#715)
* Improve key handling and fix parameter consistency in McCall lectures - Refactor random key handling to use fold_in instead of key threading - More idiomatic JAX pattern for indexed loops - Removes key from loop state for cleaner code - Deterministic randomness based on time step - Fix missing n_agents variable in _simulate_cross_section_compiled - Extract from initial_wage_indices using len() - Standardize separation rate across lectures - Set α = 0.05 in mccall_fitted_vfi to match mccall_model_with_sep_markov - All economic parameters now consistent between lectures 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * fix minor typos and section title capitalization --------- Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Humphrey Yang <u6474961@anu.edu.au>
1 parent 977a93e commit c6cc43c

File tree

2 files changed

+70
-66
lines changed

2 files changed

+70
-66
lines changed

lectures/mccall_fitted_vfi.md

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ class Model(NamedTuple):
268268
269269
def create_mccall_model(
270270
c: float = 1.0,
271-
α: float = 0.1,
271+
α: float = 0.05,
272272
β: float = 0.96,
273273
ρ: float = 0.9,
274274
ν: float = 0.2,
@@ -361,14 +361,17 @@ def vfi(
361361
```
362362

363363
Here's a function that uses a solution $v_u$ to compute the remaining functions of
364-
interest: $v_u$, and the continuation value function $h$.
364+
interest: $v_e$, and the continuation value function $h$.
365365

366366
We use the same expressions as we did in the {doc}`discrete case <mccall_model_with_sep_markov>`, after replacing sums with integrals.
367367

368368
```{code-cell} ipython3
369369
def compute_solution_functions(model, v_u):
370370
371-
# Interpolate v_u
371+
# Unpack model parameters
372+
c, α, β, ρ, ν, γ, w_grid, z_draws = model
373+
374+
# Interpolate v_u on the wage grid
372375
vf = lambda x: jnp.interp(x, w_grid, v_u)
373376
374377
def compute_expectation(w):
@@ -604,7 +607,7 @@ When unemployed, the agent accepts offers that exceed the reservation wage.
604607

605608
When employed, the agent faces job separation with probability $\alpha$ each period.
606609

607-
### Cross-Sectional Analysis
610+
### Cross-sectional analysis
608611

609612
Now let's simulate many agents simultaneously to examine the cross-sectional unemployment rate.
610613

@@ -633,29 +636,29 @@ def _simulate_cross_section_compiled(
633636
c, α, β, ρ, ν, γ, w_grid, z_draws = model
634637
635638
# Initialize arrays
636-
key, subkey = jax.random.split(key)
639+
init_key, subkey = jax.random.split(key)
637640
wages = jnp.exp(jax.random.normal(subkey, (n_agents,)) * ν)
638641
status = jnp.zeros(n_agents, dtype=jnp.int32)
639642
640643
def update(t, loop_state):
641-
key, status, wages = loop_state
644+
status, wages = loop_state
642645
643646
# Shift loop state forwards
644-
key, subkey = jax.random.split(key)
645-
agent_keys = jax.random.split(subkey, n_agents)
647+
step_key = jax.random.fold_in(init_key, t)
648+
agent_keys = jax.random.split(step_key, n_agents)
646649
647650
status, wages = update_agents_vmap(
648651
agent_keys, status, wages, model, w_bar
649652
)
650653
651-
return key, status, wages
654+
return status, wages
652655
653656
# Run simulation using fori_loop
654-
initial_loop_state = (key, status, wages)
657+
initial_loop_state = (status, wages)
655658
final_loop_state = lax.fori_loop(0, T, update, initial_loop_state)
656659
657660
# Return only final employment state
658-
_, final_is_employed, _ = final_loop_state
661+
final_is_employed, _ = final_loop_state
659662
return final_is_employed
660663
661664

lectures/mccall_model_with_sep_markov.md

Lines changed: 56 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ libraries
4949
```{code-cell} ipython3
5050
:tags: [hide-output]
5151
52-
!pip install quantecon
52+
!pip install quantecon jax
5353
```
5454

5555
We use the following imports:
@@ -64,7 +64,7 @@ import matplotlib.pyplot as plt
6464
from functools import partial
6565
```
6666

67-
## Model Setup
67+
## Model setup
6868

6969
The setting is as follows:
7070

@@ -74,7 +74,7 @@ The setting is as follows:
7474
- Unemployed workers receive compensation $c$ per period
7575
- Future payoffs are discounted by factor $\beta \in (0,1)$
7676

77-
### Decision Problem
77+
### Decision problem
7878

7979
When unemployed and receiving wage offer $w$, the agent chooses between:
8080

@@ -86,7 +86,7 @@ The wage updates are as follows:
8686
* If an unemployed agent rejects offer $w$, then their next offer is drawn from $P(w, \cdot)$
8787
* If an employed agent loses a job in which they were paid wage $w$, then their next offer is drawn from $P(w, \cdot)$
8888

89-
### The Wage Offer Process
89+
### The wage offer process
9090

9191
To construct the wage offer process we start with an AR1 process.
9292

@@ -112,7 +112,7 @@ Actually, in practice, we approximate this wage process as follows:
112112

113113

114114

115-
### Value Functions
115+
### Value functions
116116

117117
We let
118118

@@ -168,12 +168,12 @@ $$
168168

169169
+++
170170

171-
### Optimal Policy
171+
### Optimal policy
172172

173173
Once we have the solutions $v_e$ and $v_u$ to these Bellman equations, we can compute the optimal policy: accept at current wage offer $w$ if
174174

175175
$$
176-
v_e(w) u(c) + β(Pv_u)(w)
176+
v_e(w) \geq u(c) + \beta (P v_u)(w)
177177
$$
178178

179179
The optimal policy turns out to be a reservation wage strategy: accept all wages above some threshold.
@@ -185,7 +185,7 @@ The optimal policy turns out to be a reservation wage strategy: accept all wages
185185

186186
Let's now implement the model.
187187

188-
### Set Up
188+
### Set up
189189

190190
The default utility function is a CRRA utility function
191191

@@ -234,7 +234,7 @@ def create_js_with_sep_model(
234234
```
235235

236236

237-
### Solution: First Pass
237+
### Solution: first pass
238238

239239
Let's put together a (not very efficient) routine for calculating the
240240
reservation wage.
@@ -244,7 +244,7 @@ reservation wage.
244244
It works by starting with guesses for $v_e$ and $v_u$ and iterating to
245245
convergence.
246246

247-
Here's are Bellman operators that update $v_u$ and $v_e$ respectively.
247+
Here are Bellman operators that update $v_u$ and $v_e$ respectively.
248248

249249

250250
```{code-cell} ipython3
@@ -313,7 +313,7 @@ def solve_model_first_pass(
313313
```
314314

315315

316-
### Road Test
316+
### Road test
317317

318318
Let's solve the model:
319319

@@ -348,9 +348,9 @@ The reservation wage is at the intersection of $v_e$, and the continuation value
348348
function, which is the value of rejecting.
349349

350350

351-
## Improving Efficiency
351+
## Improving efficiency
352352

353-
The solution method desribed above works fine but we can do much better.
353+
The solution method described above works fine but we can do much better.
354354

355355
First, we use the employed worker's Bellman equation to express
356356
$v_e$ in terms of $Pv_u$
@@ -495,7 +495,7 @@ The result is the same as before but we only iterate on one array --- and also
495495
our JAX code is more efficient.
496496

497497

498-
## Sensitivity Analysis
498+
## Sensitivity analysis
499499

500500
Let's examine how reservation wages change with the separation rate.
501501

@@ -523,7 +523,7 @@ Can you provide an intuitive economic story behind the outcome that you see in t
523523

524524
+++
525525

526-
## Employment Simulation
526+
## Employment simulation
527527

528528
Now let's simulate the employment dynamics of a single agent under the optimal policy.
529529

@@ -691,15 +691,15 @@ often leads a high new draw.
691691

692692
+++
693693

694-
## The Ergodic Property
694+
## Ergodic property
695695

696696
Below we examine cross-sectional unemployment.
697697

698698
In particular, we will look at the unemployment rate in a cross-sectional
699699
simulation and compare it to the time-average unemployment rate, which is the
700700
fraction of time an agent spends unemployed over a long time series.
701701

702-
We will see that these two values are approximately equal -- if fact they are
702+
We will see that these two values are approximately equal -- in fact they are
703703
exactly equal in the limit.
704704

705705
The reason is that the process $(S_t, W_t)$, where
@@ -744,15 +744,15 @@ Often the second approach is better for our purposes, since it's easier to paral
744744

745745
+++
746746

747-
## Cross-Sectional Analysis
747+
## Cross-sectional analysis
748748

749749
Now let's simulate many agents simultaneously to examine the cross-sectional unemployment rate.
750750

751751
We first create a vectorized version of `update_agent` to efficiently update all agents in parallel:
752752

753753
```{code-cell} ipython3
754-
# Create vectorized version of update_agent
755-
# The last parameter is now w_bar (scalar) instead of σ (array)
754+
# Create vectorized version of update_agent.
755+
# Vectorize over key, status, wage_idx
756756
update_agents_vmap = jax.vmap(
757757
update_agent, in_axes=(0, 0, 0, None, None)
758758
)
@@ -761,76 +761,72 @@ update_agents_vmap = jax.vmap(
761761
Next we define the core simulation function, which uses `lax.fori_loop` to efficiently iterate many agents forward in time:
762762

763763
```{code-cell} ipython3
764-
@partial(jax.jit, static_argnums=(3, 4))
764+
@jax.jit
765765
def _simulate_cross_section_compiled(
766766
key: jnp.ndarray,
767767
model: Model,
768768
w_bar: float,
769-
n_agents: int,
769+
initial_wage_indices: jnp.ndarray,
770+
initial_status_vec: jnp.ndarray,
770771
T: int
771772
):
772-
"""JIT-compiled core simulation loop using lax.fori_loop.
773-
Returns only the final employment state to save memory."""
773+
"""
774+
JIT-compiled core simulation loop for shifting the cross section
775+
using lax.fori_loop. Returns the final employment employment status
776+
cross-section.
777+
778+
"""
774779
n, w_vals, P, P_cumsum, β, c, α, γ = model
780+
n_agents = len(initial_wage_indices)
775781
776-
# Initialize arrays
777-
wage_indices = jnp.zeros(n_agents, dtype=jnp.int32)
778-
status = jnp.zeros(n_agents, dtype=jnp.int32)
779782
780783
def update(t, loop_state):
781-
key, status, wage_indices = loop_state
782-
783-
# Shift loop state forwards
784-
key, subkey = jax.random.split(key)
785-
agent_keys = jax.random.split(subkey, n_agents)
786-
784+
" Shift loop state forwards "
785+
status, wage_indices = loop_state
786+
step_key = jax.random.fold_in(key, t)
787+
agent_keys = jax.random.split(step_key, n_agents)
787788
status, wage_indices = update_agents_vmap(
788789
agent_keys, status, wage_indices, model, w_bar
789790
)
790-
791-
return key, status, wage_indices
791+
return status, wage_indices
792792
793793
# Run simulation using fori_loop
794-
initial_loop_state = (key, status, wage_indices)
794+
initial_loop_state = (initial_status_vec, initial_wage_indices)
795795
final_loop_state = lax.fori_loop(0, T, update, initial_loop_state)
796796
797797
# Return only final employment state
798-
_, final_is_employed, _ = final_loop_state
798+
final_is_employed, _ = final_loop_state
799799
return final_is_employed
800800
801801
802802
def simulate_cross_section(
803-
model: Model,
804-
n_agents: int = 100_000,
805-
T: int = 200,
806-
seed: int = 42
803+
model: Model, # Model instance with parameters
804+
n_agents: int = 100_000, # Number of agents to simulate
805+
T: int = 200, # Length of burn-in
806+
seed: int = 42 # For reproducibility
807807
) -> float:
808808
"""
809-
Simulate employment paths for many agents and return final unemployment rate.
809+
Wrapper function for _simulate_cross_section_compiled.
810810
811-
Parameters:
812-
- model: Model instance with parameters
813-
- n_agents: Number of agents to simulate
814-
- T: Number of periods to simulate
815-
- seed: Random seed for reproducibility
811+
Push forward a cross-section for T periods and return the final
812+
cross-sectional unemployment rate.
816813
817-
Returns:
818-
- unemployment_rate: Fraction of agents unemployed at time T
819814
"""
820815
key = jax.random.PRNGKey(seed)
821816
822817
# Solve for optimal reservation wage
823818
v_u = vfi(model)
824819
w_bar = get_reservation_wage(v_u, model)
825820
826-
# Run JIT-compiled simulation
821+
# Initialize arrays
822+
initial_wage_indices = jnp.zeros(n_agents, dtype=jnp.int32)
823+
initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32)
824+
827825
final_status = _simulate_cross_section_compiled(
828-
key, model, w_bar, n_agents, T
826+
key, model, w_bar, initial_wage_indices, initial_status_vec, T
829827
)
830828
831-
# Calculate unemployment rate at final period
832829
unemployment_rate = 1 - jnp.mean(final_status)
833-
834830
return unemployment_rate
835831
```
836832

@@ -850,8 +846,13 @@ def plot_cross_sectional_unemployment(
850846
key = jax.random.PRNGKey(42)
851847
v_u = vfi(model)
852848
w_bar = get_reservation_wage(v_u, model)
849+
850+
# Initialize arrays
851+
initial_wage_indices = jnp.zeros(n_agents, dtype=jnp.int32)
852+
initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32)
853+
853854
final_status = _simulate_cross_section_compiled(
854-
key, model, w_bar, n_agents, t_snapshot
855+
key, model, w_bar, initial_wage_indices, initial_status_vec, t_snapshot
855856
)
856857
857858
# Calculate unemployment rate
@@ -906,7 +907,7 @@ Now let's visualize the cross-sectional distribution:
906907
plot_cross_sectional_unemployment(model)
907908
```
908909

909-
## Lower Unemployment Compensation (c=0.5)
910+
## Lower unemployment compensation (c=0.5)
910911

911912
What happens to the cross-sectional unemployment rate with lower unemployment compensation?
912913

0 commit comments

Comments
 (0)