Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ jobs:
- name: Install JAX, Numpyro, PyTorch
shell: bash -l {0}
run: |
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
pip install pyro-ppl
pip install --upgrade "jax[cuda12-local]==0.6.2"
# pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
# pip install pyro-ppl
pip install "jax[cuda12-local]==0.6.2"
pip install numpyro pyro-ppl
python scripts/test-jax-install.py
- name: Check nvidia Drivers
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
135 changes: 61 additions & 74 deletions lectures/mccall_model_with_sep_markov.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.17.1
jupytext_version: 1.17.2
kernelspec:
name: python3
display_name: Python 3 (ipykernel)
language: python
name: python3
---

(mccall_with_sep_markov)=
Expand Down Expand Up @@ -49,7 +49,7 @@ libraries
```{code-cell} ipython3
:tags: [hide-output]

!pip install quantecon jax
!pip install quantecon
```

We use the following imports:
Expand All @@ -58,7 +58,7 @@ We use the following imports:
from quantecon.markov import tauchen
import jax.numpy as jnp
import jax
from jax import jit, lax
from jax import lax
from typing import NamedTuple
import matplotlib.pyplot as plt
from functools import partial
Expand Down Expand Up @@ -138,48 +138,11 @@ The optimal policy turns out to be a reservation wage strategy: accept all wages

## Code


First, we implement the successive approximation algorithm.

This algorithm takes an operator $T$ and an initial condition and iterates until
convergence.

We will use it for value function iteration.

```{code-cell} ipython3
@partial(jit, static_argnums=(0,))
def successive_approx(
T, # Operator (callable) - marked as static
x_0, # Initial condition
tolerance: float = 1e-6, # Error tolerance
max_iter: int = 100_000, # Max iteration bound
):
"""Computes the approximate fixed point of T via successive
approximation using lax.while_loop."""

def cond_fn(carry):
x, error, k = carry
return (error > tolerance) & (k <= max_iter)

def body_fn(carry):
x, error, k = carry
x_new = T(x)
error = jnp.max(jnp.abs(x_new - x))
return (x_new, error, k + 1)

initial_carry = (x_0, tolerance + 1, 1)
x_final, _, _ = lax.while_loop(cond_fn, body_fn, initial_carry)

return x_final
```


Next let's set up a `Model` class to store information needed to solve the model.
Let's set up a `Model` class to store information needed to solve the model.

We include `P_cumsum`, the row-wise cumulative sum of the transition matrix, to
optimize the simulation -- the details are explained below.


```{code-cell} ipython3
class Model(NamedTuple):
n: int
Expand Down Expand Up @@ -215,7 +178,6 @@ def create_js_with_sep_model(
Here's the Bellman operator for the unemployed worker's value function:

```{code-cell} ipython3
@jit
def T(v: jnp.ndarray, model: Model) -> jnp.ndarray:
"""The Bellman operator for the value of being unemployed."""
n, w_vals, P, P_cumsum, β, c, α = model
Expand All @@ -229,7 +191,6 @@ The next function computes the optimal policy under the assumption that $v$ is
the value function:

```{code-cell} ipython3
@jit
def get_greedy(v: jnp.ndarray, model: Model) -> jnp.ndarray:
"""Get a v-greedy policy."""
n, w_vals, P, P_cumsum, β, c, α = model
Expand All @@ -247,14 +208,34 @@ The second routine requires a policy function, which we will typically obtain by
applying the `vfi` function.

```{code-cell} ipython3
def vfi(model: Model):
"""Solve by VFI."""
@jax.jit
def vfi(
model: Model,
tolerance: float = 1e-6, # Error tolerance
max_iter: int = 100_000, # Max iteration bound
):

v_init = jnp.zeros(model.w_vals.shape)
v_star = successive_approx(lambda v: T(v, model), v_init)
σ_star = get_greedy(v_star, model)
return v_star, σ_star

def cond(loop_state):
v, error, i = loop_state
return (error > tolerance) & (i <= max_iter)

def update(loop_state):
v, error, i = loop_state
v_new = T(v, model)
error = jnp.max(jnp.abs(v_new - v))
new_loop_state = v_new, error, i + 1
return new_loop_state

initial_state = (v_init, tolerance + 1, 1)
final_loop_state = lax.while_loop(cond, update, initial_state)
v_final, error, i = final_loop_state

return v_final


@jax.jit
def get_reservation_wage(σ: jnp.ndarray, model: Model) -> float:
"""
Calculate the reservation wage from a given policy.
Expand All @@ -268,25 +249,24 @@ def get_reservation_wage(σ: jnp.ndarray, model: Model) -> float:
"""
n, w_vals, P, P_cumsum, β, c, α = model

# Find all wage indices where policy indicates acceptance
accept_indices = jnp.where(σ == 1)[0]

if len(accept_indices) == 0:
return jnp.inf # Agent never accepts any wage
# Find the first index where policy indicates acceptance
# σ is a boolean array, argmax returns the first True value
first_accept_idx = jnp.argmax(σ)

# Return the lowest wage that is accepted
return w_vals[accept_indices[0]]
# If no acceptance (all False), return infinity
# Otherwise return the wage at the first acceptance index
return jnp.where(jnp.any(σ), w_vals[first_accept_idx], jnp.inf)
```


## Computing the Solution

Let's solve the model:

```{code-cell} ipython3
model = create_js_with_sep_model()
n, w_vals, P, P_cumsum, β, c, α = model
v_star, σ_star = vfi(model)
v_star = vfi(model)
σ_star = get_greedy(v_star, model)
```

Next we compute some related quantities, including the reservation wage.
Expand All @@ -312,19 +292,18 @@ ax.set_xlabel(r"$w$")
plt.show()
```


## Sensitivity Analysis

Let's examine how reservation wages change with the separation rate.


```{code-cell} ipython3
α_vals: jnp.ndarray = jnp.linspace(0.0, 1.0, 10)

w_star_vec = jnp.empty_like(α_vals)
for (i_α, α) in enumerate(α_vals):
model = create_js_with_sep_model(α=α)
v_star, σ_star = vfi(model)
v_star = vfi(model)
σ_star = get_greedy(v_star, model)
w_star = get_reservation_wage(σ_star, model)
w_star_vec = w_star_vec.at[i_α].set(w_star)

Expand Down Expand Up @@ -356,9 +335,8 @@ This is implemented via `jnp.searchsorted` on the precomputed cumulative sum

The function `update_agent` advances the agent's state by one period.


```{code-cell} ipython3
@jit
@jax.jit
def update_agent(key, is_employed, wage_idx, model, σ):
"""
Updates an agent by one period. Updates their employment status and their
Expand Down Expand Up @@ -439,7 +417,8 @@ Let's create a comprehensive plot of the employment simulation:
model = create_js_with_sep_model()

# Calculate reservation wage for plotting
v_star, σ_star = vfi(model)
v_star = vfi(model)
σ_star = get_greedy(v_star, model)
w_star = get_reservation_wage(σ_star, model)

wage_path, employment_status = simulate_employment_path(model, σ_star)
Expand Down Expand Up @@ -486,7 +465,6 @@ plt.tight_layout()
plt.show()
```


The simulation helps to visualize outcomes associated with this model.

The agent follows a reservation wage strategy.
Expand Down Expand Up @@ -531,7 +509,7 @@ This holds because:

These properties ensure the chain is ergodic with a unique stationary distribution $\pi$ over states $(s, w)$.

For an ergodic Markov chain, the ergodic theorem guarantees that time averages = ensemble averages.
For an ergodic Markov chain, the ergodic theorem guarantees that time averages = cross-sectional averages.

In particular, the fraction of time a single agent spends unemployed (across all
wage states) converges to the cross-sectional unemployment rate:
Expand Down Expand Up @@ -568,7 +546,7 @@ update_agents_vmap = jax.vmap(
Next we define the core simulation function, which uses `lax.fori_loop` to efficiently iterate many agents forward in time:

```{code-cell} ipython3
@partial(jit, static_argnums=(3, 4))
@partial(jax.jit, static_argnums=(3, 4))
def _simulate_cross_section_compiled(
key: jnp.ndarray,
model: Model,
Expand Down Expand Up @@ -627,7 +605,8 @@ def simulate_cross_section(
key = jax.random.PRNGKey(seed)

# Solve for optimal policy
v_star, σ_star = vfi(model)
v_star = vfi(model)
σ_star = get_greedy(v_star, model)

# Run JIT-compiled simulation
final_employment = _simulate_cross_section_compiled(
Expand Down Expand Up @@ -655,7 +634,8 @@ def plot_cross_sectional_unemployment(model: Model, t_snapshot: int = 200,
"""
# Get final employment state directly
key = jax.random.PRNGKey(42)
v_star, σ_star = vfi(model)
v_star = vfi(model)
σ_star = get_greedy(v_star, model)
final_employment = _simulate_cross_section_compiled(
key, model, σ_star, n_agents, t_snapshot
)
Expand All @@ -681,7 +661,12 @@ def plot_cross_sectional_unemployment(model: Model, t_snapshot: int = 200,
plt.show()
```

Now let's compare the time-average unemployment rate (from a single agent's long simulation) with the cross-sectional unemployment rate (from many agents at a single point in time):
Now let's compare the time-average unemployment rate (from a single agent's long simulation) with the cross-sectional unemployment rate (from many agents at a single point in time).

We claimed above that these numbers will be approximately equal in large
samples, due to ergodicity.

Let's see if that's true.

```{code-cell} ipython3
model = create_js_with_sep_model()
Expand All @@ -697,28 +682,31 @@ print(f"Cross-sectional unemployment rate (at t=200): "
print(f"Difference: {abs(time_avg_unemp - cross_sectional_unemp):.4f}")
```

Indeed, they are very close.

Now let's visualize the cross-sectional distribution:

```{code-cell} ipython3
plot_cross_sectional_unemployment(model)
```

## Cross-Sectional Analysis with Lower Unemployment Compensation (c=0.5)
## Lower Unemployment Compensation (c=0.5)

Let's examine how the cross-sectional unemployment rate changes with lower unemployment compensation:
What happens to the cross-sectional unemployment rate with lower unemployment compensation?

```{code-cell} ipython3
model_low_c = create_js_with_sep_model(c=0.5)
plot_cross_sectional_unemployment(model_low_c)
```


## Exercises

```{exercise-start}
:label: mmwsm_ex1
```

Create a plot that shows how the steady state cross-sectional unemployment rate
Create a plot that investigates more carefully how the steady state cross-sectional unemployment rate
changes with unemployment compensation.

```{exercise-end}
Expand Down Expand Up @@ -751,4 +739,3 @@ plt.show()

```{solution-end}
```

Loading
Loading