diff --git a/lectures/mccall_model.md b/lectures/mccall_model.md
index 3e8f10067..43469f062 100644
--- a/lectures/mccall_model.md
+++ b/lectures/mccall_model.md
@@ -498,12 +498,13 @@ colors = prop_cycle.by_key()['color']
# Plot the wage offer distribution
ax.plot(w, q, '-', alpha=0.6, lw=2,
- label='wage offer distribution',
+ label='wage offer distribution',
color=colors[0])
# Compute reservation wage with default beta
model_default = McCallModel()
-v_init = model_default.w / (1 - model_default.β)
+c, β, w, q = model_default
+v_init = w / (1 - β)
v_default, res_wage_default = compute_reservation_wage(
model_default, v_init
)
@@ -511,14 +512,15 @@ v_default, res_wage_default = compute_reservation_wage(
# Compute reservation wage with lower beta
β_new = 0.96
model_low_beta = McCallModel(β=β_new)
-v_init_low = model_low_beta.w / (1 - model_low_beta.β)
+c, β_low, w, q = model_low_beta
+v_init_low = w / (1 - β_low)
v_low, res_wage_low = compute_reservation_wage(
model_low_beta, v_init_low
)
# Plot vertical lines for reservation wages
ax.axvline(x=res_wage_default, color=colors[1], lw=2,
- label=f'reservation wage (β={model_default.β})')
+ label=f'reservation wage (β={β})')
ax.axvline(x=res_wage_low, color=colors[2], lw=2,
label=f'reservation wage (β={β_new})')
@@ -621,12 +623,43 @@ Let $h$ denote the continuation value:
The Bellman equation can now be written as
-$$
+```{math}
+:label: j1b
+
v^*(w')
= \max \left\{ \frac{w'}{1 - \beta}, \, h \right\}
+```
+
+Now let's derive a nonlinear equation for $h$ alone.
+
+Starting from {eq}`j1b`, we multiply both sides by $q(w')$ to get
+
+$$
+ v^*(w') q(w') = \max \left\{ \frac{w'}{1 - \beta}, h \right\} q(w')
+$$
+
+Next, we sum both sides over $w' \in \mathbb{W}$:
+
+$$
+ \sum_{w' \in \mathbb W} v^*(w') q(w')
+ = \sum_{w' \in \mathbb W} \max \left\{ \frac{w'}{1 - \beta}, h \right\} q(w')
+$$
+
+Now multiply both sides by $\beta$:
+
+$$
+ \beta \sum_{w' \in \mathbb W} v^*(w') q(w')
+ = \beta \sum_{w' \in \mathbb W} \max \left\{ \frac{w'}{1 - \beta}, h \right\} q(w')
$$
-Substituting this last equation into {eq}`j1` gives
+Add $c$ to both sides:
+
+$$
+ c + \beta \sum_{w' \in \mathbb W} v^*(w') q(w')
+ = c + \beta \sum_{w' \in \mathbb W} \max \left\{ \frac{w'}{1 - \beta}, h \right\} q(w')
+$$
+
+Finally, using the definition of $h$ from {eq}`j1`, the left-hand side is just $h$, giving us
```{math}
:label: j2
@@ -638,7 +671,7 @@ Substituting this last equation into {eq}`j1` gives
\right\} q (w')
```
-This is a nonlinear equation that we can solve for $h$.
+This is a nonlinear equation in the single scalar $h$ that we can solve for $h$.
As before, we will use successive approximations:
@@ -781,8 +814,28 @@ plt.show()
And here's a solution using JAX.
```{code-cell} ipython3
+# First, we set up a function to draw random wage offers from the distribution.
+# We use the inverse transform method: draw a uniform random variable u,
+# then find the smallest wage w such that the CDF at w is >= u.
cdf = jnp.cumsum(q_default)
+def draw_wage(uniform_rv):
+ """
+ Draw a wage from the distribution q_default using the inverse transform method.
+
+ Parameters:
+ -----------
+ uniform_rv : float
+ A uniform random variable on [0, 1]
+
+ Returns:
+ --------
+ wage : float
+ A wage drawn from w_default with probabilities given by q_default
+ """
+ return w_default[jnp.searchsorted(cdf, uniform_rv)]
+
+
def compute_stopping_time(w_bar, key):
"""
Compute stopping time by drawing wages until one exceeds `w_bar`.
@@ -791,7 +844,7 @@ def compute_stopping_time(w_bar, key):
t, key, accept = loop_state
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)
- w = w_default[jnp.searchsorted(cdf, u)]
+ w = draw_wage(u)
accept = w >= w_bar
t = t + 1
return t, key, accept
@@ -831,7 +884,8 @@ def compute_stop_time_for_c(c):
return compute_mean_stopping_time(w_bar)
# Vectorize across all c values
-stop_times = jax.vmap(compute_stop_time_for_c)(c_vals)
+compute_stop_time_vectorized = jax.vmap(compute_stop_time_for_c)
+stop_times = compute_stop_time_vectorized(c_vals)
fig, ax = plt.subplots()
@@ -928,12 +982,12 @@ def create_mccall_continuous(
key = jax.random.PRNGKey(seed)
s = jax.random.normal(key, (mc_size,))
w_draws = jnp.exp(μ + σ * s)
- return McCallModelContinuous(c=c, β=β, σ=σ, μ=μ, w_draws=w_draws)
+ return McCallModelContinuous(c, β, σ, μ, w_draws)
@jax.jit
def compute_reservation_wage_continuous(model, max_iter=500, tol=1e-5):
- c, β, σ, μ, w_draws = model.c, model.β, model.σ, model.μ, model.w_draws
+ c, β, σ, μ, w_draws = model
h = jnp.mean(w_draws) / (1 - β) # initial guess
@@ -949,8 +1003,9 @@ def compute_reservation_wage_continuous(model, max_iter=500, tol=1e-5):
return jnp.logical_and(i < max_iter, error > tol)
initial_state = (h, 0, tol + 1)
- h_final, _, _ = jax.lax.while_loop(cond, update, initial_state)
-
+ final_state = jax.lax.while_loop(cond, update, initial_state)
+ h_final, _, _ = final_state
+
# Now compute the reservation wage
return (1 - β) * h_final
```
@@ -969,12 +1024,13 @@ def compute_R_element(c, β):
model = create_mccall_continuous(c=c, β=β)
return compute_reservation_wage_continuous(model)
-# Create meshgrid and vectorize computation
-c_grid, β_grid = jnp.meshgrid(c_vals, β_vals, indexing='ij')
-compute_R_vectorized = jax.vmap(
- jax.vmap(compute_R_element,
- in_axes=(None, 0)),
- in_axes=(0, None))
+# First, vectorize over β (holding c fixed)
+compute_R_over_β = jax.vmap(compute_R_element, in_axes=(None, 0))
+
+# Next, vectorize over c (applying the above function to each c)
+compute_R_vectorized = jax.vmap(compute_R_over_β, in_axes=(0, None))
+
+# Apply to compute the full grid
R = compute_R_vectorized(c_vals, β_vals)
```
diff --git a/lectures/mccall_model.py b/lectures/mccall_model.py
new file mode 100644
index 000000000..d52f50406
--- /dev/null
+++ b/lectures/mccall_model.py
@@ -0,0 +1,1068 @@
+# ---
+# jupyter:
+# jupytext:
+# default_lexer: ipython3
+# text_representation:
+# extension: .py
+# format_name: percent
+# format_version: '1.3'
+# jupytext_version: 1.17.2
+# kernelspec:
+# display_name: Python 3 (ipykernel)
+# language: python
+# name: python3
+# ---
+
+# %% [markdown]
+# (mccall)=
+# ```{raw} jupyter
+#
+# ```
+#
+# # Job Search I: The McCall Search Model
+#
+# ```{contents} Contents
+# :depth: 2
+# ```
+#
+# ```{epigraph}
+# "Questioning a McCall worker is like having a conversation with an out-of-work friend:
+# 'Maybe you are setting your sights too high', or 'Why did you quit your old job before you
+# had a new one lined up?' This is real social science: an attempt to model, to understand,
+# human behavior by visualizing the situation people find themselves in, the options they face
+# and the pros and cons as they themselves see them." -- Robert E. Lucas, Jr.
+# ```
+#
+# In addition to what's in Anaconda, this lecture will need the following libraries:
+
+# %% tags=["hide-output"]
+# !pip install quantecon jax
+
+# %% [markdown]
+# ## Overview
+#
+# The McCall search model {cite}`McCall1970` helped transform economists' way of thinking about labor markets.
+#
+# To clarify notions such as "involuntary" unemployment, McCall modeled the decision problem of an unemployed worker in terms of factors including
+#
+# * current and likely future wages
+# * impatience
+# * unemployment compensation
+#
+# To solve the decision problem McCall used dynamic programming.
+#
+# Here we set up McCall's model and use dynamic programming to analyze it.
+#
+# As we'll see, McCall's model is not only interesting in its own right but also an excellent vehicle for learning dynamic programming.
+#
+# Let's start with some imports:
+
+# %%
+import matplotlib.pyplot as plt
+import numpy as np
+import numba
+import jax
+import jax.numpy as jnp
+from typing import NamedTuple
+from functools import partial
+import quantecon as qe
+from quantecon.distributions import BetaBinomial
+
+# %% [markdown]
+# ## The McCall Model
+#
+# ```{index} single: Models; McCall
+# ```
+#
+# An unemployed agent receives in each period a job offer at wage $w_t$.
+#
+# In this lecture, we adopt the following simple environment:
+#
+# * The offer sequence $\{w_t\}_{t \geq 0}$ is IID, with $q(w)$ being the probability of observing wage $w$ in finite set $\mathbb{W}$.
+# * The agent observes $w_t$ at the start of $t$.
+# * The agent knows that $\{w_t\}$ is IID with common distribution $q$ and can use this when computing expectations.
+#
+# (In later lectures, we will relax these assumptions.)
+#
+# At time $t$, our agent has two choices:
+#
+# 1. Accept the offer and work permanently at constant wage $w_t$.
+# 1. Reject the offer, receive unemployment compensation $c$, and reconsider next period.
+#
+# The agent is infinitely lived and aims to maximize the expected discounted
+# sum of earnings
+#
+# ```{math}
+# :label: obj_model
+#
+# {\mathbb E} \sum_{t=0}^\infty \beta^t y_t
+# ```
+#
+# The constant $\beta$ lies in $(0, 1)$ and is called a **discount factor**.
+#
+# The smaller is $\beta$, the more the agent discounts future earnings relative to current earnings.
+#
+# The variable $y_t$ is income, equal to
+#
+# * his/her wage $w_t$ when employed
+# * unemployment compensation $c$ when unemployed
+#
+#
+# ### A Trade-Off
+#
+# The worker faces a trade-off:
+#
+# * Waiting too long for a good offer is costly, since the future is discounted.
+# * Accepting too early is costly, since better offers might arrive in the future.
+#
+# To decide the optimal wait time in the face of this trade-off, we use [dynamic programming](https://dp.quantecon.org/).
+#
+# Dynamic programming can be thought of as a two-step procedure that
+#
+# 1. first assigns values to "states" and
+# 1. then deduces optimal actions given those values
+#
+# We'll go through these steps in turn.
+#
+# ### The Value Function
+#
+# In order to optimally trade-off current and future rewards, we need to think about two things:
+#
+# 1. the current payoffs we get from different choices
+# 1. the different states that those choices will lead to in next period
+#
+# To weigh these two aspects of the decision problem, we need to assign *values*
+# to states.
+#
+# To this end, let $v^*(w)$ be the total lifetime value accruing to an
+# unemployed worker who enters the current period unemployed when the wage is
+# $w \in \mathbb{W}$.
+#
+# (In particular, the agent has wage offer $w$ in hand and can accept or reject it.)
+#
+# More precisely, $v^*(w)$ denotes the total sum of expected discounted earnings
+# when an agent always behaves in an optimal way. points in time.
+#
+# Of course $v^*(w)$ is not trivial to calculate because we don't yet know
+# what decisions are optimal and what aren't!
+#
+# If we don't know what opimal choices are, it feels imposible to calculate
+# $v^*(w)$.
+#
+# But let's put this aside for now and think of $v^*$ as a function that assigns
+# to each possible wage $w$ the maximal lifetime value $v^*(w)$ that can be
+# obtained with that offer in hand.
+#
+# A crucial observation is that this function $v^*$ must satisfy
+#
+# ```{math}
+# :label: odu_pv
+#
+# v^*(w)
+# = \max \left\{
+# \frac{w}{1 - \beta}, \, c + \beta
+# \sum_{w' \in \mathbb{W}} v^*(w') q (w')
+# \right\}
+# ```
+#
+# for every possible $w$ in $\mathbb{W}$.
+#
+# This is a version of the **Bellman equation**, which is
+# ubiquitous in economic dynamics and other fields involving planning over time.
+#
+# The intuition behind it is as follows:
+#
+# * the first term inside the max operation is the lifetime payoff from accepting current offer, since
+# such a worker works forever at $w$ and values this income stream as
+#
+# $$
+# \frac{w}{1 - \beta} = w + \beta w + \beta^2 w + \cdots
+# $$
+#
+# * the second term inside the max operation is the continuation value, which is
+# the lifetime payoff from rejecting the current offer and then behaving
+# optimally in all subsequent periods
+#
+# If we optimize and pick the best of these two options, we obtain maximal
+# lifetime value from today, given current offer $w$.
+#
+# But this is precisely $v^*(w)$, which is the left-hand side of {eq}`odu_pv`.
+#
+# Putting this all together, we see that {eq}`odu_pv` is valid for all $w$.
+#
+#
+# ### The Optimal Policy
+#
+# We still don't know how to compute $v^*$ (although {eq}`odu_pv` gives us hints
+# we'll return to below).
+#
+# But suppose for now that we do know $v^*$.
+#
+# Once we have this function in hand we can easily make optimal choices (i.e., make the
+# right choice between accept and reject given any $w$).
+#
+# All we have to do is select the maximal choice on the right-hand side of {eq}`odu_pv`.
+#
+# In other words, we make the best choice between stopping and continuing, given
+# the information provided to us by $v^*$.
+#
+# The optimal action is best thought of as a **policy**, which is, in general, a map from
+# states to actions.
+#
+# Given any $w$, we can read off the corresponding best choice (accept or
+# reject) by picking the max on the right-hand side of {eq}`odu_pv`.
+#
+# Thus, we have a map from $\mathbb W$ to $\{0, 1\}$, with 1 meaning accept and 0 meaning reject.
+#
+# We can write the policy as follows
+#
+# $$
+# \sigma(w) := \mathbf{1}
+# \left\{
+# \frac{w}{1 - \beta} \geq c + \beta \sum_{w' \in \mathbb W}
+# v^*(w') q (w')
+# \right\}
+# $$
+#
+# Here $\mathbf{1}\{ P \} = 1$ if statement $P$ is true and equals 0 otherwise.
+#
+# We can also write this as
+#
+# $$
+# \sigma(w) := \mathbf{1} \{ w \geq \bar w \}
+# $$
+#
+# where
+#
+# ```{math}
+# :label: reswage
+#
+# \bar w := (1 - \beta) \left\{ c + \beta \sum_{w'} v^*(w') q (w') \right\}
+# ```
+#
+# Here $\bar w$ (called the **reservation wage**) is a constant depending on
+# $\beta, c$ and the wage distribution.
+#
+# The agent should accept if and only if the current wage offer exceeds the reservation wage.
+#
+# In view of {eq}`reswage`, we can compute this reservation wage if we can compute the value function.
+#
+#
+# ## Computing the Optimal Policy: Take 1
+#
+# To put the above ideas into action, we need to compute the value function at each $w \in \mathbb W$.
+#
+# To simplify notation, let's set
+#
+# $$
+# \mathbb W := \{w_1, \ldots, w_n \}
+# \quad \text{and} \quad
+# v^*(i) := v^*(w_i)
+# $$
+#
+# The value function is then represented by the vector $v^* = (v^*(i))_{i=1}^n$.
+#
+# In view of {eq}`odu_pv`, this vector satisfies the nonlinear system of equations
+#
+# ```{math}
+# :label: odu_pv2
+#
+# v^*(i)
+# = \max \left\{
+# \frac{w(i)}{1 - \beta}, \, c + \beta \sum_{j=1}^n
+# v^*(j) q (j)
+# \right\}
+# \quad
+# \text{for } i = 1, \ldots, n
+# ```
+#
+#
+#
+# ### The Algorithm
+#
+# To compute this vector, we use successive approximations:
+#
+# Step 1: pick an arbitrary initial guess $v \in \mathbb R^n$.
+#
+# Step 2: compute a new vector $v' \in \mathbb R^n$ via
+#
+# ```{math}
+# :label: odu_pv2p
+#
+# v'(i)
+# = \max \left\{
+# \frac{w(i)}{1 - \beta}, \, c + \beta \sum_{j=1}^n
+# v(j) q (j)
+# \right\}
+# \quad
+# \text{for } i = 1, \ldots, n
+# ```
+#
+# Step 3: calculate a measure of a discrepancy between $v$ and $v'$, such as $\max_i |v(i)- v'(i)|$.
+#
+# Step 4: if the deviation is larger than some fixed tolerance, set $v = v'$ and go to step 2, else continue.
+#
+# Step 5: return $v$.
+#
+# For a small tolerance, the returned function $v$ is a close approximation to the value function $v^*$.
+#
+# The theory below elaborates on this point.
+#
+# ### Fixed Point Theory
+#
+# What's the mathematics behind these ideas?
+#
+# First, one defines a mapping $T$ from $\mathbb R^n$ to itself via
+#
+# ```{math}
+# :label: odu_pv3
+#
+# (Tv)(i)
+# = \max \left\{
+# \frac{w(i)}{1 - \beta}, \, c + \beta \sum_{j=1}^n
+# v(j) q (j)
+# \right\}
+# \quad
+# \text{for } i = 1, \ldots, n
+# ```
+#
+# (A new vector $Tv$ is obtained from given vector $v$ by evaluating
+# the r.h.s. at each $i$.)
+#
+# The element $v_k$ in the sequence $\{v_k\}$ of successive approximations corresponds to $T^k v$.
+#
+# * This is $T$ applied $k$ times, starting at the initial guess $v$
+#
+# One can show that the conditions of the [Banach fixed point theorem](https://en.wikipedia.org/wiki/Banach_fixed-point_theorem) are
+# satisfied by $T$ on $\mathbb R^n$.
+#
+# One implication is that $T$ has a unique fixed point in $\mathbb R^n$.
+#
+# * That is, a unique vector $\bar v$ such that $T \bar v = \bar v$.
+#
+# Moreover, it's immediate from the definition of $T$ that this fixed point is $v^*$.
+#
+# A second implication of the Banach contraction mapping theorem is that
+# $\{ T^k v \}$ converges to the fixed point $v^*$ regardless of $v$.
+#
+#
+# ### Implementation
+#
+# Our default for $q$, the wage offer distribution, will be [Beta-binomial](https://en.wikipedia.org/wiki/Beta-binomial_distribution).
+
+# %%
+n, a, b = 50, 200, 100 # default parameters
+q_default = jnp.array(BetaBinomial(n, a, b).pdf())
+
+# %% [markdown]
+# Our default set of values for wages will be
+
+# %%
+w_min, w_max = 10, 60
+w_default = jnp.linspace(w_min, w_max, n+1)
+
+# %% [markdown]
+# Here's a plot of the probabilities of different wage outcomes:
+
+# %%
+fig, ax = plt.subplots()
+ax.plot(w_default, q_default, '-o', label='$q(w(i))$')
+ax.set_xlabel('wages')
+ax.set_ylabel('probabilities')
+
+plt.show()
+
+
+# %% [markdown]
+# We will use [JAX](https://python-programming.quantecon.org/jax_intro.html) to write our code.
+#
+# We'll use `NamedTuple` for our model class to maintain immutability, which works well with JAX's functional programming paradigm.
+#
+# Here's a class that stores the model parameters with default values.
+
+# %%
+class McCallModel(NamedTuple):
+ c: float = 25 # unemployment compensation
+ β: float = 0.99 # discount factor
+ w: jnp.ndarray = w_default # array of wage values, w[i] = wage at state i
+ q: jnp.ndarray = q_default # array of probabilities
+
+
+# %% [markdown]
+# We implement the Bellman operator $T$ from {eq}`odu_pv3`, which we can write in
+# terms of array operations as
+#
+# ```{math}
+# :label: odu_pv4
+#
+# Tv
+# = \max \left\{
+# \frac{w}{1 - \beta}, \, c + \beta \sum_{j=1}^n v(j) q (j)
+# \right\}
+# \quad
+# ```
+#
+# (The first term inside the max is an array and the second is just a number -- here
+# we mean that the max comparison against this number is done element-by-element for all elements in the array.)
+#
+# We can code $T$ up as follows.
+
+# %%
+def T(model: McCallModel, v: jnp.ndarray):
+ c, β, w, q = model
+ accept = w / (1 - β)
+ reject = c + β * v @ q
+ return jnp.maximum(accept, reject)
+
+
+# %% [markdown]
+# Based on these defaults, let's try plotting the first few approximate value functions
+# in the sequence $\{ T^k v \}$.
+#
+# We will start from guess $v$ given by $v(i) = w(i) / (1 - β)$, which is the value of accepting at every given wage.
+
+# %%
+model = McCallModel()
+c, β, w, q = model
+v = w / (1 - β) # Initial condition
+fig, ax = plt.subplots()
+
+num_plots = 6
+for i in range(num_plots):
+ ax.plot(w, v, '-', alpha=0.6, lw=2, label=f"iterate {i}")
+ v = T(model, v)
+
+ax.legend(loc='lower right')
+ax.set_xlabel('wage')
+ax.set_ylabel('value')
+plt.show()
+
+
+# %% [markdown]
+# You can see that convergence is occurring: successive iterates are getting closer together.
+#
+# Here's a more serious iteration effort to compute the limit, which continues
+# until measured deviation between successive iterates is below `tol`.
+#
+# Once we obtain a good approximation to the limit, we will use it to calculate
+# the reservation wage.
+
+# %%
+def compute_reservation_wage(
+ model: McCallModel, # instance containing default parameters
+ v_init: jnp.ndarray, # initial condition for iteration
+ tol: float=1e-6, # error tolerance
+ max_iter: int=500, # maximum number of iterations for loop
+ ):
+ "Computes the reservation wage in the McCall job search model."
+ c, β, w, q = model
+ i = 0
+ error = tol + 1
+ v = v_init
+
+ while i < max_iter and error > tol:
+ v_next = T(model, v)
+ error = jnp.max(jnp.abs(v_next - v))
+ v = v_next
+ i += 1
+
+ res_wage = (1 - β) * (c + β * v @ q)
+ return v, res_wage
+
+
+# %% [markdown]
+# The cell computes the reservation wage at the default parameters
+
+# %%
+model = McCallModel()
+c, β, w, q = model
+v_init = w / (1 - β) # initial guess
+v, res_wage = compute_reservation_wage(model, v_init)
+print(res_wage)
+
+# %% [markdown]
+# ### Comparative Statics
+#
+# Now that we know how to compute the reservation wage, let's see how it varies with
+# parameters.
+#
+# Here we compare the reservation wage at two values of $\beta$.
+#
+# The reservation wages will be plotted alongside the wage offer distribution, so
+# that we can get a sense of what fraction of offers will be accepted.
+
+# %%
+fig, ax = plt.subplots()
+
+# Get the default color cycle
+prop_cycle = plt.rcParams['axes.prop_cycle']
+colors = prop_cycle.by_key()['color']
+
+# Plot the wage offer distribution
+ax.plot(w, q, '-', alpha=0.6, lw=2,
+ label='wage offer distribution',
+ color=colors[0])
+
+# Compute reservation wage with default beta
+model_default = McCallModel()
+c, β, w, q = model_default
+v_init = w / (1 - β)
+v_default, res_wage_default = compute_reservation_wage(
+ model_default, v_init
+)
+
+# Compute reservation wage with lower beta
+β_new = 0.96
+model_low_beta = McCallModel(β=β_new)
+c, β_low, w, q = model_low_beta
+v_init_low = w / (1 - β_low)
+v_low, res_wage_low = compute_reservation_wage(
+ model_low_beta, v_init_low
+)
+
+# Plot vertical lines for reservation wages
+ax.axvline(x=res_wage_default, color=colors[1], lw=2,
+ label=f'reservation wage (β={β})')
+ax.axvline(x=res_wage_low, color=colors[2], lw=2,
+ label=f'reservation wage (β={β_new})')
+
+ax.set_xlabel('wage', fontsize=12)
+ax.set_ylabel('probability', fontsize=12)
+ax.tick_params(axis='both', which='major', labelsize=11)
+ax.legend(loc='upper left', frameon=False, fontsize=11)
+plt.show()
+
+
+# %% [markdown]
+# We see that the reservation wage is higher when $\beta$ is higher.
+#
+# This is not surprising, since higher $\beta$ is associated with more patience.
+#
+# Now let's look more systematically at what happens when we change $\beta$ and $c$.
+#
+# As a first step, given that we'll use it many times, let's create a more
+# efficient, jit-complied version of the function that computes the reservation
+# wage:
+
+# %%
+@jax.jit
+def compute_res_wage_jitted(
+ model: McCallModel, # instance containing default parameters
+ v_init: jnp.ndarray, # initial condition for iteration
+ tol: float=1e-6, # error tolerance
+ max_iter: int=500, # maximum number of iterations for loop
+ ):
+ c, β, w, q = model
+ i = 0
+ error = tol + 1
+ initial_state = v_init, i, error
+
+ def cond(loop_state):
+ v, i, error = loop_state
+ return jnp.logical_and(i < max_iter, error > tol)
+
+ def update(loop_state):
+ v, i, error = loop_state
+ v_next = T(model, v)
+ error = jnp.max(jnp.abs(v_next - v))
+ i += 1
+ new_loop_state = v_next, i, error
+ return new_loop_state
+
+ final_state = jax.lax.while_loop(cond, update, initial_state)
+ v, i, error = final_state
+
+ res_wage = (1 - β) * (c + β * v @ q)
+ return v, res_wage
+
+
+# %% [markdown]
+# Now we compute the reservation wage at each $c, \beta$ pair.
+
+# %%
+grid_size = 25
+c_vals = jnp.linspace(10.0, 30.0, grid_size)
+β_vals = jnp.linspace(0.9, 0.99, grid_size)
+
+res_wage_matrix = np.empty((grid_size, grid_size))
+model = McCallModel()
+v_init = model.w / (1 - model.β)
+
+for i, c in enumerate(c_vals):
+ for j, β in enumerate(β_vals):
+ model = McCallModel(c=c, β=β)
+ v, res_wage = compute_res_wage_jitted(model, v_init)
+ v_init = v
+ res_wage_matrix[i, j] = res_wage
+
+fig, ax = plt.subplots()
+cs1 = ax.contourf(c_vals, β_vals, res_wage_matrix.T, alpha=0.75)
+ctr1 = ax.contour(c_vals, β_vals, res_wage_matrix.T)
+plt.clabel(ctr1, inline=1, fontsize=13)
+plt.colorbar(cs1, ax=ax)
+ax.set_title("reservation wage")
+ax.set_xlabel("$c$", fontsize=16)
+ax.set_ylabel("$β$", fontsize=16)
+ax.ticklabel_format(useOffset=False)
+plt.show()
+
+
+# %% [markdown]
+# As expected, the reservation wage increases with both patience and unemployment compensation.
+#
+# (mm_op2)=
+# ## Computing an Optimal Policy: Take 2
+#
+# The approach to dynamic programming just described is standard and broadly applicable.
+#
+# But for our McCall search model there's also an easier way that circumvents the
+# need to compute the value function.
+#
+# Let $h$ denote the continuation value:
+#
+# ```{math}
+# :label: j1
+#
+# h = c + \beta \sum_{w'} v^*(w') q (w')
+# ```
+#
+# The Bellman equation can now be written as
+#
+# ```{math}
+# :label: j1b
+#
+# v^*(w')
+# = \max \left\{ \frac{w'}{1 - \beta}, \, h \right\}
+# ```
+#
+# Now let's derive a nonlinear equation for $h$ alone.
+#
+# Starting from {eq}`j1b`, we multiply both sides by $q(w')$ to get
+#
+# $$
+# v^*(w') q(w') = \max \left\{ \frac{w'}{1 - \beta}, h \right\} q(w')
+# $$
+#
+# Next, we sum both sides over $w' \in \mathbb{W}$:
+#
+# $$
+# \sum_{w' \in \mathbb W} v^*(w') q(w')
+# = \sum_{w' \in \mathbb W} \max \left\{ \frac{w'}{1 - \beta}, h \right\} q(w')
+# $$
+#
+# Now multiply both sides by $\beta$:
+#
+# $$
+# \beta \sum_{w' \in \mathbb W} v^*(w') q(w')
+# = \beta \sum_{w' \in \mathbb W} \max \left\{ \frac{w'}{1 - \beta}, h \right\} q(w')
+# $$
+#
+# Add $c$ to both sides:
+#
+# $$
+# c + \beta \sum_{w' \in \mathbb W} v^*(w') q(w')
+# = c + \beta \sum_{w' \in \mathbb W} \max \left\{ \frac{w'}{1 - \beta}, h \right\} q(w')
+# $$
+#
+# Finally, using the definition of $h$ from {eq}`j1`, the left-hand side is just $h$, giving us
+#
+# ```{math}
+# :label: j2
+#
+# h = c + \beta
+# \sum_{w' \in \mathbb W}
+# \max \left\{
+# \frac{w'}{1 - \beta}, h
+# \right\} q (w')
+# ```
+#
+# This is a nonlinear equation in the single scalar $h$ that we can solve for $h$.
+#
+# As before, we will use successive approximations:
+#
+# Step 1: pick an initial guess $h$.
+#
+# Step 2: compute the update $h'$ via
+#
+# ```{math}
+# :label: j3
+#
+# h'
+# = c + \beta
+# \sum_{w' \in \mathbb W}
+# \max \left\{
+# \frac{w'}{1 - \beta}, h
+# \right\} q (w')
+# \quad
+# ```
+#
+# Step 3: calculate the deviation $|h - h'|$.
+#
+# Step 4: if the deviation is larger than some fixed tolerance, set $h = h'$ and go to step 2, else return $h$.
+#
+# One can again use the Banach contraction mapping theorem to show that this process always converges.
+#
+# The big difference here, however, is that we're iterating on a scalar $h$, rather than an $n$-vector, $v(i), i = 1, \ldots, n$.
+#
+# Here's an implementation:
+
+# %%
+def compute_reservation_wage_two(
+ model: McCallModel, # instance containing default parameters
+ tol: float=1e-5, # error tolerance
+ max_iter: int=500, # maximum number of iterations for loop
+ ):
+ c, β, w, q = model
+ h = (w @ q) / (1 - β) # initial condition
+ i = 0
+ error = tol + 1
+ initial_loop_state = i, h, error
+
+ def cond(loop_state):
+ i, h, error = loop_state
+ return jnp.logical_and(i < max_iter, error > tol)
+
+ def update(loop_state):
+ i, h, error = loop_state
+ s = jnp.maximum(w / (1 - β), h)
+ h_next = c + β * (s @ q)
+ error = jnp.abs(h_next - h)
+ i_next = i + 1
+ new_loop_state = i_next, h_next, error
+ return new_loop_state
+
+ final_state = jax.lax.while_loop(cond, update, initial_loop_state)
+ i, h, error = final_state
+
+ # Compute and return the reservation wage
+ return (1 - β) * h
+
+
+# %% [markdown]
+# You can use this code to solve the exercise below.
+#
+# ## Exercises
+#
+# ```{exercise}
+# :label: mm_ex1
+#
+# Compute the average duration of unemployment when $\beta=0.99$ and
+# $c$ takes the following values
+#
+# > `c_vals = np.linspace(10, 40, 25)`
+#
+# That is, start the agent off as unemployed, compute their reservation wage
+# given the parameters, and then simulate to see how long it takes to accept.
+#
+# Repeat a large number of times and take the average.
+#
+# Plot mean unemployment duration as a function of $c$ in `c_vals`.
+# ```
+#
+# ```{solution-start} mm_ex1
+# :class: dropdown
+# ```
+#
+# Here's a solution using Numba.
+
+# %%
+# Convert JAX arrays to NumPy arrays for use with Numba
+q_default_np = np.array(q_default)
+w_default_np = np.array(w_default)
+cdf = np.cumsum(q_default_np)
+
+@numba.jit
+def compute_stopping_time(w_bar, seed=1234):
+ """
+ Compute stopping time by drawing wages until one exceeds w_bar.
+ """
+ np.random.seed(seed)
+ t = 1
+ while True:
+ # Generate a wage draw
+ w = w_default_np[qe.random.draw(cdf)]
+
+ # Stop when the draw is above the reservation wage
+ if w >= w_bar:
+ stopping_time = t
+ break
+ else:
+ t += 1
+ return stopping_time
+
+@numba.jit(parallel=True)
+def compute_mean_stopping_time(w_bar, num_reps=100000):
+ """
+ Generate a mean stopping time over `num_reps` repetitions by
+ drawing from `compute_stopping_time`.
+ """
+ obs = np.empty(num_reps)
+ for i in numba.prange(num_reps):
+ obs[i] = compute_stopping_time(w_bar, seed=i)
+ return obs.mean()
+
+c_vals = np.linspace(10, 40, 25)
+stop_times = np.empty_like(c_vals)
+for i, c in enumerate(c_vals):
+ mcm = McCallModel(c=c)
+ w_bar = compute_reservation_wage_two(mcm)
+ stop_times[i] = compute_mean_stopping_time(float(w_bar))
+
+fig, ax = plt.subplots()
+
+ax.plot(c_vals, stop_times, label="mean unemployment duration")
+ax.set(xlabel="unemployment compensation", ylabel="months")
+ax.legend()
+
+plt.show()
+
+# %% [markdown]
+# And here's a solution using JAX.
+
+# %%
+# First, we set up a function to draw random wage offers from the distribution.
+# We use the inverse transform method: draw a uniform random variable u,
+# then find the smallest wage w such that the CDF at w is >= u.
+cdf = jnp.cumsum(q_default)
+
+def draw_wage(uniform_rv):
+ """
+ Draw a wage from the distribution q_default using the inverse transform method.
+
+ Parameters:
+ -----------
+ uniform_rv : float
+ A uniform random variable on [0, 1]
+
+ Returns:
+ --------
+ wage : float
+ A wage drawn from w_default with probabilities given by q_default
+ """
+ return w_default[jnp.searchsorted(cdf, uniform_rv)]
+
+
+def compute_stopping_time(w_bar, key):
+ """
+ Compute stopping time by drawing wages until one exceeds `w_bar`.
+ """
+ def update(loop_state):
+ t, key, accept = loop_state
+ key, subkey = jax.random.split(key)
+ u = jax.random.uniform(subkey)
+ w = draw_wage(u)
+ accept = w >= w_bar
+ t = t + 1
+ return t, key, accept
+
+ def cond(loop_state):
+ _, _, accept = loop_state
+ return jnp.logical_not(accept)
+
+ initial_loop_state = (0, key, False)
+ t_final, _, _ = jax.lax.while_loop(cond, update, initial_loop_state)
+ return t_final
+
+
+def compute_mean_stopping_time(w_bar, num_reps=100000, seed=1234):
+ """
+ Generate a mean stopping time over `num_reps` repetitions by
+ drawing from `compute_stopping_time`.
+ """
+ # Generate a key for each MC replication
+ key = jax.random.PRNGKey(seed)
+ keys = jax.random.split(key, num_reps)
+
+ # Vectorize compute_stopping_time and evaluate across keys
+ compute_fn = jax.vmap(compute_stopping_time, in_axes=(None, 0))
+ obs = compute_fn(w_bar, keys)
+
+ # Return mean stopping time
+ return jnp.mean(obs)
+
+c_vals = jnp.linspace(10, 40, 25)
+
+@jax.jit
+def compute_stop_time_for_c(c):
+ """Compute mean stopping time for a given compensation value c."""
+ model = McCallModel(c=c)
+ w_bar = compute_reservation_wage_two(model)
+ return compute_mean_stopping_time(w_bar)
+
+# Vectorize across all c values
+compute_stop_time_vectorized = jax.vmap(compute_stop_time_for_c)
+stop_times = compute_stop_time_vectorized(c_vals)
+
+fig, ax = plt.subplots()
+
+ax.plot(c_vals, stop_times, label="mean unemployment duration")
+ax.set(xlabel="unemployment compensation", ylabel="months")
+ax.legend()
+
+plt.show()
+
+
+# %% [markdown]
+# At least for our hardware, Numba is faster on the CPU while JAX is faster on the GPU.
+#
+# ```{solution-end}
+# ```
+#
+# ```{exercise-start}
+# :label: mm_ex2
+# ```
+#
+# The purpose of this exercise is to show how to replace the discrete wage
+# offer distribution used above with a continuous distribution.
+#
+# This is a significant topic because many convenient distributions are
+# continuous (i.e., have a density).
+#
+# Fortunately, the theory changes little in our simple model.
+#
+# Recall that $h$ in {eq}`j1` denotes the value of not accepting a job in this period but
+# then behaving optimally in all subsequent periods:
+#
+# To shift to a continuous offer distribution, we can replace {eq}`j1` by
+#
+# ```{math}
+# :label: j1c
+#
+# h
+# = c + \beta
+# \int v^*(s') q (s') ds'.
+# \quad
+# ```
+#
+# Equation {eq}`j2` becomes
+#
+# ```{math}
+# :label: j2c
+#
+# h
+# = c + \beta
+# \int
+# \max \left\{
+# \frac{w(s')}{1 - \beta}, h
+# \right\} q (s') d s'
+# \quad
+# ```
+#
+# The aim is to solve this nonlinear equation by iteration, and from it obtain
+# the reservation wage.
+#
+# Try to carry this out, setting
+#
+# * the state sequence $\{ s_t \}$ to be IID and standard normal and
+# * the wage function to be $w(s) = \exp(\mu + \sigma s)$.
+#
+# You will need to implement a new version of the `McCallModel` class that
+# assumes a lognormal wage distribution.
+#
+# Calculate the integral by Monte Carlo, by averaging over a large number of wage draws.
+#
+# For default parameters, use `c=25, β=0.99, σ=0.5, μ=2.5`.
+#
+# Once your code is working, investigate how the reservation wage changes with $c$ and $\beta$.
+#
+# ```{exercise-end}
+# ```
+#
+# ```{solution-start} mm_ex2
+# :class: dropdown
+# ```
+#
+# Here is one solution:
+
+# %%
+class McCallModelContinuous(NamedTuple):
+ c: float # unemployment compensation
+ β: float # discount factor
+ σ: float # scale parameter in lognormal distribution
+ μ: float # location parameter in lognormal distribution
+ w_draws: jnp.ndarray # draws of wages for Monte Carlo
+
+
+def create_mccall_continuous(
+ c=25, β=0.99, σ=0.5, μ=2.5, mc_size=1000, seed=1234
+ ):
+ key = jax.random.PRNGKey(seed)
+ s = jax.random.normal(key, (mc_size,))
+ w_draws = jnp.exp(μ + σ * s)
+ return McCallModelContinuous(c, β, σ, μ, w_draws)
+
+
+@jax.jit
+def compute_reservation_wage_continuous(model, max_iter=500, tol=1e-5):
+ c, β, σ, μ, w_draws = model
+
+ h = jnp.mean(w_draws) / (1 - β) # initial guess
+
+ def update(state):
+ h, i, error = state
+ integral = jnp.mean(jnp.maximum(w_draws / (1 - β), h))
+ h_next = c + β * integral
+ error = jnp.abs(h_next - h)
+ return h_next, i + 1, error
+
+ def cond(state):
+ h, i, error = state
+ return jnp.logical_and(i < max_iter, error > tol)
+
+ initial_state = (h, 0, tol + 1)
+ final_state = jax.lax.while_loop(cond, update, initial_state)
+ h_final, _, _ = final_state
+
+ # Now compute the reservation wage
+ return (1 - β) * h_final
+
+
+# %% [markdown]
+# Now we investigate how the reservation wage changes with $c$ and
+# $\beta$.
+#
+# We will do this using a contour plot.
+
+# %%
+grid_size = 25
+c_vals = jnp.linspace(10.0, 30.0, grid_size)
+β_vals = jnp.linspace(0.9, 0.99, grid_size)
+
+def compute_R_element(c, β):
+ model = create_mccall_continuous(c=c, β=β)
+ return compute_reservation_wage_continuous(model)
+
+# First, vectorize over β (holding c fixed)
+compute_R_over_β = jax.vmap(compute_R_element, in_axes=(None, 0))
+
+# Next, vectorize over c (applying the above function to each c)
+compute_R_vectorized = jax.vmap(compute_R_over_β, in_axes=(0, None))
+
+# Apply to compute the full grid
+R = compute_R_vectorized(c_vals, β_vals)
+
+# %%
+fig, ax = plt.subplots()
+
+cs1 = ax.contourf(c_vals, β_vals, R.T, alpha=0.75)
+ctr1 = ax.contour(c_vals, β_vals, R.T)
+
+plt.clabel(ctr1, inline=1, fontsize=13)
+plt.colorbar(cs1, ax=ax)
+
+
+ax.set_title("reservation wage")
+ax.set_xlabel("$c$", fontsize=16)
+ax.set_ylabel("$β$", fontsize=16)
+
+ax.ticklabel_format(useOffset=False)
+
+plt.show()
+
+# %% [markdown]
+# ```{solution-end}
+# ```