diff --git a/lectures/lake_model.md b/lectures/lake_model.md index 93fb06156..7cc6420a2 100644 --- a/lectures/lake_model.md +++ b/lectures/lake_model.md @@ -108,13 +108,6 @@ We want to derive the dynamics of the following aggregates: * $U_t$, the total number of unemployed workers at $t$ * $N_t$, the number of workers in the labor force at $t$ -We also want to know the values of the following objects: - -* The employment rate $e_t := E_t/N_t$. -* The unemployment rate $u_t := U_t/N_t$. - -(Here and below, capital letters represent aggregates and lowercase letters represent rates) - ### Laws of motion for stock variables We begin by constructing laws of motion for the aggregate variables $E_t,U_t, N_t$. @@ -167,6 +160,13 @@ This law tells us how total employment and unemployment evolve over time. Now let's derive the law of motion for rates. +We want to track the values of the following objects: + +* The employment rate $e_t := E_t/N_t$. +* The unemployment rate $u_t := U_t/N_t$. + +(Here and below, capital letters represent aggregates and lowercase letters represent rates) + To get these we can divide both sides of $X_{t+1} = A X_t$ by $N_{t+1}$ to get $$ @@ -197,37 +197,92 @@ $$ we can also write this as $$ -x_{t+1} = \hat A x_t +x_{t+1} = R x_t \quad \text{where} \quad -\hat A := \frac{1}{1 + g} A +R := \frac{1}{1 + g} A $$ You can check that $e_t + u_t = 1$ implies that $e_{t+1}+u_{t+1} = 1$. -This follows from the fact that the columns of $\hat A$ sum to 1. +This follows from the fact that the columns of $R$ sum to 1. ## Implementation Let's code up these equations. -To do this we're going to use a class that we'll call `LakeModel` that stores the primitives $\alpha, \lambda, b, d$ +### Model + +To begin, we set up a class called `LakeModel` that stores the primitives $\alpha, \lambda, b, d$. ```{code-cell} ipython3 class LakeModel(NamedTuple): """ Parameters for the lake model """ - λ: float = 0.283 - α: float = 0.013 - b: float = 0.0124 - d: float = 0.00822 + λ: float + α: float + b: float + d: float + A: jnp.ndarray + R: jnp.ndarray + g: float + + +def create_lake_model( + λ: float = 0.283, # job finding rate + α: float = 0.013, # separation rate + b: float = 0.0124, # birth rate + d: float = 0.00822 # death rate + ) -> LakeModel: + """ + Create a LakeModel instance with default parameters. + + Computes and stores the transition matrices A and R, + and the labor force growth rate g. + + """ + # Compute growth rate + g = b - d + + # Compute transition matrix A + A = jnp.array([ + [(1-d) * (1-λ) + b, (1-d) * α + b], + [(1-d) * λ, (1-d) * (1-α)] + ]) + + # Compute normalized transition matrix R + R = A / (1 + g) + + return LakeModel(λ=λ, α=α, b=b, d=d, A=A, R=R, g=g) +``` + +As an experiment, let's create two instances, one with $α=0.013$ and another with $α=0.03$ + +```{code-cell} ipython3 +model = create_lake_model() +print(f"Default α: {model.α}") +print(f"A matrix:\n{model.A}") +print(f"R matrix:\n{model.R}") ``` +```{code-cell} ipython3 +model_new = create_lake_model(α=0.03) +print(f"New α: {model_new.α}") +print(f"New A matrix:\n{model_new.A}") +print(f"New R matrix:\n{model_new.R}") +``` + +### Code for dynamics + We will also use a specialized function to generate time series in an efficient JAX-compatible manner. -(Iteratively generating time series is somewhat nontrivial in JAX because arrays -are immutable.) +Iteratively generating time series is somewhat nontrivial in JAX because arrays +are immutable. + +Here we use `lax.scan`, which allows the function to be jit-compiled. + +Readers who prefer to skip the details can safely continue reading after the function definition. ```{code-cell} ipython3 @partial(jax.jit, static_argnames=['f', 'num_steps']) @@ -235,13 +290,13 @@ def generate_path(f, initial_state, num_steps, **kwargs): """ Generate a time series by repeatedly applying an update rule. - Given a map f, initial state x_0, and a set of model parameter θ, this + Given a map f, initial state x_0, and model parameters, this function computes and returns the sequence {x_t}_{t=0}^{T-1} when - x_{t+1} = f(x_t, t, θ) + x_{t+1} = f(x_t, **kwargs) Args: - f: Update function mapping (x_t, t, θ) -> x_{t+1} + f: Update function mapping (x_t, **kwargs) -> x_{t+1} initial_state: Initial state x_0 num_steps: Number of time steps T to simulate **kwargs: Optional extra arguments passed to f @@ -255,7 +310,7 @@ def generate_path(f, initial_state, num_steps, **kwargs): """ Wrapper function that adapts f for use with JAX scan. """ - next_state = f(state, t, **kwargs) + next_state = f(state, **kwargs) return next_state, state _, path = jax.lax.scan(update_wrapper, @@ -263,59 +318,26 @@ def generate_path(f, initial_state, num_steps, **kwargs): return path.T ``` -Now we can compute the matrices and simulate the dynamics. +Here are functions to update $X_t$ and $x_t$. ```{code-cell} ipython3 -@jax.jit -def compute_matrices(model: LakeModel): - """Compute the transition matrices A and A_hat for the model.""" - λ, α, b, d = model.λ, model.α, model.b, model.d - g = b - d - A = jnp.array([[(1-d) * (1-λ) + b, (1 - d) * α + b], - [ (1-d) * λ, (1 - d) * (1 - α)]]) - A_hat = A / (1 + g) - return A, A_hat, g - - -@jax.jit -def stock_update(current_stocks, time_step, model): - """ - Apply transition matrix to get next period's stocks. - """ - A, A_hat, g = compute_matrices(model) - next_stocks = A @ current_stocks - return next_stocks - -@jax.jit -def rate_update(current_rates, time_step, model): - """ - Apply normalized transition matrix for next period's rates. - """ - A, A_hat, g = compute_matrices(model) - next_rates = A_hat @ current_rates - return next_rates -``` - -We create two instances, one with $α=0.013$ and another with $α=0.03$ - -```{code-cell} ipython3 -model = LakeModel() -model_new = LakeModel(α=0.03) - -print(f"Default α: {model.α}") -A, A_hat, g = compute_matrices(model) -print(f"A matrix:\n{A}") +def stock_update(X: jnp.ndarray, model: LakeModel) -> jnp.ndarray: + """Apply transition matrix to get next period's stocks.""" + λ, α, b, d, A, R, g = model + return A @ X + +def rate_update(x: jnp.ndarray, model: LakeModel) -> jnp.ndarray: + """Apply normalized transition matrix for next period's rates.""" + λ, α, b, d, A, R, g = model + return R @ x ``` -```{code-cell} ipython3 -A_new, A_hat_new, g_new = compute_matrices(model_new) -print(f"New α: {model_new.α}") -print(f"New A matrix:\n{A_new}") -``` ### Aggregate dynamics -Let's run a simulation under the default parameters (see above) starting from $X_0 = (12, 138)$. +Let's run a simulation under the default parameters starting from $X_0 = (12, 138)$. + +We will plot the sequences $\{E_t\}$, $\{U_t\}$ and $\{N_t\}$. ```{code-cell} ipython3 N_0 = 150 # Population @@ -326,66 +348,72 @@ T = 50 # Simulation length U_0 = u_0 * N_0 E_0 = e_0 * N_0 -fig, axes = plt.subplots(3, 1, figsize=(10, 8)) +# Generate X path X_0 = jnp.array([U_0, E_0]) X_path = generate_path(stock_update, X_0, T, model=model) -axes[0].plot(X_path[0, :], lw=2) -axes[0].set_title('unemployment') - -axes[1].plot(X_path[1, :], lw=2) -axes[1].set_title('employment') - -axes[2].plot(X_path.sum(0), lw=2) -axes[2].set_title('labor force') - +# Plot +fig, axes = plt.subplots(3, 1, figsize=(10, 8)) +titles = ['unemployment', 'employment', 'labor force'] +data = [X_path[0, :], X_path[1, :], X_path.sum(0)] +for ax, title, series in zip(axes, titles, data): + ax.plot(series, lw=2) + ax.set_title(title) plt.tight_layout() plt.show() ``` The aggregates $E_t$ and $U_t$ don't converge because their sum $E_t + U_t$ grows at rate $g$. + +### Rate dynamics + On the other hand, the vector of employment and unemployment rates $x_t$ can be in a steady state $\bar x$ if there exists an $\bar x$ such that -* $\bar x = \hat A \bar x$ +* $\bar x = R \bar x$ * the components satisfy $\bar e + \bar u = 1$ -This equation tells us that a steady state level $\bar x$ is an eigenvector of $\hat A$ associated with a unit eigenvalue. +This equation tells us that a steady state level $\bar x$ is an eigenvector of $R$ associated with a unit eigenvalue. The following function can be used to compute the steady state. ```{code-cell} ipython3 @jax.jit -def rate_steady_state(model: LakeModel): +def rate_steady_state(model: LakeModel) -> jnp.ndarray: r""" - Finds the steady state of the system :math:`x_{t+1} = \hat A x_{t}` - by computing the eigenvector corresponding to the unit eigenvalue. + Finds the steady state of the system :math:`x_{t+1} = R x_{t}` + by computing the eigenvector corresponding to the largest eigenvalue. + + By the Perron-Frobenius theorem, since :math:`R` is a non-negative + matrix with columns summing to 1 (a stochastic matrix), the largest + eigenvalue equals 1 and the corresponding eigenvector gives the steady state. """ - A, A_hat, g = compute_matrices(model) - eigenvals, eigenvec = jnp.linalg.eig(A_hat) - - # Find the eigenvector corresponding to eigenvalue 1 - unit_idx = jnp.argmin(jnp.abs(eigenvals - 1.0)) + λ, α, b, d, A, R, g = model + eigenvals, eigenvec = jnp.linalg.eig(R) + + # Find the eigenvector corresponding to the largest eigenvalue + # (which is 1 for a stochastic matrix by Perron-Frobenius theorem) + max_idx = jnp.argmax(jnp.abs(eigenvals)) # Get the corresponding eigenvector - steady_state = jnp.real(eigenvec[:, unit_idx]) - + steady_state = jnp.real(eigenvec[:, max_idx]) + # Normalize to ensure positive values and sum to 1 steady_state = jnp.abs(steady_state) steady_state = steady_state / jnp.sum(steady_state) - + return steady_state ``` We also have $x_t \to \bar x$ as $t \to \infty$ provided that the remaining -eigenvalue of $\hat A$ has modulus less than 1. +eigenvalue of $R$ has modulus less than 1. This is the case for our default parameters: ```{code-cell} ipython3 -A, A_hat, g = compute_matrices(model) -e, f = jnp.linalg.eigvals(A_hat) +model = create_lake_model() +e, f = jnp.linalg.eigvals(model.R) print(f"Eigenvalue magnitudes: {abs(e):.2f}, {abs(f):.2f}") ``` @@ -409,6 +437,41 @@ plt.tight_layout() plt.show() ``` +```{exercise} +:label: model_ex1 + +Use JAX's `vmap` to compute steady-state unemployment rates for a range of job finding rates $\lambda$ (from 0.1 to 0.5), and plot the relationship. +``` + +```{solution-start} model_ex1 +:class: dropdown +``` + +Here is one solution + +```{code-cell} ipython3 +@jax.jit +def compute_unemployment_rate(λ_val): + """Computes steady-state unemployment for a given λ""" + model = create_lake_model(λ=λ_val) + steady_state = rate_steady_state(model) + return steady_state[0] + +# Use vmap to compute for multiple λ values +λ_values = jnp.linspace(0.1, 0.5, 50) +unemployment_rates = jax.vmap(compute_unemployment_rate)(λ_values) + +# Plot the results +fig, ax = plt.subplots(figsize=(10, 6)) +ax.plot(λ_values, unemployment_rates, lw=2) +ax.set_xlabel(r'$\lambda$') +ax.set_ylabel('steady-state unemployment rate') +plt.show() +``` + +```{solution-end} +``` + (dynamics_workers)= ## Dynamics of an individual worker @@ -487,7 +550,7 @@ $$ with probability one. -Inspection tells us that $P$ is exactly the transpose of $\hat A$ under the assumption $b=d=0$. +Inspection tells us that $P$ is exactly the transpose of $R$ under the assumption $b=d=0$. Thus, the percentages of time that an infinitely lived worker spends employed and unemployed equal the fractions of workers employed and unemployed in the steady state distribution. @@ -500,18 +563,17 @@ We can investigate this by simulating the Markov chain. Let's plot the path of the sample averages over 5,000 periods ```{code-cell} ipython3 -@jax.jit -def markov_update(state, t, P, keys): +def markov_update(state, P, key): """ Sample next state from transition probabilities. """ probs = P[state] - state_new = jax.random.choice(keys[t], + state_new = jax.random.choice(key, a=jnp.arange(len(probs)), p=probs) return state_new -model_markov = LakeModel(d=0, b=0) +model_markov = create_lake_model(d=0, b=0) T = 5000 # Simulation length α, λ = model_markov.α, model_markov.λ @@ -521,10 +583,21 @@ P = jnp.array([[1 - λ, λ], xbar = rate_steady_state(model_markov) -# Simulate the Markov chain +# Simulate the Markov chain - we need a different approach for random updates key = jax.random.PRNGKey(0) -keys = jax.random.split(key, T) -s_path = generate_path(markov_update, 1, T, P=P, keys=keys) + +def simulate_markov(P, initial_state, T, key): + """Simulate Markov chain for T periods""" + keys = jax.random.split(key, T) + + def scan_fn(state, key): + next_state = markov_update(state, P, key) + return next_state, state + + _, path = jax.lax.scan(scan_fn, initial_state, keys) + return path + +s_path = simulate_markov(P, 1, T, key) fig, axes = plt.subplots(2, 1, figsize=(10, 8)) s_bar_e = jnp.cumsum(s_path) / jnp.arange(1, T+1) @@ -535,14 +608,14 @@ titles = ['percent of time unemployed', 'percent of time employed'] for i, plot in enumerate(to_plot): axes[i].plot(plot, lw=2, alpha=0.5) - axes[i].hlines(xbar[i], 0, T, 'r', '--') + axes[i].hlines(xbar[i], 0, T, linestyles='--') axes[i].set_title(titles[i]) plt.tight_layout() plt.show() ``` -The stationary probabilities are given by the dashed red line. +The stationary probabilities are given by the dashed line. In this case it takes much of the sample for these two objects to converge. @@ -812,25 +885,25 @@ def compute_optimal_quantities(c, τ, @jax.jit -def compute_steady_state_quantities(c, τ, +def compute_steady_state_quantities(c, τ, params: EconomyParameters, w_vec, p_vec): """ Compute the steady state unemployment rate given c and τ using optimal quantities from the McCall model and computing corresponding steady state quantities """ - w_bar, λ, V, U = compute_optimal_quantities(c, τ, + w_bar, λ, V, U = compute_optimal_quantities(c, τ, params, w_vec, p_vec) - + # Compute steady state employment and unemployment rates - model = LakeModel(α=params.α_q, λ=λ, b=params.b, d=params.d) + model = create_lake_model(λ=λ, α=params.α_q, b=params.b, d=params.d) u, e = rate_steady_state(model) - + # Compute steady state welfare mask = (w_vec - τ > w_bar) w = jnp.sum(V * p_vec * mask) / jnp.sum(p_vec * mask) welfare = e * w + u * U - + return e, u, welfare @@ -905,63 +978,6 @@ The level that maximizes steady state welfare is approximately 62. ## Exercises -```{exercise} -:label: model_ex1 - -In the JAX implementation of the Lake Model, we use a `NamedTuple` for parameters and separate functions for computations. - -This approach has several advantages: -1. It's immutable, which aligns with JAX's functional programming paradigm -2. Functions can be JIT-compiled for better performance - -In this exercise, your task is to: -1. Update parameters by creating a new instance of the model with the parameters (`α=0.02, λ=0.3`). -2. Use JAX's `vmap` to compute steady states for different parameter values -3. Plot how the steady-state unemployment rate varies with the job finding rate $\lambda$ -``` - -```{solution-start} model_ex1 -:class: dropdown -``` - -Here is one solution - -```{code-cell} ipython3 -@jax.jit -def compute_unemployment_rate(λ_val): - """Computes steady-state unemployment for a given λ""" - model = LakeModel(λ=λ_val) - steady_state = rate_steady_state(model) - return steady_state[0] - -# Use vmap to compute for multiple λ values -λ_values = jnp.linspace(0.1, 0.5, 50) -unemployment_rates = jax.vmap(compute_unemployment_rate)(λ_values) - -# Plot the results -fig, ax = plt.subplots(figsize=(10, 6)) -ax.plot(λ_values, unemployment_rates, lw=2) -ax.set_xlabel(r'$\lambda$') -ax.set_ylabel('steady-state unemployment rate') -plt.show() - -model_base = LakeModel() -model_ex1 = LakeModel(α=0.02, λ=0.3) - -print(f"Base model α: {model_base.α}") -print(f"New model α: {model_ex1.α}, λ: {model_ex1.λ}") - -# Compute steady states for both -base_steady_state = rate_steady_state(model_base) -new_steady_state = rate_steady_state(model_ex1) - -print(f"Base unemployment rate: {base_steady_state[0]:.4f}") -print(f"New unemployment rate: {new_steady_state[0]:.4f}") -``` - -```{solution-end} -``` - ```{exercise-start} :label: model_ex2 ``` @@ -998,7 +1014,7 @@ We begin by constructing the model with default parameters and finding the initial steady state ```{code-cell} ipython3 -model_initial = LakeModel() +model_initial = create_lake_model() x0 = rate_steady_state(model_initial) print(f"Initial Steady State: {x0}") ``` @@ -1013,7 +1029,7 @@ T = 50 New legislation changes $\lambda$ to $0.2$ ```{code-cell} ipython3 -model_ex2 = LakeModel(λ=0.2) +model_ex2 = create_lake_model(λ=0.2) xbar = rate_steady_state(model_ex2) # new steady state # Simulate paths @@ -1049,7 +1065,7 @@ titles = ['unemployment rate', 'employment rate'] for i, title in enumerate(titles): axes[i].plot(x_path[i, :]) - axes[i].hlines(xbar[i], 0, T, 'r', '--') + axes[i].hlines(xbar[i], 0, T, linestyles='--') axes[i].set_title(title) plt.tight_layout() @@ -1091,7 +1107,7 @@ Let's start off at the baseline parameterization and record the steady state ```{code-cell} ipython3 -model_baseline = LakeModel() +model_baseline = create_lake_model() x0 = rate_steady_state(model_baseline) N0 = 100 T = 50 @@ -1107,7 +1123,7 @@ T_hat = 20 Let's increase $b$ to the new value and simulate for 20 periods ```{code-cell} ipython3 -model_high_b = LakeModel(b=b_hat) +model_high_b = create_lake_model(b=b_hat) # Simulate stocks and rates for first 20 periods X_path1 = generate_path(stock_update, x0 * N0, T_hat, model=model_high_b) @@ -1157,7 +1173,7 @@ titles = ['unemployment rate', 'employment rate'] for i, title in enumerate(titles): axes[i].plot(x_path[i, :]) - axes[i].hlines(x0[i], 0, T, 'r', '--') + axes[i].hlines(x0[i], 0, T, linestyles='--') axes[i].set_title(title) plt.tight_layout()