Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions lectures/mccall_fitted_vfi.md
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ class Model(NamedTuple):

def create_mccall_model(
c: float = 1.0,
α: float = 0.1,
α: float = 0.05,
β: float = 0.96,
ρ: float = 0.9,
ν: float = 0.2,
Expand Down Expand Up @@ -361,14 +361,17 @@ def vfi(
```

Here's a function that uses a solution $v_u$ to compute the remaining functions of
interest: $v_u$, and the continuation value function $h$.
interest: $v_e$, and the continuation value function $h$.

We use the same expressions as we did in the {doc}`discrete case <mccall_model_with_sep_markov>`, after replacing sums with integrals.

```{code-cell} ipython3
def compute_solution_functions(model, v_u):

# Interpolate v_u
# Unpack model parameters
c, α, β, ρ, ν, γ, w_grid, z_draws = model

# Interpolate v_u on the wage grid
vf = lambda x: jnp.interp(x, w_grid, v_u)

def compute_expectation(w):
Expand Down Expand Up @@ -604,7 +607,7 @@ When unemployed, the agent accepts offers that exceed the reservation wage.

When employed, the agent faces job separation with probability $\alpha$ each period.

### Cross-Sectional Analysis
### Cross-sectional analysis

Now let's simulate many agents simultaneously to examine the cross-sectional unemployment rate.

Expand Down Expand Up @@ -633,29 +636,29 @@ def _simulate_cross_section_compiled(
c, α, β, ρ, ν, γ, w_grid, z_draws = model

# Initialize arrays
key, subkey = jax.random.split(key)
init_key, subkey = jax.random.split(key)
wages = jnp.exp(jax.random.normal(subkey, (n_agents,)) * ν)
status = jnp.zeros(n_agents, dtype=jnp.int32)

def update(t, loop_state):
key, status, wages = loop_state
status, wages = loop_state

# Shift loop state forwards
key, subkey = jax.random.split(key)
agent_keys = jax.random.split(subkey, n_agents)
step_key = jax.random.fold_in(init_key, t)
agent_keys = jax.random.split(step_key, n_agents)

status, wages = update_agents_vmap(
agent_keys, status, wages, model, w_bar
)

return key, status, wages
return status, wages

# Run simulation using fori_loop
initial_loop_state = (key, status, wages)
initial_loop_state = (status, wages)
final_loop_state = lax.fori_loop(0, T, update, initial_loop_state)

# Return only final employment state
_, final_is_employed, _ = final_loop_state
final_is_employed, _ = final_loop_state
return final_is_employed


Expand Down
111 changes: 56 additions & 55 deletions lectures/mccall_model_with_sep_markov.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ libraries
```{code-cell} ipython3
:tags: [hide-output]

!pip install quantecon
!pip install quantecon jax
```

We use the following imports:
Expand All @@ -64,7 +64,7 @@ import matplotlib.pyplot as plt
from functools import partial
```

## Model Setup
## Model setup

The setting is as follows:

Expand All @@ -74,7 +74,7 @@ The setting is as follows:
- Unemployed workers receive compensation $c$ per period
- Future payoffs are discounted by factor $\beta \in (0,1)$

### Decision Problem
### Decision problem

When unemployed and receiving wage offer $w$, the agent chooses between:

Expand All @@ -86,7 +86,7 @@ The wage updates are as follows:
* If an unemployed agent rejects offer $w$, then their next offer is drawn from $P(w, \cdot)$
* If an employed agent loses a job in which they were paid wage $w$, then their next offer is drawn from $P(w, \cdot)$

### The Wage Offer Process
### The wage offer process

To construct the wage offer process we start with an AR1 process.

Expand All @@ -112,7 +112,7 @@ Actually, in practice, we approximate this wage process as follows:



### Value Functions
### Value functions

We let

Expand Down Expand Up @@ -168,12 +168,12 @@ $$

+++

### Optimal Policy
### Optimal policy

Once we have the solutions $v_e$ and $v_u$ to these Bellman equations, we can compute the optimal policy: accept at current wage offer $w$ if

$$
v_e(w) u(c) + β(Pv_u)(w)
v_e(w) \geq u(c) + \beta (P v_u)(w)
$$

The optimal policy turns out to be a reservation wage strategy: accept all wages above some threshold.
Expand All @@ -185,7 +185,7 @@ The optimal policy turns out to be a reservation wage strategy: accept all wages

Let's now implement the model.

### Set Up
### Set up

The default utility function is a CRRA utility function

Expand Down Expand Up @@ -234,7 +234,7 @@ def create_js_with_sep_model(
```


### Solution: First Pass
### Solution: first pass

Let's put together a (not very efficient) routine for calculating the
reservation wage.
Expand All @@ -244,7 +244,7 @@ reservation wage.
It works by starting with guesses for $v_e$ and $v_u$ and iterating to
convergence.

Here's are Bellman operators that update $v_u$ and $v_e$ respectively.
Here are Bellman operators that update $v_u$ and $v_e$ respectively.


```{code-cell} ipython3
Expand Down Expand Up @@ -313,7 +313,7 @@ def solve_model_first_pass(
```


### Road Test
### Road test

Let's solve the model:

Expand Down Expand Up @@ -348,9 +348,9 @@ The reservation wage is at the intersection of $v_e$, and the continuation value
function, which is the value of rejecting.


## Improving Efficiency
## Improving efficiency

The solution method desribed above works fine but we can do much better.
The solution method described above works fine but we can do much better.

First, we use the employed worker's Bellman equation to express
$v_e$ in terms of $Pv_u$
Expand Down Expand Up @@ -495,7 +495,7 @@ The result is the same as before but we only iterate on one array --- and also
our JAX code is more efficient.


## Sensitivity Analysis
## Sensitivity analysis

Let's examine how reservation wages change with the separation rate.

Expand Down Expand Up @@ -523,7 +523,7 @@ Can you provide an intuitive economic story behind the outcome that you see in t

+++

## Employment Simulation
## Employment simulation

Now let's simulate the employment dynamics of a single agent under the optimal policy.

Expand Down Expand Up @@ -691,15 +691,15 @@ often leads a high new draw.

+++

## The Ergodic Property
## Ergodic property

Below we examine cross-sectional unemployment.

In particular, we will look at the unemployment rate in a cross-sectional
simulation and compare it to the time-average unemployment rate, which is the
fraction of time an agent spends unemployed over a long time series.

We will see that these two values are approximately equal -- if fact they are
We will see that these two values are approximately equal -- in fact they are
exactly equal in the limit.

The reason is that the process $(S_t, W_t)$, where
Expand Down Expand Up @@ -744,15 +744,15 @@ Often the second approach is better for our purposes, since it's easier to paral

+++

## Cross-Sectional Analysis
## Cross-sectional analysis

Now let's simulate many agents simultaneously to examine the cross-sectional unemployment rate.

We first create a vectorized version of `update_agent` to efficiently update all agents in parallel:

```{code-cell} ipython3
# Create vectorized version of update_agent
# The last parameter is now w_bar (scalar) instead of σ (array)
# Create vectorized version of update_agent.
# Vectorize over key, status, wage_idx
update_agents_vmap = jax.vmap(
update_agent, in_axes=(0, 0, 0, None, None)
)
Expand All @@ -761,76 +761,72 @@ update_agents_vmap = jax.vmap(
Next we define the core simulation function, which uses `lax.fori_loop` to efficiently iterate many agents forward in time:

```{code-cell} ipython3
@partial(jax.jit, static_argnums=(3, 4))
@jax.jit
def _simulate_cross_section_compiled(
key: jnp.ndarray,
model: Model,
w_bar: float,
n_agents: int,
initial_wage_indices: jnp.ndarray,
initial_status_vec: jnp.ndarray,
T: int
):
"""JIT-compiled core simulation loop using lax.fori_loop.
Returns only the final employment state to save memory."""
"""
JIT-compiled core simulation loop for shifting the cross section
using lax.fori_loop. Returns the final employment employment status
cross-section.

"""
n, w_vals, P, P_cumsum, β, c, α, γ = model
n_agents = len(initial_wage_indices)

# Initialize arrays
wage_indices = jnp.zeros(n_agents, dtype=jnp.int32)
status = jnp.zeros(n_agents, dtype=jnp.int32)

def update(t, loop_state):
key, status, wage_indices = loop_state

# Shift loop state forwards
key, subkey = jax.random.split(key)
agent_keys = jax.random.split(subkey, n_agents)

" Shift loop state forwards "
status, wage_indices = loop_state
step_key = jax.random.fold_in(key, t)
agent_keys = jax.random.split(step_key, n_agents)
status, wage_indices = update_agents_vmap(
agent_keys, status, wage_indices, model, w_bar
)

return key, status, wage_indices
return status, wage_indices

# Run simulation using fori_loop
initial_loop_state = (key, status, wage_indices)
initial_loop_state = (initial_status_vec, initial_wage_indices)
final_loop_state = lax.fori_loop(0, T, update, initial_loop_state)

# Return only final employment state
_, final_is_employed, _ = final_loop_state
final_is_employed, _ = final_loop_state
return final_is_employed


def simulate_cross_section(
model: Model,
n_agents: int = 100_000,
T: int = 200,
seed: int = 42
model: Model, # Model instance with parameters
n_agents: int = 100_000, # Number of agents to simulate
T: int = 200, # Length of burn-in
seed: int = 42 # For reproducibility
) -> float:
"""
Simulate employment paths for many agents and return final unemployment rate.
Wrapper function for _simulate_cross_section_compiled.

Parameters:
- model: Model instance with parameters
- n_agents: Number of agents to simulate
- T: Number of periods to simulate
- seed: Random seed for reproducibility
Push forward a cross-section for T periods and return the final
cross-sectional unemployment rate.

Returns:
- unemployment_rate: Fraction of agents unemployed at time T
"""
key = jax.random.PRNGKey(seed)

# Solve for optimal reservation wage
v_u = vfi(model)
w_bar = get_reservation_wage(v_u, model)

# Run JIT-compiled simulation
# Initialize arrays
initial_wage_indices = jnp.zeros(n_agents, dtype=jnp.int32)
initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32)

final_status = _simulate_cross_section_compiled(
key, model, w_bar, n_agents, T
key, model, w_bar, initial_wage_indices, initial_status_vec, T
)

# Calculate unemployment rate at final period
unemployment_rate = 1 - jnp.mean(final_status)

return unemployment_rate
```

Expand All @@ -850,8 +846,13 @@ def plot_cross_sectional_unemployment(
key = jax.random.PRNGKey(42)
v_u = vfi(model)
w_bar = get_reservation_wage(v_u, model)

# Initialize arrays
initial_wage_indices = jnp.zeros(n_agents, dtype=jnp.int32)
initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32)

final_status = _simulate_cross_section_compiled(
key, model, w_bar, n_agents, t_snapshot
key, model, w_bar, initial_wage_indices, initial_status_vec, t_snapshot
)

# Calculate unemployment rate
Expand Down Expand Up @@ -906,7 +907,7 @@ Now let's visualize the cross-sectional distribution:
plot_cross_sectional_unemployment(model)
```

## Lower Unemployment Compensation (c=0.5)
## Lower unemployment compensation (c=0.5)

What happens to the cross-sectional unemployment rate with lower unemployment compensation?

Expand Down
Loading