diff --git a/lectures/mccall_model.md b/lectures/mccall_model.md index 202b9d591..6205e276b 100644 --- a/lectures/mccall_model.md +++ b/lectures/mccall_model.md @@ -34,11 +34,11 @@ and the pros and cons as they themselves see them." -- Robert E. Lucas, Jr. In addition to what's in Anaconda, this lecture will need the following libraries: -```{code-cell} ipython +```{code-cell} ipython3 --- tags: [hide-output] --- -!pip install quantecon +!pip install quantecon jax ``` ## Overview @@ -62,8 +62,10 @@ Let's start with some imports: ```{code-cell} ipython import matplotlib.pyplot as plt import numpy as np -from numba import jit, float64 -from numba.experimental import jitclass +import jax +import jax.numpy as jnp +import jax.random as jr +from typing import NamedTuple import quantecon as qe from quantecon.distributions import BetaBinomial ``` @@ -91,9 +93,11 @@ At time $t$, our agent has two choices: The agent is infinitely lived and aims to maximize the expected discounted sum of earnings -$$ -\mathbb{E} \sum_{t=0}^{\infty} \beta^t y_t -$$ +```{math} +:label: obj_model + +{\mathbb E} \sum_{t=0}^\infty \beta^t u(y_t) +``` The constant $\beta$ lies in $(0, 1)$ and is called a **discount factor**. @@ -138,7 +142,7 @@ $w \in \mathbb{W}$. In particular, the agent has wage offer $w$ in hand. More precisely, $v^*(w)$ denotes the value of the objective function -{eq}`objective` when an agent in this situation makes *optimal* decisions now +{eq}`obj_model` when an agent in this situation makes *optimal* decisions now and at all future points in time. Of course $v^*(w)$ is not trivial to calculate because we don't yet know @@ -343,14 +347,14 @@ Our default for $q$, the distribution of the state process, will be ```{code-cell} python3 n, a, b = 50, 200, 100 # default parameters -q_default = BetaBinomial(n, a, b).pdf() # default choice of q +q_default = jnp.array(BetaBinomial(n, a, b).pdf()) ``` Our default set of values for wages will be ```{code-cell} python3 w_min, w_max = 10, 60 -w_default = np.linspace(w_min, w_max, n+1) +w_default = jnp.linspace(w_min, w_max, n+1) ``` Here's a plot of the probabilities of different wage outcomes: @@ -364,60 +368,32 @@ ax.set_ylabel('probabilities') plt.show() ``` -We are going to use Numba to accelerate our code. - -* See, in particular, the discussion of `@jitclass` in [our lecture on Numba](https://python-programming.quantecon.org/numba.html). - -The following helps Numba by providing some type specifications. - -```{code-cell} python3 -mccall_data = [ - ('c', float64), # unemployment compensation - ('β', float64), # discount factor - ('w', float64[::1]), # array of wage values, w[i] = wage at state i - ('q', float64[::1]) # array of probabilities -] -``` - -```{note} -Note the use of `[::1]` in the array type declarations above. - -This notation specifies that the arrays should be C-contiguous. - -This is important for performance, especially when using the `@` operator for matrix multiplication (e.g., `v @ q`). +We are going to use JAX to accelerate our code. -Without this specification, Numba might need to handle non-contiguous arrays, which can significantly slow down these operations. +* We'll use NamedTuple for our model class to maintain immutability, which works well with JAX's functional programming paradigm. -Try to replace `[::1]` with `[:]` and see what happens. -``` - -Here's a class that stores the data and computes the values of state-action pairs, -i.e. the value in the maximum bracket on the right hand side of the Bellman equation {eq}`odu_pv2p`, -given the current state and an arbitrary feasible action. - -Default parameter values are embedded in the class. +Here's a class that stores the model parameters with default values, and a separate function that computes the values of state-action pairs (i.e., the value in the maximum bracket on the right hand side of the Bellman equation {eq}`odu_pv2p`). ```{code-cell} python3 -@jitclass(mccall_data) -class McCallModel: - - def __init__(self, c=25, β=0.99, w=w_default, q=q_default): - - self.c, self.β = c, β - self.w, self.q = w_default, q_default - - def state_action_values(self, i, v): - """ - The values of state-action pairs. - """ - # Simplify names - c, β, w, q = self.c, self.β, self.w, self.q - # Evaluate value for each state-action pair - # Consider action = accept or reject the current offer - accept = w[i] / (1 - β) - reject = c + β * (v @ q) - - return np.array([accept, reject]) +class McCallModel(NamedTuple): + c: float = 25 # unemployment compensation + β: float = 0.99 # discount factor + w: jnp.ndarray = w_default # array of wage values, w[i] = wage at state i + q: jnp.ndarray = q_default # array of probabilities + +@jax.jit +def state_action_values(model, i, v): + """ + The values of state-action pairs. + """ + # Simplify names + c, β, w, q = model.c, model.β, model.w, model.q + # Evaluate value for each state-action pair + # Consider action = accept or reject the current offer + accept = w[i] / (1 - β) + reject = c + β * (v @ q) + + return jnp.array([accept, reject]) ``` Based on these defaults, let's try plotting the first few approximate value functions @@ -439,13 +415,13 @@ def plot_value_function_seq(mcm, ax, num_plots=6): n = len(mcm.w) v = mcm.w / (1 - mcm.β) - v_next = np.empty_like(v) for i in range(num_plots): ax.plot(mcm.w, v, '-', alpha=0.4, label=f"iterate {i}") # Update guess + v_next = jnp.zeros_like(v) for j in range(n): - v_next[j] = np.max(mcm.state_action_values(j, v)) - v[:] = v_next # copy contents into v + v_next = v_next.at[j].set(jnp.max(state_action_values(mcm, j, v))) + v = v_next # update v ax.legend(loc='lower right') ``` @@ -469,37 +445,35 @@ Here's a more serious iteration effort to compute the limit, which continues unt Once we obtain a good approximation to the limit, we will use it to calculate the reservation wage. -We'll be using JIT compilation via Numba to turbocharge our loops. +We'll be using JIT compilation via JAX to accelerate our loops. ```{code-cell} python3 -@jit -def compute_reservation_wage(mcm, - max_iter=500, - tol=1e-6): - +@jax.jit +def compute_reservation_wage(mcm, max_iter=500, tol=1e-6): # Simplify names c, β, w, q = mcm.c, mcm.β, mcm.w, mcm.q - - # == First compute the value function == # - + + # First compute the value function n = len(w) - v = w / (1 - β) # initial guess - v_next = np.empty_like(v) - j = 0 - error = tol + 1 - while j < max_iter and error > tol: - + v = w / (1 - β) # initial guess + + def body_fun(state): + v, i, error = state + v_next = jnp.zeros_like(v) for j in range(n): - v_next[j] = np.max(mcm.state_action_values(j, v)) - - error = np.max(np.abs(v_next - v)) - j += 1 - - v[:] = v_next # copy contents into v - - # == Now compute the reservation wage == # - - return (1 - β) * (c + β * (v @ q)) + v_next = v_next.at[j].set(jnp.max(state_action_values(mcm, j, v))) + error = jnp.max(jnp.abs(v_next - v)) + return v_next, i + 1, error + + def cond_fun(state): + v, i, error = state + return jnp.logical_and(i < max_iter, error > tol) + + initial_state = (v, 0, tol + 1) + v_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state) + + # Now compute the reservation wage + return (1 - β) * (c + β * (v_final @ q)) ``` The next line computes the reservation wage at default parameters @@ -518,15 +492,17 @@ $c$. ```{code-cell} python3 grid_size = 25 -R = np.empty((grid_size, grid_size)) +c_vals = jnp.linspace(10.0, 30.0, grid_size) +β_vals = jnp.linspace(0.9, 0.99, grid_size) -c_vals = np.linspace(10.0, 30.0, grid_size) -β_vals = np.linspace(0.9, 0.99, grid_size) +def compute_R_element(c, β): + mcm = McCallModel(c=c, β=β) + return compute_reservation_wage(mcm) -for i, c in enumerate(c_vals): - for j, β in enumerate(β_vals): - mcm = McCallModel(c=c, β=β) - R[i, j] = compute_reservation_wage(mcm) +# Create meshgrid and vectorize computation +c_grid, β_grid = jnp.meshgrid(c_vals, β_vals, indexing='ij') +compute_R_vectorized = jax.vmap(jax.vmap(compute_R_element, in_axes=(None, 0)), in_axes=(0, None)) +R = compute_R_vectorized(c_vals, β_vals) ``` ```{code-cell} python3 @@ -623,32 +599,30 @@ The big difference here, however, is that we're iterating on a scalar $h$, rathe Here's an implementation: ```{code-cell} python3 -@jit -def compute_reservation_wage_two(mcm, - max_iter=500, - tol=1e-5): - +@jax.jit +def compute_reservation_wage_two(mcm, max_iter=500, tol=1e-5): # Simplify names c, β, w, q = mcm.c, mcm.β, mcm.w, mcm.q - - # == First compute h == # - + + # First compute h h = (w @ q) / (1 - β) - i = 0 - error = tol + 1 - while i < max_iter and error > tol: - - s = np.maximum(w / (1 - β), h) + + def body_fun(state): + h, i, error = state + s = jnp.maximum(w / (1 - β), h) h_next = c + β * (s @ q) - - error = np.abs(h_next - h) - i += 1 - - h = h_next - - # == Now compute the reservation wage == # - - return (1 - β) * h + error = jnp.abs(h_next - h) + return h_next, i + 1, error + + def cond_fun(state): + h, i, error = state + return jnp.logical_and(i < max_iter, error > tol) + + initial_state = (h, 0, tol + 1) + h_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state) + + # Now compute the reservation wage + return (1 - β) * h_final ``` You can use this code to solve the exercise below. @@ -678,37 +652,42 @@ Plot mean unemployment duration as a function of $c$ in `c_vals`. Here's one solution ```{code-cell} python3 -cdf = np.cumsum(q_default) - -@jit -def compute_stopping_time(w_bar, seed=1234): - - np.random.seed(seed) - t = 1 - while True: - # Generate a wage draw - w = w_default[qe.random.draw(cdf)] - # Stop when the draw is above the reservation wage - if w >= w_bar: - stopping_time = t - break - else: - t += 1 - return stopping_time - -@jit -def compute_mean_stopping_time(w_bar, num_reps=100000): - obs = np.empty(num_reps) - for i in range(num_reps): - obs[i] = compute_stopping_time(w_bar, seed=i) - return obs.mean() - -c_vals = np.linspace(10, 40, 25) -stop_times = np.empty_like(c_vals) -for i, c in enumerate(c_vals): +cdf = jnp.cumsum(q_default) + +@jax.jit +def compute_stopping_time(w_bar, key): + def body_fun(state): + t, key, done = state + key, subkey = jr.split(key) + u = jr.uniform(subkey) + w = w_default[jnp.searchsorted(cdf, u)] + done = w >= w_bar + t = jnp.where(done, t, t + 1) + return t, key, done + + def cond_fun(state): + t, _, done = state + return jnp.logical_not(done) + + initial_state = (1, key, False) + t_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state) + return t_final + +@jax.jit +def compute_mean_stopping_time(w_bar, num_reps=100000, seed=1234): + key = jr.PRNGKey(seed) + keys = jr.split(key, num_reps) + obs = jax.vmap(compute_stopping_time, in_axes=(None, 0))(w_bar, keys) + return jnp.mean(obs) + +c_vals = jnp.linspace(10, 40, 25) + +def compute_stop_time_for_c(c): mcm = McCallModel(c=c) w_bar = compute_reservation_wage_two(mcm) - stop_times[i] = compute_mean_stopping_time(w_bar) + return compute_mean_stopping_time(w_bar) + +stop_times = jax.vmap(compute_stop_time_for_c)(c_vals) fig, ax = plt.subplots() @@ -789,48 +768,41 @@ Once your code is working, investigate how the reservation wage changes with $c$ Here is one solution: ```{code-cell} python3 -mccall_data_continuous = [ - ('c', float64), # unemployment compensation - ('β', float64), # discount factor - ('σ', float64), # scale parameter in lognormal distribution - ('μ', float64), # location parameter in lognormal distribution - ('w_draws', float64[:]) # draws of wages for Monte Carlo -] - -@jitclass(mccall_data_continuous) -class McCallModelContinuous: - - def __init__(self, c=25, β=0.99, σ=0.5, μ=2.5, mc_size=1000): - - self.c, self.β, self.σ, self.μ = c, β, σ, μ - - # Draw and store shocks - np.random.seed(1234) - s = np.random.randn(mc_size) - self.w_draws = np.exp(μ+ σ * s) - - -@jit +class McCallModelContinuous(NamedTuple): + c: float # unemployment compensation + β: float # discount factor + σ: float # scale parameter in lognormal distribution + μ: float # location parameter in lognormal distribution + w_draws: jnp.ndarray # draws of wages for Monte Carlo + +def create_mccall_continuous(c=25, β=0.99, σ=0.5, μ=2.5, mc_size=1000, seed=1234): + key = jr.PRNGKey(seed) + s = jr.normal(key, (mc_size,)) + w_draws = jnp.exp(μ + σ * s) + return McCallModelContinuous(c=c, β=β, σ=σ, μ=μ, w_draws=w_draws) + +@jax.jit def compute_reservation_wage_continuous(mcmc, max_iter=500, tol=1e-5): - c, β, σ, μ, w_draws = mcmc.c, mcmc.β, mcmc.σ, mcmc.μ, mcmc.w_draws - - h = np.mean(w_draws) / (1 - β) # initial guess - i = 0 - error = tol + 1 - while i < max_iter and error > tol: - - integral = np.mean(np.maximum(w_draws / (1 - β), h)) + + h = jnp.mean(w_draws) / (1 - β) # initial guess + + def body_fun(state): + h, i, error = state + integral = jnp.mean(jnp.maximum(w_draws / (1 - β), h)) h_next = c + β * integral - - error = np.abs(h_next - h) - i += 1 - - h = h_next - - # == Now compute the reservation wage == # - - return (1 - β) * h + error = jnp.abs(h_next - h) + return h_next, i + 1, error + + def cond_fun(state): + h, i, error = state + return jnp.logical_and(i < max_iter, error > tol) + + initial_state = (h, 0, tol + 1) + h_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state) + + # Now compute the reservation wage + return (1 - β) * h_final ``` Now we investigate how the reservation wage changes with $c$ and @@ -840,15 +812,20 @@ We will do this using a contour plot. ```{code-cell} python3 grid_size = 25 -R = np.empty((grid_size, grid_size)) - -c_vals = np.linspace(10.0, 30.0, grid_size) -β_vals = np.linspace(0.9, 0.99, grid_size) - -for i, c in enumerate(c_vals): - for j, β in enumerate(β_vals): - mcmc = McCallModelContinuous(c=c, β=β) - R[i, j] = compute_reservation_wage_continuous(mcmc) +c_vals = jnp.linspace(10.0, 30.0, grid_size) +β_vals = jnp.linspace(0.9, 0.99, grid_size) + +def compute_R_element(c, β): + mcmc = create_mccall_continuous(c=c, β=β) + return compute_reservation_wage_continuous(mcmc) + +# Create meshgrid and vectorize computation +c_grid, β_grid = jnp.meshgrid(c_vals, β_vals, indexing='ij') +compute_R_vectorized = jax.vmap( + jax.vmap(compute_R_element, + in_axes=(None, 0)), + in_axes=(0, None)) +R = compute_R_vectorized(c_vals, β_vals) ``` ```{code-cell} python3