From f0516084c7f6bbf87be96399f1be9733ea39902b Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Fri, 14 Nov 2025 07:55:25 +0900 Subject: [PATCH] misc --- lectures/mccall_model_with_sep_markov.md | 97 ++++++++++++------------ 1 file changed, 47 insertions(+), 50 deletions(-) diff --git a/lectures/mccall_model_with_sep_markov.md b/lectures/mccall_model_with_sep_markov.md index 013a618b0..ca7ac8f35 100644 --- a/lectures/mccall_model_with_sep_markov.md +++ b/lectures/mccall_model_with_sep_markov.md @@ -34,8 +34,8 @@ kernelspec: This lecture builds on the job search model with separation presented in the {doc}`previous lecture `. -The key difference is that wage offers now follow a **Markov chain** rather than -being independent and identically distributed (IID). +The key difference is that wage offers now follow a {doc}`Markov chain ` rather than +being IID. This modification adds persistence to the wage offer process, meaning that today's wage offer provides information about tomorrow's offer. @@ -264,8 +264,7 @@ def T(v: jnp.ndarray, model: Model) -> jnp.ndarray: return jnp.maximum(accept, reject) ``` -Here's a routine for value function iteration, as well as a second routine that -computes the reservation wage directly from the value function. +Here's a routine for value function iteration. ```{code-cell} ipython3 @jax.jit @@ -293,39 +292,38 @@ def vfi( v_final, error, i = final_loop_state return v_final +``` +Here is a routine that computes the reservation wage from the value function. +```{code-cell} ipython3 @jax.jit def get_reservation_wage(v: jnp.ndarray, model: Model) -> float: """ - Calculate the reservation wage directly from the value function. + Calculate the reservation wage from the unemployed agents + value function v := v_u. The reservation wage is the lowest wage w where accepting (v_e(w)) - is at least as good as rejecting (u(c) + β(Pv)(w)). - - Parameters: - - v: Value function v_u - - model: Model instance containing parameters + is at least as good as rejecting (u(c) + β(Pv_u)(w)). - Returns: - - Reservation wage (lowest wage for which acceptance is optimal) """ n, w_vals, P, P_cumsum, β, c, α, γ = model # Compute accept and reject values d = 1 / (1 - β * (1 - α)) - accept = d * (u(w_vals, γ) + α * β * P @ v) - reject = u(c, γ) + β * P @ v + v_e = d * (u(w_vals, γ) + α * β * P @ v) + continuation_value = u(c, γ) + β * P @ v # Find where acceptance becomes optimal - should_accept = accept >= reject - first_accept_idx = jnp.argmax(should_accept) + accept_indices = v_e >= continuation_value + first_accept_idx = jnp.argmax(accept_indices) # index of first True # If no acceptance (all False), return infinity # Otherwise return the wage at the first acceptance index - return jnp.where(jnp.any(should_accept), w_vals[first_accept_idx], jnp.inf) + return jnp.where(jnp.any(accept_indices), w_vals[first_accept_idx], jnp.inf) ``` + ## Computing the Solution Let's solve the model: @@ -333,27 +331,26 @@ Let's solve the model: ```{code-cell} ipython3 model = create_js_with_sep_model() n, w_vals, P, P_cumsum, β, c, α, γ = model -v_star = vfi(model) -w_star = get_reservation_wage(v_star, model) +v_u = vfi(model) +w_bar = get_reservation_wage(v_u, model) ``` Next we compute some related quantities for plotting. ```{code-cell} ipython3 d = 1 / (1 - β * (1 - α)) -accept = d * (u(w_vals, γ) + α * β * P @ v_star) -h_star = u(c, γ) + β * P @ v_star +v_e = d * (u(w_vals, γ) + α * β * P @ v_u) +h = u(c, γ) + β * P @ v_u ``` Let's plot our results. ```{code-cell} ipython3 fig, ax = plt.subplots(figsize=(9, 5.2)) -ax.plot(w_vals, h_star, linewidth=4, ls="--", alpha=0.4, - label="continuation value") -ax.plot(w_vals, accept, linewidth=4, ls="--", alpha=0.4, - label="stopping value") -ax.plot(w_vals, v_star, "k-", alpha=0.7, label=r"$v_u^*(w)$") +ax.plot(w_vals, h, 'g-', linewidth=2, + label="continuation value function $h$") +ax.plot(w_vals, v_e, 'b-', linewidth=2, + label="employment value function $v_e$") ax.legend(frameon=False) ax.set_xlabel(r"$w$") plt.show() @@ -370,16 +367,16 @@ Let's examine how reservation wages change with the separation rate. ```{code-cell} ipython3 α_vals: jnp.ndarray = jnp.linspace(0.0, 1.0, 10) -w_star_vec = [] +w_bar_vec = [] for α in α_vals: model = create_js_with_sep_model(α=α) - v_star = vfi(model) - w_star = get_reservation_wage(v_star, model) - w_star_vec.append(w_star) + v_u = vfi(model) + w_bar = get_reservation_wage(v_u, model) + w_bar_vec.append(w_bar) fig, ax = plt.subplots(figsize=(9, 5.2)) ax.plot( - α_vals, w_star_vec, linewidth=2, alpha=0.6, label="reservation wage" + α_vals, w_bar_vec, linewidth=2, alpha=0.6, label="reservation wage" ) ax.legend(frameon=False) ax.set_xlabel(r"$\alpha$") @@ -414,7 +411,7 @@ unemployed, 1 if employed) and $w_t$ is * their current wage, if employed. ```{code-cell} ipython3 -def update_agent(key, status, wage_idx, model, w_star): +def update_agent(key, status, wage_idx, model, w_bar): """ Updates an agent's employment status and current wage. @@ -423,7 +420,7 @@ def update_agent(key, status, wage_idx, model, w_star): - status: Current employment status (0 or 1) - wage_idx: Current wage, recorded as an array index - model: Model instance - - w_star: Reservation wage + - w_bar: Reservation wage """ n, w_vals, P, P_cumsum, β, c, α, γ = model @@ -436,7 +433,7 @@ def update_agent(key, status, wage_idx, model, w_star): ) separation_occurs = jax.random.uniform(key2) < α # Accept if current wage meets or exceeds reservation wage - accepts = w_vals[wage_idx] >= w_star + accepts = w_vals[wage_idx] >= w_bar # If employed: status = 1 if no separation, 0 if separation # If unemployed: status = 1 if accepts, 0 if rejects @@ -462,7 +459,7 @@ Here's a function to simulate the employment path of a single agent. ```{code-cell} ipython3 def simulate_employment_path( model: Model, # Model details - w_star: float, # Reservation wage + w_bar: float, # Reservation wage T: int = 2_000, # Simulation length seed: int = 42 # Set seed for simulation ): @@ -487,7 +484,7 @@ def simulate_employment_path( key, subkey = jax.random.split(key) status, wage_idx = update_agent( - subkey, status, wage_idx, model, w_star + subkey, status, wage_idx, model, w_bar ) return jnp.array(wage_path), jnp.array(status_path) @@ -499,10 +496,10 @@ Let's create a comprehensive plot of the employment simulation: model = create_js_with_sep_model() # Calculate reservation wage for plotting -v_star = vfi(model) -w_star = get_reservation_wage(v_star, model) +v_u = vfi(model) +w_bar = get_reservation_wage(v_u, model) -wage_path, employment_status = simulate_employment_path(model, w_star) +wage_path, employment_status = simulate_employment_path(model, w_bar) fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(8, 6)) @@ -518,8 +515,8 @@ ax1.set_ylim(-0.1, 1.1) # Plot wage path with employment status coloring ax2.plot(wage_path, 'b-', alpha=0.7, linewidth=1) -ax2.axhline(y=w_star, color='black', linestyle='--', alpha=0.8, - label=f'Reservation wage: {w_star:.2f}') +ax2.axhline(y=w_bar, color='black', linestyle='--', alpha=0.8, + label=f'Reservation wage: {w_bar:.2f}') ax2.set_xlabel('time') ax2.set_ylabel('wage') ax2.set_title('Wage path (actual and offers)') @@ -620,7 +617,7 @@ We first create a vectorized version of `update_agent` to efficiently update all ```{code-cell} ipython3 # Create vectorized version of update_agent -# The last parameter is now w_star (scalar) instead of σ (array) +# The last parameter is now w_bar (scalar) instead of σ (array) update_agents_vmap = jax.vmap( update_agent, in_axes=(0, 0, 0, None, None) ) @@ -633,7 +630,7 @@ Next we define the core simulation function, which uses `lax.fori_loop` to effic def _simulate_cross_section_compiled( key: jnp.ndarray, model: Model, - w_star: float, + w_bar: float, n_agents: int, T: int ): @@ -653,7 +650,7 @@ def _simulate_cross_section_compiled( agent_keys = jax.random.split(subkey, n_agents) status, wage_indices = update_agents_vmap( - agent_keys, status, wage_indices, model, w_star + agent_keys, status, wage_indices, model, w_bar ) return key, status, wage_indices @@ -688,12 +685,12 @@ def simulate_cross_section( key = jax.random.PRNGKey(seed) # Solve for optimal reservation wage - v_star = vfi(model) - w_star = get_reservation_wage(v_star, model) + v_u = vfi(model) + w_bar = get_reservation_wage(v_u, model) # Run JIT-compiled simulation final_status = _simulate_cross_section_compiled( - key, model, w_star, n_agents, T + key, model, w_bar, n_agents, T ) # Calculate unemployment rate at final period @@ -717,10 +714,10 @@ def plot_cross_sectional_unemployment(model: Model, t_snapshot: int = 200, """ # Get final employment state directly key = jax.random.PRNGKey(42) - v_star = vfi(model) - w_star = get_reservation_wage(v_star, model) + v_u = vfi(model) + w_bar = get_reservation_wage(v_u, model) final_status = _simulate_cross_section_compiled( - key, model, w_star, n_agents, t_snapshot + key, model, w_bar, n_agents, t_snapshot ) # Calculate unemployment rate