diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3bc26f081..80e4d9a83 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,9 +31,9 @@ jobs: - name: Install JAX, Numpyro, PyTorch shell: bash -l {0} run: | - pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 - pip install pyro-ppl - pip install --upgrade "jax[cuda12-local]==0.6.2" + # pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 + # pip install pyro-ppl + pip install "jax[cuda12-local]==0.6.2" pip install numpyro pyro-ppl python scripts/test-jax-install.py - name: Check nvidia Drivers diff --git a/lectures/_static/lecture_specific/mccall_model_with_separation/mccall_resw_alpha.png b/lectures/_static/lecture_specific/mccall_model_with_separation/mccall_resw_alpha.png deleted file mode 100644 index 706fa128d..000000000 Binary files a/lectures/_static/lecture_specific/mccall_model_with_separation/mccall_resw_alpha.png and /dev/null differ diff --git a/lectures/_static/lecture_specific/mccall_model_with_separation/mccall_resw_beta.png b/lectures/_static/lecture_specific/mccall_model_with_separation/mccall_resw_beta.png deleted file mode 100644 index 80f320b25..000000000 Binary files a/lectures/_static/lecture_specific/mccall_model_with_separation/mccall_resw_beta.png and /dev/null differ diff --git a/lectures/_static/lecture_specific/mccall_model_with_separation/mccall_resw_c.png b/lectures/_static/lecture_specific/mccall_model_with_separation/mccall_resw_c.png deleted file mode 100644 index a969cf02a..000000000 Binary files a/lectures/_static/lecture_specific/mccall_model_with_separation/mccall_resw_c.png and /dev/null differ diff --git a/lectures/mccall_model_with_sep_markov.md b/lectures/mccall_model_with_sep_markov.md index 76342a870..8ba89fa60 100644 --- a/lectures/mccall_model_with_sep_markov.md +++ b/lectures/mccall_model_with_sep_markov.md @@ -4,11 +4,11 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.17.1 + jupytext_version: 1.17.2 kernelspec: - name: python3 display_name: Python 3 (ipykernel) language: python + name: python3 --- (mccall_with_sep_markov)= @@ -49,7 +49,7 @@ libraries ```{code-cell} ipython3 :tags: [hide-output] -!pip install quantecon jax +!pip install quantecon ``` We use the following imports: @@ -58,7 +58,7 @@ We use the following imports: from quantecon.markov import tauchen import jax.numpy as jnp import jax -from jax import jit, lax +from jax import lax from typing import NamedTuple import matplotlib.pyplot as plt from functools import partial @@ -138,48 +138,11 @@ The optimal policy turns out to be a reservation wage strategy: accept all wages ## Code - -First, we implement the successive approximation algorithm. - -This algorithm takes an operator $T$ and an initial condition and iterates until -convergence. - -We will use it for value function iteration. - -```{code-cell} ipython3 -@partial(jit, static_argnums=(0,)) -def successive_approx( - T, # Operator (callable) - marked as static - x_0, # Initial condition - tolerance: float = 1e-6, # Error tolerance - max_iter: int = 100_000, # Max iteration bound - ): - """Computes the approximate fixed point of T via successive - approximation using lax.while_loop.""" - - def cond_fn(carry): - x, error, k = carry - return (error > tolerance) & (k <= max_iter) - - def body_fn(carry): - x, error, k = carry - x_new = T(x) - error = jnp.max(jnp.abs(x_new - x)) - return (x_new, error, k + 1) - - initial_carry = (x_0, tolerance + 1, 1) - x_final, _, _ = lax.while_loop(cond_fn, body_fn, initial_carry) - - return x_final -``` - - -Next let's set up a `Model` class to store information needed to solve the model. +Let's set up a `Model` class to store information needed to solve the model. We include `P_cumsum`, the row-wise cumulative sum of the transition matrix, to optimize the simulation -- the details are explained below. - ```{code-cell} ipython3 class Model(NamedTuple): n: int @@ -215,7 +178,6 @@ def create_js_with_sep_model( Here's the Bellman operator for the unemployed worker's value function: ```{code-cell} ipython3 -@jit def T(v: jnp.ndarray, model: Model) -> jnp.ndarray: """The Bellman operator for the value of being unemployed.""" n, w_vals, P, P_cumsum, β, c, α = model @@ -229,7 +191,6 @@ The next function computes the optimal policy under the assumption that $v$ is the value function: ```{code-cell} ipython3 -@jit def get_greedy(v: jnp.ndarray, model: Model) -> jnp.ndarray: """Get a v-greedy policy.""" n, w_vals, P, P_cumsum, β, c, α = model @@ -247,14 +208,34 @@ The second routine requires a policy function, which we will typically obtain by applying the `vfi` function. ```{code-cell} ipython3 -def vfi(model: Model): - """Solve by VFI.""" +@jax.jit +def vfi( + model: Model, + tolerance: float = 1e-6, # Error tolerance + max_iter: int = 100_000, # Max iteration bound + ): + v_init = jnp.zeros(model.w_vals.shape) - v_star = successive_approx(lambda v: T(v, model), v_init) - σ_star = get_greedy(v_star, model) - return v_star, σ_star + + def cond(loop_state): + v, error, i = loop_state + return (error > tolerance) & (i <= max_iter) + + def update(loop_state): + v, error, i = loop_state + v_new = T(v, model) + error = jnp.max(jnp.abs(v_new - v)) + new_loop_state = v_new, error, i + 1 + return new_loop_state + + initial_state = (v_init, tolerance + 1, 1) + final_loop_state = lax.while_loop(cond, update, initial_state) + v_final, error, i = final_loop_state + return v_final + +@jax.jit def get_reservation_wage(σ: jnp.ndarray, model: Model) -> float: """ Calculate the reservation wage from a given policy. @@ -268,17 +249,15 @@ def get_reservation_wage(σ: jnp.ndarray, model: Model) -> float: """ n, w_vals, P, P_cumsum, β, c, α = model - # Find all wage indices where policy indicates acceptance - accept_indices = jnp.where(σ == 1)[0] - - if len(accept_indices) == 0: - return jnp.inf # Agent never accepts any wage + # Find the first index where policy indicates acceptance + # σ is a boolean array, argmax returns the first True value + first_accept_idx = jnp.argmax(σ) - # Return the lowest wage that is accepted - return w_vals[accept_indices[0]] + # If no acceptance (all False), return infinity + # Otherwise return the wage at the first acceptance index + return jnp.where(jnp.any(σ), w_vals[first_accept_idx], jnp.inf) ``` - ## Computing the Solution Let's solve the model: @@ -286,7 +265,8 @@ Let's solve the model: ```{code-cell} ipython3 model = create_js_with_sep_model() n, w_vals, P, P_cumsum, β, c, α = model -v_star, σ_star = vfi(model) +v_star = vfi(model) +σ_star = get_greedy(v_star, model) ``` Next we compute some related quantities, including the reservation wage. @@ -312,19 +292,18 @@ ax.set_xlabel(r"$w$") plt.show() ``` - ## Sensitivity Analysis Let's examine how reservation wages change with the separation rate. - ```{code-cell} ipython3 α_vals: jnp.ndarray = jnp.linspace(0.0, 1.0, 10) w_star_vec = jnp.empty_like(α_vals) for (i_α, α) in enumerate(α_vals): model = create_js_with_sep_model(α=α) - v_star, σ_star = vfi(model) + v_star = vfi(model) + σ_star = get_greedy(v_star, model) w_star = get_reservation_wage(σ_star, model) w_star_vec = w_star_vec.at[i_α].set(w_star) @@ -356,9 +335,8 @@ This is implemented via `jnp.searchsorted` on the precomputed cumulative sum The function `update_agent` advances the agent's state by one period. - ```{code-cell} ipython3 -@jit +@jax.jit def update_agent(key, is_employed, wage_idx, model, σ): """ Updates an agent by one period. Updates their employment status and their @@ -439,7 +417,8 @@ Let's create a comprehensive plot of the employment simulation: model = create_js_with_sep_model() # Calculate reservation wage for plotting -v_star, σ_star = vfi(model) +v_star = vfi(model) +σ_star = get_greedy(v_star, model) w_star = get_reservation_wage(σ_star, model) wage_path, employment_status = simulate_employment_path(model, σ_star) @@ -486,7 +465,6 @@ plt.tight_layout() plt.show() ``` - The simulation helps to visualize outcomes associated with this model. The agent follows a reservation wage strategy. @@ -531,7 +509,7 @@ This holds because: These properties ensure the chain is ergodic with a unique stationary distribution $\pi$ over states $(s, w)$. -For an ergodic Markov chain, the ergodic theorem guarantees that time averages = ensemble averages. +For an ergodic Markov chain, the ergodic theorem guarantees that time averages = cross-sectional averages. In particular, the fraction of time a single agent spends unemployed (across all wage states) converges to the cross-sectional unemployment rate: @@ -568,7 +546,7 @@ update_agents_vmap = jax.vmap( Next we define the core simulation function, which uses `lax.fori_loop` to efficiently iterate many agents forward in time: ```{code-cell} ipython3 -@partial(jit, static_argnums=(3, 4)) +@partial(jax.jit, static_argnums=(3, 4)) def _simulate_cross_section_compiled( key: jnp.ndarray, model: Model, @@ -627,7 +605,8 @@ def simulate_cross_section( key = jax.random.PRNGKey(seed) # Solve for optimal policy - v_star, σ_star = vfi(model) + v_star = vfi(model) + σ_star = get_greedy(v_star, model) # Run JIT-compiled simulation final_employment = _simulate_cross_section_compiled( @@ -655,7 +634,8 @@ def plot_cross_sectional_unemployment(model: Model, t_snapshot: int = 200, """ # Get final employment state directly key = jax.random.PRNGKey(42) - v_star, σ_star = vfi(model) + v_star = vfi(model) + σ_star = get_greedy(v_star, model) final_employment = _simulate_cross_section_compiled( key, model, σ_star, n_agents, t_snapshot ) @@ -681,7 +661,12 @@ def plot_cross_sectional_unemployment(model: Model, t_snapshot: int = 200, plt.show() ``` -Now let's compare the time-average unemployment rate (from a single agent's long simulation) with the cross-sectional unemployment rate (from many agents at a single point in time): +Now let's compare the time-average unemployment rate (from a single agent's long simulation) with the cross-sectional unemployment rate (from many agents at a single point in time). + +We claimed above that these numbers will be approximately equal in large +samples, due to ergodicity. + +Let's see if that's true. ```{code-cell} ipython3 model = create_js_with_sep_model() @@ -697,28 +682,31 @@ print(f"Cross-sectional unemployment rate (at t=200): " print(f"Difference: {abs(time_avg_unemp - cross_sectional_unemp):.4f}") ``` +Indeed, they are very close. + Now let's visualize the cross-sectional distribution: ```{code-cell} ipython3 plot_cross_sectional_unemployment(model) ``` -## Cross-Sectional Analysis with Lower Unemployment Compensation (c=0.5) +## Lower Unemployment Compensation (c=0.5) -Let's examine how the cross-sectional unemployment rate changes with lower unemployment compensation: +What happens to the cross-sectional unemployment rate with lower unemployment compensation? ```{code-cell} ipython3 model_low_c = create_js_with_sep_model(c=0.5) plot_cross_sectional_unemployment(model_low_c) ``` + ## Exercises ```{exercise-start} :label: mmwsm_ex1 ``` -Create a plot that shows how the steady state cross-sectional unemployment rate +Create a plot that investigates more carefully how the steady state cross-sectional unemployment rate changes with unemployment compensation. ```{exercise-end} @@ -751,4 +739,3 @@ plt.show() ```{solution-end} ``` - diff --git a/lectures/mccall_model_with_separation.md b/lectures/mccall_model_with_separation.md index 2911dc139..612e1a290 100644 --- a/lectures/mccall_model_with_separation.md +++ b/lectures/mccall_model_with_separation.md @@ -41,9 +41,9 @@ In addition to what's in Anaconda, this lecture will need the following librarie Previously {doc}`we looked ` at the McCall job search model {cite}`McCall1970` as a way of understanding unemployment and worker decisions. -One unrealistic feature of the model is that every job is permanent. +One unrealistic feature of that version of the model was that every job is permanent. -In this lecture, we extend the McCall model by introducing job separation. +In this lecture, we extend the model by introducing job separation. Once separation enters the picture, the agent comes to view @@ -62,6 +62,7 @@ import jax import jax.numpy as jnp from typing import NamedTuple from quantecon.distributions import BetaBinomial +from myst_nb import glue ``` ## The Model @@ -125,53 +126,78 @@ We drop time subscripts in what follows and primes denote next period values. Let -* $v_e(w)$ be total lifetime value accruing to a worker who enters the current period *employed* with existing wage $w$ -* $v_u(w)$ be total lifetime value accruing to a worker who who enters the current period *unemployed* and receives - wage offer $w$. +* $v_e(w)$ be maximum lifetime value accruing to a worker who enters the current + period *employed* with existing wage $w$ +* $v_u(w)$ be maximum lifetime value accruing to a worker who who enters the + current period *unemployed* and receives wage offer $w$. -Here *value* means the value of the objective function {eq}`objective` when the worker makes optimal decisions at all future points in time. +Here **maximum lifetime value** means the value of {eq}`objective` when +the worker makes optimal decisions at all future points in time. -Our first aim is to obtain these functions. +As we now show, these obtaining these functions is key to solving the new model. ### The Bellman Equations -Suppose for now that the worker can calculate the functions $v_e$ and $v_u$ and use them in his decision making. +We recall that, in {doc}`the original job search model `, the +value function (the value of being unemployed with a given wage offer) satisfied +a Bellman equation. -Then $v_e$ and $v_u$ should satisfy +Here this function again satisfies a Bellman equation that looks very similar. -```{math} -:label: bell1_mccall - -v_e(w) = u(w) + \beta - \left[ - (1-\alpha)v_e(w) + \alpha \sum_{w' \in \mathbb W} v_u(w') q(w') - \right] -``` - -and ```{math} :label: bell2_mccall -v_u(w) = \max \left\{ v_e(w), \, u(c) + \beta \sum_{w' \in \mathbb W} v_u(w') q(w') \right\} + v_u(w) = \max + \left\{ + v_e(w), \, + u(c) + \beta \sum_{w' \in \mathbb W} v_u(w') q(w') + \right\} ``` -Equation {eq}`bell1_mccall` expresses the value of being employed at wage $w$ in terms of +The difference is that the value of accepting is $v_e(w)$ rather than +$w/(1-\beta)$. -* current reward $u(w)$ plus -* discounted expected reward tomorrow, given the $\alpha$ probability of being fired +We have to make this change because jobs are not permanent. + +Accepting transitions the worker to employment and hence yields reward $v_e(w)$. + +Rejecting leads to unemployment compensation and unemployment tomorrow. Equation {eq}`bell2_mccall` expresses the value of being unemployed with offer $w$ in hand as a maximum over the value of two options: accept or reject the current offer. -Accepting transitions the worker to employment and hence yields reward $v_e(w)$. +The function $v_e$ also satisfies a Bellman equation: -Rejecting leads to unemployment compensation and unemployment tomorrow. +```{math} +:label: bell1_mccall -Equations {eq}`bell1_mccall` and {eq}`bell2_mccall` are the Bellman equations for this model. + v_e(w) = u(w) + \beta + \left[ + (1-\alpha)v_e(w) + \alpha \sum_{w' \in \mathbb W} v_u(w') q(w') + \right] +``` -They provide enough information to solve for both $v_e$ and $v_u$. +```{note} +This equation differs from a traditional Bellman equation because there is no max. + +There is no max because an employed agent has no choices. + +Nonetheless, in keeping with most of the literature, we also refer to it as a +Bellman equation. + +``` + +Equation {eq}`bell1_mccall` expresses the value of being employed at wage $w$ in terms of + +* current reward $u(w)$ plus +* discounted expected reward tomorrow, given the $\alpha$ probability of being fired + +As we will see, equations {eq}`bell1_mccall` and {eq}`bell2_mccall` provide +enough information to solve for both $v_e$ and $v_u$. + +Once we have them in hand, we will be able to make optimal choices. (ast_mcm)= ### A Simplifying Transformation @@ -180,66 +206,79 @@ Rather than jumping straight into solving these equations, let's see if we can simplify them somewhat. (This process will be analogous to our {ref}`second pass ` at the plain vanilla -McCall model, where we simplified the Bellman equation.) +McCall model, where we reduced the Bellman equation to an equation in an unknown +scalar value, rather than an unknown vector.) First, let ```{math} -:label: defd_mm +:label: defh_mm -d := \sum_{w' \in \mathbb W} v_u(w') q(w') +h := u(c) + \beta \sum_{w' \in \mathbb W} v_u(w') q(w') ``` -be the expected value of unemployment tomorrow. +be the continuation value associated with unemployment (the value of rejecting the current offer). We can now write {eq}`bell2_mccall` as $$ -v_u(w) = \max \left\{ v_e(w), \, u(c) + \beta d \right\} +v_u(w) = \max \left\{ v_e(w), \, h \right\} $$ or, shifting time forward one period $$ \sum_{w' \in \mathbb W} v_u(w') q(w') - = \sum_{w' \in \mathbb W} \max \left\{ v_e(w'), \, u(c) + \beta d \right\} q(w') + = \sum_{w' \in \mathbb W} \max \left\{ v_e(w'), \, h \right\} q(w') $$ -Using {eq}`defd_mm` again now gives +Using {eq}`defh_mm` again now gives ```{math} :label: bell02_mccall -d = \sum_{w' \in \mathbb W} \max \left\{ v_e(w'), \, u(c) + \beta d \right\} q(w') +h = u(c) + \beta \sum_{w' \in \mathbb W} \max \left\{ v_e(w'), \, h \right\} q(w') ``` -Finally, {eq}`bell1_mccall` can now be rewritten as +Finally, from {eq}`defh_mm` we have + +$$ +\sum_{w' \in \mathbb W} v_u(w') q(w') = \frac{h - u(c)}{\beta} +$$ + +so {eq}`bell1_mccall` can now be rewritten as ```{math} :label: bell01_mccall v_e(w) = u(w) + \beta \left[ - (1-\alpha)v_e(w) + \alpha d + (1-\alpha)v_e(w) + \alpha \frac{h - u(c)}{\beta} \right] ``` ### Simplifying to a Single Equation -We can simplify further by solving {eq}`bell01_mccall` for $v_e$ as a function of $d$. +We can simplify further by solving {eq}`bell01_mccall` for $v_e$ as a function of $h$. Rearranging {eq}`bell01_mccall` gives $$ -v_e(w) - \beta(1-\alpha)v_e(w) = u(w) + \beta\alpha d +v_e(w) = u(w) + \beta(1-\alpha)v_e(w) + \alpha(h - u(c)) $$ or +$$ +v_e(w) - \beta(1-\alpha)v_e(w) = u(w) + \alpha(h - u(c)) +$$ + +Solving for $v_e(w)$: + ```{math} :label: v_e_closed -v_e(w) = \frac{u(w) + \beta\alpha d}{1 - \beta(1-\alpha)} +v_e(w) = \frac{u(w) + \alpha(h - u(c))}{1 - \beta(1-\alpha)} ``` Substituting this into {eq}`bell02_mccall` yields @@ -247,23 +286,23 @@ Substituting this into {eq}`bell02_mccall` yields ```{math} :label: bell_scalar -d = \sum_{w' \in \mathbb W} \max \left\{ \frac{u(w') + \beta\alpha d}{1 - \beta(1-\alpha)}, \, u(c) + \beta d \right\} q(w') +h = u(c) + \beta \sum_{w' \in \mathbb W} \max \left\{ \frac{u(w') + \alpha(h - u(c))}{1 - \beta(1-\alpha)}, \, h \right\} q(w') ``` -This is a single scalar equation in $d$. +This is a single scalar equation in $h$. ### The Reservation Wage -Suppose we can use {eq}`bell_scalar` to solve for $d$. +Suppose we can use {eq}`bell_scalar` to solve for $h$. -Once we have $d$, we can obtain $v_e$ from {eq}`v_e_closed`. +Once we have $h$, we can obtain $v_e$ from {eq}`v_e_closed`. We can then determine optimal behavior for the worker. From {eq}`bell2_mccall`, we see that an unemployed agent accepts current offer -$w$ if $v_e(w) \geq u(c) + \beta d$. +$w$ if $v_e(w) \geq h$. -This means precisely that the value of accepting is higher than the expected value of rejecting. +This means precisely that the value of accepting is higher than the value of rejecting. It is clear that $v_e$ is (at least weakly) increasing in $w$, since the agent is never made worse off by a higher wage offer. @@ -272,7 +311,7 @@ Hence, we can express the optimal choice as accepting wage offer $w$ if and only $$ w \geq \bar w \quad \text{where} \quad -\bar w \text{ solves } v_e(\bar w) = u(c) + \beta d +\bar w \text{ solves } v_e(\bar w) = h $$ ### Solving the Bellman Equations @@ -280,32 +319,22 @@ $$ We'll use the same iterative approach to solving the Bellman equations that we adopted in the {doc}`first job search lecture `. -Since we have reduced the problem to a single scalar equation {eq}`bell_scalar`, -we only need to iterate on $d$. +In this case we only need to iterate on the single scalar equation {eq}`bell_scalar`. The iteration rule is ```{math} :label: bell_iter -d_{n+1} = \sum_{w' \in \mathbb W} - \max \left\{ \frac{u(w') + \beta\alpha d_n}{1 - \beta(1-\alpha)}, \, u(c) + \beta d_n \right\} q(w') +h_{n+1} = u(c) + \beta \sum_{w' \in \mathbb W} + \max \left\{ \frac{u(w') + \alpha(h_n - u(c))}{1 - \beta(1-\alpha)}, \, h_n \right\} q(w') ``` -starting from some initial condition $d_0$. +starting from some initial condition $h_0$. -Once convergence is achieved, we can compute $v_e$ from {eq}`v_e_closed`: +Once convergence is achieved, we can compute $v_e$ from {eq}`v_e_closed`. -```{math} -:label: bell_v_e_final - -v_e(w) = \frac{u(w) + \beta\alpha d}{1 - \beta(1-\alpha)} -``` - -This approach is simpler than iterating on both $d$ and $v_e$ simultaneously, as -we now only need to track a single scalar value. - -(Convergence can be established via the Banach contraction mapping theorem.) +(It is possible to prove that {eq}`bell_iter` converges via the Banach contraction mapping theorem.) ## Implementation @@ -319,8 +348,8 @@ This helps to tidy up the code and provides an object that's easy to pass to fun The default utility function is a CRRA utility function ```{code-cell} ipython3 -def u(c, σ=2.0): - return (c**(1 - σ) - 1) / (1 - σ) +def u(c, γ): + return (c**(1 - γ) - 1) / (1 - γ) ``` Also, here's a default wage distribution, based around the BetaBinomial @@ -340,6 +369,7 @@ Here's our model class for the McCall model with separation. class Model(NamedTuple): α: float = 0.2 # job separation rate β: float = 0.98 # discount factor + γ: float = 2.0 # utility parameter (CRRA) c: float = 6.0 # unemployment compensation w: jnp.ndarray = w_default # wage outcome space q: jnp.ndarray = q_default # probabilities over wage offers @@ -349,52 +379,56 @@ Now we iterate until successive realizations are closer together than some small We then return the current iterate as an approximate solution. -First, we define a function to compute $v_e$ from $d$: +First, we define a function to compute $v_e$ from $h$: ```{code-cell} ipython3 -def compute_v_e(model, d): - " Compute v_e from d using the closed-form expression. " - α, β, w = model.α, model.β, model.w - return (u(w) + β * α * d) / (1 - β * (1 - α)) +def compute_v_e(model, h): + " Compute v_e from h using the closed-form expression. " + α, β, γ, c, w, q = model + return (u(w, γ) + α * (h - u(c, γ))) / (1 - β * (1 - α)) ``` -Now we implement the iteration on $d$ only: +Now we implement the iteration on $h$ only: ```{code-cell} ipython3 -def update_d(model, d): - " One update of the scalar d. " - α, β, c, w, q = model.α, model.β, model.c, model.w, model.q - v_e = compute_v_e(model, d) - d_new = jnp.maximum(v_e, u(c) + β * d) @ q - return d_new +def update_h(model, h): + " One update of the scalar h. " + α, β, γ, c, w, q = model + v_e = compute_v_e(model, h) + h_new = u(c, γ) + β * (jnp.maximum(v_e, h) @ q) + return h_new +``` + +Using this iteration rule, we can write our model solver. +```{code-cell} ipython3 @jax.jit def solve_model(model, tol=1e-5, max_iter=2000): " Iterates to convergence on the Bellman equations. " - def cond_fun(state): - d, i, error = state + def cond(loop_state): + h, i, error = loop_state return jnp.logical_and(error > tol, i < max_iter) - def body_fun(state): - d, i, error = state - d_new = update_d(model, d) - error_new = jnp.abs(d_new - d) - return d_new, i + 1, error_new + def update(loop_state): + h, i, error = loop_state + h_new = update_h(model, h) + error_new = jnp.abs(h_new - h) + return h_new, i + 1, error_new - # Initial state: (d, i, error) - d_init = 1.0 + # Initialize + h_init = u(model.c, model.γ) / (1 - model.β) i_init = 0 error_init = tol + 1 + init_state = (h_init, i_init, error_init) - init_state = (d_init, i_init, error_init) - final_state = jax.lax.while_loop(cond_fun, body_fun, init_state) - d_final, _, _ = final_state + final_state = jax.lax.while_loop(cond, update, init_state) + h_final, _, _ = final_state - # Compute v_e from the converged d - v_e_final = compute_v_e(model, d_final) + # Compute v_e from the converged h + v_e_final = compute_v_e(model, h_final) - return v_e_final, d_final + return v_e_final, h_final ``` ### The Reservation Wage: First Pass @@ -402,22 +436,20 @@ def solve_model(model, tol=1e-5, max_iter=2000): The optimal choice of the agent is summarized by the reservation wage. As discussed above, the reservation wage is the $\bar w$ that solves -$v_e(\bar w) = v_u^*$ where $v_u^* := u(c) + \beta d$ is the continuation -value. +$v_e(\bar w) = h$ where $h$ is the continuation value. -Let's compare $v_e$ and $v_u^*$ to see what they look like. +Let's compare $v_e$ and $h$ to see what they look like. We'll use the default parameterizations found in the code above. ```{code-cell} ipython3 model = Model() -v_e, d = solve_model(model) -v_u_star = u(model.c) + model.β * d +v_e, h = solve_model(model) fig, ax = plt.subplots() ax.plot(model.w, v_e, 'b-', lw=2, alpha=0.7, label='$v_e$') -ax.plot(model.w, [v_u_star] * len(model.w), - 'g-', lw=2, alpha=0.7, label='$v_u^*$') +ax.plot(model.w, [h] * len(model.w), + 'g-', lw=2, alpha=0.7, label='$h$') ax.set_xlim(min(model.w), max(model.w)) ax.legend() plt.show() @@ -425,7 +457,11 @@ plt.show() The value $v_e$ is increasing because higher $w$ generates a higher wage flow conditional on staying employed. -### The Reservation Wage: Computation + +The reservation wage is the $w$ where these lines meet. + + +### Computing the Reservation Wage Here's a function `compute_reservation_wage` that takes an instance of `Model` and returns the associated reservation wage. @@ -435,19 +471,20 @@ and returns the associated reservation wage. def compute_reservation_wage(model): """ Computes the reservation wage of an instance of the McCall model - by finding the smallest w such that v_e(w) >= v_u^*. If no such w exists, then + by finding the smallest w such that v_e(w) >= h. If no such w exists, then w_bar is set to np.inf. """ - v_e, d = solve_model(model) - v_u_star = u(model.c) + model.β * d - i = jnp.searchsorted(v_e, v_u_star, side='left') + v_e, h = solve_model(model) + i = jnp.searchsorted(v_e, h, side='left') w_bar = jnp.where(i >= len(model.w), jnp.inf, model.w[i]) return w_bar ``` Next we will investigate how the reservation wage varies with parameters. + + ## Impact of Parameters In each instance below, we'll show you a figure and then ask you to reproduce it in the exercises. @@ -459,7 +496,8 @@ First, let's look at how $\bar w$ varies with unemployment compensation. In the figure below, we use the default parameters in the `Model` class, apart from c (which takes the values given on the horizontal axis) -```{figure} /_static/lecture_specific/mccall_model_with_separation/mccall_resw_c.png +```{glue:figure} mccall_resw_c +:figwidth: 600px ``` @@ -474,7 +512,8 @@ Next, let's investigate how $\bar w$ varies with the discount factor. The next figure plots the reservation wage associated with different values of $\beta$ -```{figure} /_static/lecture_specific/mccall_model_with_separation/mccall_resw_beta.png +```{glue:figure} mccall_resw_beta +:figwidth: 600px ``` @@ -486,7 +525,8 @@ Finally, let's look at how $\bar w$ varies with the job separation rate $\alpha$ Higher $\alpha$ translates to a greater chance that a worker will face termination in each period once employed. -```{figure} /_static/lecture_specific/mccall_model_with_separation/mccall_resw_alpha.png +```{glue:figure} mccall_resw_alpha +:figwidth: 600px ``` @@ -534,6 +574,7 @@ fig, ax = plt.subplots() ax.set(xlabel='unemployment compensation', ylabel='reservation wage') ax.plot(c_vals, w_bar_vals, label=r'$\bar w$ as a function of $c$') ax.legend() +glue("mccall_resw_c", fig, display=False) plt.show() ``` @@ -551,6 +592,7 @@ fig, ax = plt.subplots() ax.set(xlabel='discount factor', ylabel='reservation wage') ax.plot(β_vals, w_bar_vals, label=r'$\bar w$ as a function of $\beta$') ax.legend() +glue("mccall_resw_beta", fig, display=False) plt.show() ``` @@ -568,6 +610,7 @@ fig, ax = plt.subplots() ax.set(xlabel='separation rate', ylabel='reservation wage') ax.plot(α_vals, w_bar_vals, label=r'$\bar w$ as a function of $\alpha$') ax.legend() +glue("mccall_resw_alpha", fig, display=False) plt.show() ```