diff --git a/lectures/mccall_fitted_vfi.md b/lectures/mccall_fitted_vfi.md index 9395cab87..62d035d31 100644 --- a/lectures/mccall_fitted_vfi.md +++ b/lectures/mccall_fitted_vfi.md @@ -44,7 +44,7 @@ $$ and $\{Z_t\}$ is IID and standard normal. -While we already considered continuous wage distributions briefly in Exercise {ref}`mm_ex2` of the {doc}`first job search lecture `, the change was relatively trivial in that case. +While we already considered continuous wage distributions briefly in {doc}`mccall_model`, the change was relatively trivial in that case. The reason is that we were able to reduce the problem to solving for a single scalar value (the continuation value). diff --git a/lectures/mccall_model.md b/lectures/mccall_model.md index febec370e..e29581369 100644 --- a/lectures/mccall_model.md +++ b/lectures/mccall_model.md @@ -735,186 +735,16 @@ def compute_reservation_wage_two( You can use this code to solve the exercise below. -## Exercises - -```{exercise} -:label: mm_ex1 - -Compute the average duration of unemployment when $\beta=0.99$ and -$c$ takes the following values - -> `c_vals = np.linspace(10, 40, 25)` - -That is, start the agent off as unemployed, compute their reservation wage -given the parameters, and then simulate to see how long it takes to accept. - -Repeat a large number of times and take the average. - -Plot mean unemployment duration as a function of $c$ in `c_vals`. -``` - -```{solution-start} mm_ex1 -:class: dropdown -``` - -Here's a solution using Numba. - -```{code-cell} ipython3 -# Convert JAX arrays to NumPy arrays for use with Numba -q_default_np = np.array(q_default) -w_default_np = np.array(w_default) -cdf = np.cumsum(q_default_np) - -@numba.jit -def compute_stopping_time(w_bar, seed=1234): - """ - Compute stopping time by drawing wages until one exceeds w_bar. - """ - np.random.seed(seed) - t = 1 - while True: - # Generate a wage draw - w = w_default_np[qe.random.draw(cdf)] - - # Stop when the draw is above the reservation wage - if w >= w_bar: - stopping_time = t - break - else: - t += 1 - return stopping_time - -@numba.jit(parallel=True) -def compute_mean_stopping_time(w_bar, num_reps=100000): - """ - Generate a mean stopping time over `num_reps` repetitions by - drawing from `compute_stopping_time`. - """ - obs = np.empty(num_reps) - for i in numba.prange(num_reps): - obs[i] = compute_stopping_time(w_bar, seed=i) - return obs.mean() - -c_vals = np.linspace(10, 40, 25) -stop_times = np.empty_like(c_vals) -for i, c in enumerate(c_vals): - mcm = McCallModel(c=c) - w_bar = compute_reservation_wage_two(mcm) - stop_times[i] = compute_mean_stopping_time(float(w_bar)) - -fig, ax = plt.subplots() - -ax.plot(c_vals, stop_times, label="mean unemployment duration") -ax.set(xlabel="unemployment compensation", ylabel="months") -ax.legend() - -plt.show() -``` - -And here's a solution using JAX. - -```{code-cell} ipython3 -# First, we set up a function to draw random wage offers from the distribution. -# We use the inverse transform method: draw a uniform random variable u, -# then find the smallest wage w such that the CDF at w is >= u. -cdf = jnp.cumsum(q_default) - -def draw_wage(uniform_rv): - """ - Draw a wage from the distribution q_default using the inverse transform method. - - Parameters: - ----------- - uniform_rv : float - A uniform random variable on [0, 1] - - Returns: - -------- - wage : float - A wage drawn from w_default with probabilities given by q_default - """ - return w_default[jnp.searchsorted(cdf, uniform_rv)] - - -def compute_stopping_time(w_bar, key): - """ - Compute stopping time by drawing wages until one exceeds `w_bar`. - """ - def update(loop_state): - t, key, accept = loop_state - key, subkey = jax.random.split(key) - u = jax.random.uniform(subkey) - w = draw_wage(u) - accept = w >= w_bar - t = t + 1 - return t, key, accept - - def cond(loop_state): - _, _, accept = loop_state - return jnp.logical_not(accept) - - initial_loop_state = (0, key, False) - t_final, _, _ = jax.lax.while_loop(cond, update, initial_loop_state) - return t_final - - -def compute_mean_stopping_time(w_bar, num_reps=100000, seed=1234): - """ - Generate a mean stopping time over `num_reps` repetitions by - drawing from `compute_stopping_time`. - """ - # Generate a key for each MC replication - key = jax.random.PRNGKey(seed) - keys = jax.random.split(key, num_reps) - - # Vectorize compute_stopping_time and evaluate across keys - compute_fn = jax.vmap(compute_stopping_time, in_axes=(None, 0)) - obs = compute_fn(w_bar, keys) - - # Return mean stopping time - return jnp.mean(obs) - -c_vals = jnp.linspace(10, 40, 25) - -@jax.jit -def compute_stop_time_for_c(c): - """Compute mean stopping time for a given compensation value c.""" - model = McCallModel(c=c) - w_bar = compute_reservation_wage_two(model) - return compute_mean_stopping_time(w_bar) - -# Vectorize across all c values -compute_stop_time_vectorized = jax.vmap(compute_stop_time_for_c) -stop_times = compute_stop_time_vectorized(c_vals) - -fig, ax = plt.subplots() +## Continuous Offer Distribution -ax.plot(c_vals, stop_times, label="mean unemployment duration") -ax.set(xlabel="unemployment compensation", ylabel="months") -ax.legend() - -plt.show() -``` - -At least for our hardware, Numba is faster on the CPU while JAX is faster on the GPU. - -```{solution-end} -``` - -```{exercise-start} -:label: mm_ex2 -``` - -The purpose of this exercise is to show how to replace the discrete wage -offer distribution used above with a continuous distribution. - -This is a significant topic because many convenient distributions are -continuous (i.e., have a density). +The discrete wage offer distribution used above is convenient for theory and +computation, but many realistic distributions are continuous (i.e., have a density). -Fortunately, the theory changes little in our simple model. +Fortunately, the theory changes little in our simple model when we shift to a +continuous offer distribution. Recall that $h$ in {eq}`j1` denotes the value of not accepting a job in this period but -then behaving optimally in all subsequent periods: +then behaving optimally in all subsequent periods. To shift to a continuous offer distribution, we can replace {eq}`j1` by @@ -944,28 +774,18 @@ h The aim is to solve this nonlinear equation by iteration, and from it obtain the reservation wage. -Try to carry this out, setting +### Implementation with Lognormal Wages -* the state sequence $\{ s_t \}$ to be IID and standard normal and -* the wage function to be $w(s) = \exp(\mu + \sigma s)$. +Let's implement this for the case where -You will need to implement a new version of the `McCallModel` class that -assumes a lognormal wage distribution. +* the state sequence $\{ s_t \}$ is IID and standard normal and +* the wage function is $w(s) = \exp(\mu + \sigma s)$. -Calculate the integral by Monte Carlo, by averaging over a large number of wage draws. +This gives us a lognormal wage distribution. -For default parameters, use `c=25, β=0.99, σ=0.5, μ=2.5`. +We use Monte Carlo integration to evaluate the integral, averaging over a large number of wage draws. -Once your code is working, investigate how the reservation wage changes with $c$ and $\beta$. - -```{exercise-end} -``` - -```{solution-start} mm_ex2 -:class: dropdown -``` - -Here is one solution: +For default parameters, we use `c=25, β=0.99, σ=0.5, μ=2.5`. ```{code-cell} ipython3 class McCallModelContinuous(NamedTuple): @@ -988,20 +808,20 @@ def create_mccall_continuous( @jax.jit def compute_reservation_wage_continuous(model, max_iter=500, tol=1e-5): c, β, σ, μ, w_draws = model - + h = jnp.mean(w_draws) / (1 - β) # initial guess - + def update(state): h, i, error = state integral = jnp.mean(jnp.maximum(w_draws / (1 - β), h)) h_next = c + β * integral error = jnp.abs(h_next - h) return h_next, i + 1, error - + def cond(state): h, i, error = state return jnp.logical_and(i < max_iter, error > tol) - + initial_state = (h, 0, tol + 1) final_state = jax.lax.while_loop(cond, update, initial_state) h_final, _, _ = final_state @@ -1010,10 +830,8 @@ def compute_reservation_wage_continuous(model, max_iter=500, tol=1e-5): return (1 - β) * h_final ``` -Now we investigate how the reservation wage changes with $c$ and -$\beta$. - -We will do this using a contour plot. +Now let's investigate how the reservation wage changes with $c$ and +$\beta$ using a contour plot. ```{code-cell} ipython3 grid_size = 25 @@ -1053,5 +871,312 @@ ax.ticklabel_format(useOffset=False) plt.show() ``` +As with the discrete case, the reservation wage increases with both patience and unemployment compensation. + +## Volatility + +An interesting feature of the McCall model is that increased volatility in wage offers +tends to increase the reservation wage. + +The intuition is that volatility is attractive to the worker because they can enjoy +the upside (high wage offers) while rejecting the downside (low wage offers). + +Hence, with more volatility, workers are more willing to continue searching rather than +accept a given offer, which means the reservation wage rises. + +To illustrate this phenomenon, we use a mean-preserving spread of the wage distribution. + +In particular, we vary the scale parameter $\sigma$ in the lognormal wage distribution +$w(s) = \exp(\mu + \sigma s)$ while adjusting $\mu$ to keep the mean constant. + +Recall that for a lognormal distribution with parameters $\mu$ and $\sigma$, the mean is +$\exp(\mu + \sigma^2/2)$. + +To keep the mean constant at some value $m$, we need: + +$$ +\mu = \ln(m) - \frac{\sigma^2}{2} +$$ + +Let's implement this and compute the reservation wage for different values of $\sigma$: + +```{code-cell} ipython3 +# Fix the mean wage +mean_wage = 20.0 + +# Create a range of volatility values +σ_vals = jnp.linspace(0.1, 1.0, 25) + +# Given σ, compute μ to maintain constant mean +def compute_μ_for_mean(σ, mean_wage): + return jnp.log(mean_wage) - (σ**2) / 2 + +# Compute reservation wage for each volatility level +res_wages_volatility = [] + +for σ in σ_vals: + μ = compute_μ_for_mean(σ, mean_wage) + model = create_mccall_continuous(σ=float(σ), μ=float(μ)) + res_wage = compute_reservation_wage_continuous(model) + res_wages_volatility.append(res_wage) + +res_wages_volatility = jnp.array(res_wages_volatility) +``` + +Now let's plot the reservation wage as a function of volatility: + +```{code-cell} ipython3 +fig, ax = plt.subplots() +ax.plot(σ_vals, res_wages_volatility, linewidth=2) +ax.set_xlabel(r'volatility ($\sigma$)', fontsize=12) +ax.set_ylabel('reservation wage', fontsize=12) +plt.show() +``` + +As expected, the reservation wage is increasing in $\sigma$. + +### Lifetime Value and Volatility + +We've seen that the reservation wage increases with volatility. + +It's also the case that maximal lifetime value increases with volatility. + +Higher volatility provides more upside potential, while at the same time +workers can protect themselves against downside risk by rejecting low offers. + +This option value translates into higher expected lifetime utility. + +To demonstrate this, we will: + +1. Compute the reservation wage for each volatility level +3. Calculate the expected discounted value of the lifetime income stream + associated with that reservation wage, using Monte Carlo. + +The simulation works as follows: + +1. Compute the present discounted value of one lifetime earnings path, from a given wage path. +2. Average over a large number of such calculations to approximate expected discounted value. + +We truncate each path at $T=100$, which provides sufficient resolution for our purposes. + +```{code-cell} ipython3 +@jax.jit +def simulate_lifetime_value(key, model, w_bar, n_periods=100): + """ + Simulate one realization of the wage path and compute lifetime value. + + Parameters: + ----------- + key : jax.random.PRNGKey + Random key for JAX + model : McCallModelContinuous + The model containing parameters + w_bar : float + The reservation wage + n_periods : int + Number of periods to simulate + + Returns: + -------- + lifetime_value : float + Discounted sum of income over n_periods + """ + c, β, σ, μ, w_draws = model + + # Draw all wage offers upfront + key, subkey = jax.random.split(key) + s_vals = jax.random.normal(subkey, (n_periods,)) + wage_offers = jnp.exp(μ + σ * s_vals) + + # Determine which offers are acceptable + accept = wage_offers >= w_bar + + # Track employment status: employed from first acceptance onward + employed = jnp.cumsum(accept) > 0 + + # Get the accepted wage (first wage where accept is True) + first_accept_idx = jnp.argmax(accept) + accepted_wage = wage_offers[first_accept_idx] + + # Earnings at each period: accepted_wage if employed, c if unemployed + earnings = jnp.where(employed, accepted_wage, c) + + # Compute discounted sum + periods = jnp.arange(n_periods) + discount_factors = β ** periods + lifetime_value = jnp.sum(discount_factors * earnings) + + return lifetime_value + + +@jax.jit +def compute_mean_lifetime_value(model, w_bar, num_reps=10000, seed=1234): + """ + Compute mean lifetime value across many simulations. + + """ + key = jax.random.PRNGKey(seed) + keys = jax.random.split(key, num_reps) + + # Vectorize the simulation across all replications + simulate_fn = jax.vmap(simulate_lifetime_value, in_axes=(0, None, None)) + lifetime_values = simulate_fn(keys, model, w_bar) + return jnp.mean(lifetime_values) +``` + +Now let's compute the expected lifetime value for each volatility level: + +```{code-cell} ipython3 +# Use the same volatility range and mean wage +σ_vals = jnp.linspace(0.1, 1.0, 25) +mean_wage = 20.0 + +lifetime_vals = [] +for σ in σ_vals: + μ = compute_μ_for_mean(σ, mean_wage) + model = create_mccall_continuous(σ=σ, μ=μ) + lv = compute_mean_lifetime_value(model, w_bar) + lifetime_vals.append(lv) + +``` + +Let's visualize the expected lifetime value as a function of volatility: + +```{code-cell} ipython3 +fig, ax = plt.subplots() +ax.plot(σ_vals, lifetime_vals, linewidth=2, color='green') +ax.set_xlabel(r'volatility ($\sigma$)', fontsize=12) +ax.set_ylabel('expected lifetime value', fontsize=12) +plt.show() +``` + +The plot confirms that despite workers setting higher reservation wages when facing +more volatile wage offers (as shown above), they achieve higher expected lifetime +values due to the option value of search. + + +## Exercises + + +```{exercise} +:label: mm_ex1 + +Compute the average duration of unemployment when $\beta=0.99$ and +$c$ takes the following values + +> `c_vals = np.linspace(10, 40, 4)` + +That is, start the agent off as unemployed, compute their reservation wage +given the parameters, and then simulate to see how long it takes to accept. + +Repeat a large number of times and take the average. + +Plot mean unemployment duration as a function of $c$ in `c_vals`. + +Try to explain what you see. +``` + +```{solution-start} mm_ex1 +:class: dropdown +``` + +Here's a solution using the continuous wage offer distribution with JAX. + +```{code-cell} ipython3 +def compute_stopping_time_continuous(w_bar, key, model): + """ + Compute stopping time by drawing wages from the continuous distribution + until one exceeds `w_bar`. + + Parameters: + ----------- + w_bar : float + The reservation wage + key : jax.random.PRNGKey + Random key for JAX + model : McCallModelContinuous + The model containing wage draws + + Returns: + -------- + t_final : int + The stopping time (number of periods until acceptance) + """ + c, β, σ, μ, w_draws = model + + def update(loop_state): + t, key, accept = loop_state + key, subkey = jax.random.split(key) + # Draw a standard normal and transform to wage + s = jax.random.normal(subkey) + w = jnp.exp(μ + σ * s) + accept = w >= w_bar + t = t + 1 + return t, key, accept + + def cond(loop_state): + _, _, accept = loop_state + return jnp.logical_not(accept) + + initial_loop_state = (0, key, False) + t_final, _, _ = jax.lax.while_loop(cond, update, initial_loop_state) + return t_final + + +def compute_mean_stopping_time_continuous(w_bar, model, num_reps=100000, seed=1234): + """ + Generate a mean stopping time over `num_reps` repetitions. + + Parameters: + ----------- + w_bar : float + The reservation wage + model : McCallModelContinuous + The model containing parameters + num_reps : int + Number of simulation replications + seed : int + Random seed + + Returns: + -------- + mean_time : float + Average stopping time across all replications + """ + # Generate a key for each MC replication + key = jax.random.PRNGKey(seed) + keys = jax.random.split(key, num_reps) + + # Vectorize compute_stopping_time_continuous and evaluate across keys + compute_fn = jax.vmap(compute_stopping_time_continuous, in_axes=(None, 0, None)) + obs = compute_fn(w_bar, keys, model) + + # Return mean stopping time + return jnp.mean(obs) + + +# Compute mean stopping time for different values of c +c_vals = jnp.linspace(10, 40, 4) + +@jax.jit +def compute_stop_time_for_c_continuous(c): + """Compute mean stopping time for a given compensation value c.""" + model = create_mccall_continuous(c=c) + w_bar = compute_reservation_wage_continuous(model) + return compute_mean_stopping_time_continuous(w_bar, model) + +# Vectorize across all c values +compute_stop_time_vectorized = jax.vmap(compute_stop_time_for_c_continuous) +stop_times = compute_stop_time_vectorized(c_vals) + +fig, ax = plt.subplots() + +ax.plot(c_vals, stop_times, label="mean unemployment duration") +ax.set(xlabel="unemployment compensation", ylabel="months") +ax.legend() + +plt.show() +``` + ```{solution-end} ```