From 27bb29ba8cb1c6f9934dae4462e91fe9d1128e61 Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Sun, 2 Nov 2025 12:39:17 +0900 Subject: [PATCH 1/3] Refine lake_model.md: Remove redundant decorators and improve organization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit improves the Lake Model lecture with several refinements: **Code improvements:** - Remove redundant `@jax.jit` decorators from `compute_matrices`, `stock_update`, `rate_update`, and `markov_update` (these functions are only called from within jitted functions, so the decorators are unnecessary and can inhibit compiler optimization) - Refactor aggregate dynamics plot to use a for loop instead of repetitive code - Remove hardcoded colors ('r') from plots to use matplotlib's default color cycle **Content organization:** - Move rate definitions ($e_t$, $u_t$) to "Laws of motion for rates" section where they logically belong - Relocate Exercise 1 to appear immediately before "Dynamics of an individual worker" section for better flow - Simplify Exercise 1 to focus on the pedagogically interesting `vmap` usage, removing less interesting parameter comparison parts These changes improve code clarity, performance, and pedagogical flow without changing functionality. 馃 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- lectures/lake_model.md | 130 ++++++++++++++++------------------------- 1 file changed, 51 insertions(+), 79 deletions(-) diff --git a/lectures/lake_model.md b/lectures/lake_model.md index 93fb06156..192318a22 100644 --- a/lectures/lake_model.md +++ b/lectures/lake_model.md @@ -108,13 +108,6 @@ We want to derive the dynamics of the following aggregates: * $U_t$, the total number of unemployed workers at $t$ * $N_t$, the number of workers in the labor force at $t$ -We also want to know the values of the following objects: - -* The employment rate $e_t := E_t/N_t$. -* The unemployment rate $u_t := U_t/N_t$. - -(Here and below, capital letters represent aggregates and lowercase letters represent rates) - ### Laws of motion for stock variables We begin by constructing laws of motion for the aggregate variables $E_t,U_t, N_t$. @@ -167,6 +160,13 @@ This law tells us how total employment and unemployment evolve over time. Now let's derive the law of motion for rates. +We want to track the values of the following objects: + +* The employment rate $e_t := E_t/N_t$. +* The unemployment rate $u_t := U_t/N_t$. + +(Here and below, capital letters represent aggregates and lowercase letters represent rates) + To get these we can divide both sides of $X_{t+1} = A X_t$ by $N_{t+1}$ to get $$ @@ -266,7 +266,6 @@ def generate_path(f, initial_state, num_steps, **kwargs): Now we can compute the matrices and simulate the dynamics. ```{code-cell} ipython3 -@jax.jit def compute_matrices(model: LakeModel): """Compute the transition matrices A and A_hat for the model.""" 位, 伪, b, d = model.位, model.伪, model.b, model.d @@ -277,7 +276,6 @@ def compute_matrices(model: LakeModel): return A, A_hat, g -@jax.jit def stock_update(current_stocks, time_step, model): """ Apply transition matrix to get next period's stocks. @@ -286,7 +284,6 @@ def stock_update(current_stocks, time_step, model): next_stocks = A @ current_stocks return next_stocks -@jax.jit def rate_update(current_rates, time_step, model): """ Apply normalized transition matrix for next period's rates. @@ -330,14 +327,12 @@ fig, axes = plt.subplots(3, 1, figsize=(10, 8)) X_0 = jnp.array([U_0, E_0]) X_path = generate_path(stock_update, X_0, T, model=model) -axes[0].plot(X_path[0, :], lw=2) -axes[0].set_title('unemployment') - -axes[1].plot(X_path[1, :], lw=2) -axes[1].set_title('employment') +titles = ['unemployment', 'employment', 'labor force'] +data = [X_path[0, :], X_path[1, :], X_path.sum(0)] -axes[2].plot(X_path.sum(0), lw=2) -axes[2].set_title('labor force') +for ax, title, series in zip(axes, titles, data): + ax.plot(series, lw=2) + ax.set_title(title) plt.tight_layout() plt.show() @@ -409,6 +404,41 @@ plt.tight_layout() plt.show() ``` +```{exercise} +:label: model_ex1 + +Use JAX's `vmap` to compute steady-state unemployment rates for a range of job finding rates $\lambda$ (from 0.1 to 0.5), and plot the relationship. +``` + +```{solution-start} model_ex1 +:class: dropdown +``` + +Here is one solution + +```{code-cell} ipython3 +@jax.jit +def compute_unemployment_rate(位_val): + """Computes steady-state unemployment for a given 位""" + model = LakeModel(位=位_val) + steady_state = rate_steady_state(model) + return steady_state[0] + +# Use vmap to compute for multiple 位 values +位_values = jnp.linspace(0.1, 0.5, 50) +unemployment_rates = jax.vmap(compute_unemployment_rate)(位_values) + +# Plot the results +fig, ax = plt.subplots(figsize=(10, 6)) +ax.plot(位_values, unemployment_rates, lw=2) +ax.set_xlabel(r'$\lambda$') +ax.set_ylabel('steady-state unemployment rate') +plt.show() +``` + +```{solution-end} +``` + (dynamics_workers)= ## Dynamics of an individual worker @@ -500,7 +530,6 @@ We can investigate this by simulating the Markov chain. Let's plot the path of the sample averages over 5,000 periods ```{code-cell} ipython3 -@jax.jit def markov_update(state, t, P, keys): """ Sample next state from transition probabilities. @@ -535,14 +564,14 @@ titles = ['percent of time unemployed', 'percent of time employed'] for i, plot in enumerate(to_plot): axes[i].plot(plot, lw=2, alpha=0.5) - axes[i].hlines(xbar[i], 0, T, 'r', '--') + axes[i].hlines(xbar[i], 0, T, linestyles='--') axes[i].set_title(titles[i]) plt.tight_layout() plt.show() ``` -The stationary probabilities are given by the dashed red line. +The stationary probabilities are given by the dashed line. In this case it takes much of the sample for these two objects to converge. @@ -905,63 +934,6 @@ The level that maximizes steady state welfare is approximately 62. ## Exercises -```{exercise} -:label: model_ex1 - -In the JAX implementation of the Lake Model, we use a `NamedTuple` for parameters and separate functions for computations. - -This approach has several advantages: -1. It's immutable, which aligns with JAX's functional programming paradigm -2. Functions can be JIT-compiled for better performance - -In this exercise, your task is to: -1. Update parameters by creating a new instance of the model with the parameters (`伪=0.02, 位=0.3`). -2. Use JAX's `vmap` to compute steady states for different parameter values -3. Plot how the steady-state unemployment rate varies with the job finding rate $\lambda$ -``` - -```{solution-start} model_ex1 -:class: dropdown -``` - -Here is one solution - -```{code-cell} ipython3 -@jax.jit -def compute_unemployment_rate(位_val): - """Computes steady-state unemployment for a given 位""" - model = LakeModel(位=位_val) - steady_state = rate_steady_state(model) - return steady_state[0] - -# Use vmap to compute for multiple 位 values -位_values = jnp.linspace(0.1, 0.5, 50) -unemployment_rates = jax.vmap(compute_unemployment_rate)(位_values) - -# Plot the results -fig, ax = plt.subplots(figsize=(10, 6)) -ax.plot(位_values, unemployment_rates, lw=2) -ax.set_xlabel(r'$\lambda$') -ax.set_ylabel('steady-state unemployment rate') -plt.show() - -model_base = LakeModel() -model_ex1 = LakeModel(伪=0.02, 位=0.3) - -print(f"Base model 伪: {model_base.伪}") -print(f"New model 伪: {model_ex1.伪}, 位: {model_ex1.位}") - -# Compute steady states for both -base_steady_state = rate_steady_state(model_base) -new_steady_state = rate_steady_state(model_ex1) - -print(f"Base unemployment rate: {base_steady_state[0]:.4f}") -print(f"New unemployment rate: {new_steady_state[0]:.4f}") -``` - -```{solution-end} -``` - ```{exercise-start} :label: model_ex2 ``` @@ -1049,7 +1021,7 @@ titles = ['unemployment rate', 'employment rate'] for i, title in enumerate(titles): axes[i].plot(x_path[i, :]) - axes[i].hlines(xbar[i], 0, T, 'r', '--') + axes[i].hlines(xbar[i], 0, T, linestyles='--') axes[i].set_title(title) plt.tight_layout() @@ -1157,7 +1129,7 @@ titles = ['unemployment rate', 'employment rate'] for i, title in enumerate(titles): axes[i].plot(x_path[i, :]) - axes[i].hlines(x0[i], 0, T, 'r', '--') + axes[i].hlines(x0[i], 0, T, linestyles='--') axes[i].set_title(title) plt.tight_layout() From 8cce1a7e82d20afdfc56d5dbb7d002d20bd05fdd Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Sun, 2 Nov 2025 17:40:51 +0900 Subject: [PATCH 2/3] Refactor lake_model.md: Improve model structure and notation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Major improvements: - Added create_lake_model() function to generate model instances with precomputed matrices A, R, and g - Replaced A_hat notation with R throughout (code and LaTeX) for cleaner notation - Updated LakeModel NamedTuple to store computed matrices A and R - Modified all functions to unpack model using tuple unpacking for efficiency - Added type hints to stock_update(), rate_update(), and create_lake_model() - Simplified generate_path() function by removing unused time parameter - Updated rate_steady_state() to use Perron-Frobenius theorem (argmax instead of searching for eigenvalue near 1) - Converted all LakeModel() instantiations to use create_lake_model() - Updated markov simulation to use dedicated simulate_markov() function Benefits: - Matrices computed once at model creation instead of repeatedly - Cleaner mathematical notation using R instead of \hat{A} - More efficient code with direct tuple unpacking - Better type safety with added annotations - More mathematically rigorous using Perron-Frobenius theorem 馃 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- lectures/lake_model.md | 190 +++++++++++++++++++++++++---------------- 1 file changed, 117 insertions(+), 73 deletions(-) diff --git a/lectures/lake_model.md b/lectures/lake_model.md index 192318a22..f56218753 100644 --- a/lectures/lake_model.md +++ b/lectures/lake_model.md @@ -197,14 +197,14 @@ $$ we can also write this as $$ -x_{t+1} = \hat A x_t +x_{t+1} = R x_t \quad \text{where} \quad -\hat A := \frac{1}{1 + g} A +R := \frac{1}{1 + g} A $$ You can check that $e_t + u_t = 1$ implies that $e_{t+1}+u_{t+1} = 1$. -This follows from the fact that the columns of $\hat A$ sum to 1. +This follows from the fact that the columns of $R$ sum to 1. ## Implementation @@ -221,6 +221,50 @@ class LakeModel(NamedTuple): 伪: float = 0.013 b: float = 0.0124 d: float = 0.00822 + A: jnp.ndarray = None + R: jnp.ndarray = None + g: float = None + + +def create_lake_model(位: float = 0.283, + 伪: float = 0.013, + b: float = 0.0124, + d: float = 0.00822) -> LakeModel: + """ + Create a LakeModel instance with default parameters. + + Computes and stores the transition matrices A and R, + and the labor force growth rate g. + + Parameters + ---------- + 位 : float, optional + Job finding rate (default: 0.283) + 伪 : float, optional + Job separation rate (default: 0.013) + b : float, optional + Entry rate into labor force (default: 0.0124) + d : float, optional + Exit rate from labor force (default: 0.00822) + + Returns + ------- + LakeModel + A LakeModel instance with computed matrices A, R, and growth rate g + """ + # Compute growth rate + g = b - d + + # Compute transition matrix A + A = jnp.array([ + [(1-d) * (1-位) + b, (1-d) * 伪 + b], + [(1-d) * 位, (1-d) * (1-伪)] + ]) + + # Compute normalized transition matrix R + R = A / (1 + g) + + return LakeModel(位=位, 伪=伪, b=b, d=d, A=A, R=R, g=g) ``` We will also use a specialized function to generate time series in an efficient @@ -235,13 +279,13 @@ def generate_path(f, initial_state, num_steps, **kwargs): """ Generate a time series by repeatedly applying an update rule. - Given a map f, initial state x_0, and a set of model parameter 胃, this + Given a map f, initial state x_0, and model parameters, this function computes and returns the sequence {x_t}_{t=0}^{T-1} when - x_{t+1} = f(x_t, t, 胃) + x_{t+1} = f(x_t, **kwargs) Args: - f: Update function mapping (x_t, t, 胃) -> x_{t+1} + f: Update function mapping (x_t, **kwargs) -> x_{t+1} initial_state: Initial state x_0 num_steps: Number of time steps T to simulate **kwargs: Optional extra arguments passed to f @@ -255,7 +299,7 @@ def generate_path(f, initial_state, num_steps, **kwargs): """ Wrapper function that adapts f for use with JAX scan. """ - next_state = f(state, t, **kwargs) + next_state = f(state, **kwargs) return next_state, state _, path = jax.lax.scan(update_wrapper, @@ -263,51 +307,35 @@ def generate_path(f, initial_state, num_steps, **kwargs): return path.T ``` -Now we can compute the matrices and simulate the dynamics. +Now we can simulate the dynamics. ```{code-cell} ipython3 -def compute_matrices(model: LakeModel): - """Compute the transition matrices A and A_hat for the model.""" - 位, 伪, b, d = model.位, model.伪, model.b, model.d - g = b - d - A = jnp.array([[(1-d) * (1-位) + b, (1 - d) * 伪 + b], - [ (1-d) * 位, (1 - d) * (1 - 伪)]]) - A_hat = A / (1 + g) - return A, A_hat, g - - -def stock_update(current_stocks, time_step, model): - """ - Apply transition matrix to get next period's stocks. - """ - A, A_hat, g = compute_matrices(model) - next_stocks = A @ current_stocks - return next_stocks - -def rate_update(current_rates, time_step, model): - """ - Apply normalized transition matrix for next period's rates. - """ - A, A_hat, g = compute_matrices(model) - next_rates = A_hat @ current_rates - return next_rates +def stock_update(X: jnp.ndarray, model: LakeModel) -> jnp.ndarray: + """Apply transition matrix to get next period's stocks.""" + 位, 伪, b, d, A, R, g = model + return A @ X + +def rate_update(x: jnp.ndarray, model: LakeModel) -> jnp.ndarray: + """Apply normalized transition matrix for next period's rates.""" + 位, 伪, b, d, A, R, g = model + return R @ x ``` We create two instances, one with $伪=0.013$ and another with $伪=0.03$ ```{code-cell} ipython3 -model = LakeModel() -model_new = LakeModel(伪=0.03) +model = create_lake_model() +model_new = create_lake_model(伪=0.03) print(f"Default 伪: {model.伪}") -A, A_hat, g = compute_matrices(model) -print(f"A matrix:\n{A}") +print(f"A matrix:\n{model.A}") +print(f"R matrix:\n{model.R}") ``` ```{code-cell} ipython3 -A_new, A_hat_new, g_new = compute_matrices(model_new) print(f"New 伪: {model_new.伪}") -print(f"New A matrix:\n{A_new}") +print(f"New A matrix:\n{model_new.A}") +print(f"New R matrix:\n{model_new.R}") ``` ### Aggregate dynamics @@ -343,44 +371,49 @@ The aggregates $E_t$ and $U_t$ don't converge because their sum $E_t + U_t$ grow On the other hand, the vector of employment and unemployment rates $x_t$ can be in a steady state $\bar x$ if there exists an $\bar x$ such that -* $\bar x = \hat A \bar x$ +* $\bar x = R \bar x$ * the components satisfy $\bar e + \bar u = 1$ -This equation tells us that a steady state level $\bar x$ is an eigenvector of $\hat A$ associated with a unit eigenvalue. +This equation tells us that a steady state level $\bar x$ is an eigenvector of $R$ associated with a unit eigenvalue. The following function can be used to compute the steady state. ```{code-cell} ipython3 @jax.jit -def rate_steady_state(model: LakeModel): +def rate_steady_state(model: LakeModel) -> jnp.ndarray: r""" - Finds the steady state of the system :math:`x_{t+1} = \hat A x_{t}` - by computing the eigenvector corresponding to the unit eigenvalue. + Finds the steady state of the system :math:`x_{t+1} = R x_{t}` + by computing the eigenvector corresponding to the largest eigenvalue. + + By the Perron-Frobenius theorem, since :math:`R` is a non-negative + matrix with columns summing to 1 (a stochastic matrix), the largest + eigenvalue equals 1 and the corresponding eigenvector gives the steady state. """ - A, A_hat, g = compute_matrices(model) - eigenvals, eigenvec = jnp.linalg.eig(A_hat) - - # Find the eigenvector corresponding to eigenvalue 1 - unit_idx = jnp.argmin(jnp.abs(eigenvals - 1.0)) + 位, 伪, b, d, A, R, g = model + eigenvals, eigenvec = jnp.linalg.eig(R) + + # Find the eigenvector corresponding to the largest eigenvalue + # (which is 1 for a stochastic matrix by Perron-Frobenius theorem) + max_idx = jnp.argmax(jnp.abs(eigenvals)) # Get the corresponding eigenvector - steady_state = jnp.real(eigenvec[:, unit_idx]) - + steady_state = jnp.real(eigenvec[:, max_idx]) + # Normalize to ensure positive values and sum to 1 steady_state = jnp.abs(steady_state) steady_state = steady_state / jnp.sum(steady_state) - + return steady_state ``` We also have $x_t \to \bar x$ as $t \to \infty$ provided that the remaining -eigenvalue of $\hat A$ has modulus less than 1. +eigenvalue of $R$ has modulus less than 1. This is the case for our default parameters: ```{code-cell} ipython3 -A, A_hat, g = compute_matrices(model) -e, f = jnp.linalg.eigvals(A_hat) +model = create_lake_model() +e, f = jnp.linalg.eigvals(model.R) print(f"Eigenvalue magnitudes: {abs(e):.2f}, {abs(f):.2f}") ``` @@ -420,7 +453,7 @@ Here is one solution @jax.jit def compute_unemployment_rate(位_val): """Computes steady-state unemployment for a given 位""" - model = LakeModel(位=位_val) + model = create_lake_model(位=位_val) steady_state = rate_steady_state(model) return steady_state[0] @@ -517,7 +550,7 @@ $$ with probability one. -Inspection tells us that $P$ is exactly the transpose of $\hat A$ under the assumption $b=d=0$. +Inspection tells us that $P$ is exactly the transpose of $R$ under the assumption $b=d=0$. Thus, the percentages of time that an infinitely lived worker spends employed and unemployed equal the fractions of workers employed and unemployed in the steady state distribution. @@ -530,17 +563,17 @@ We can investigate this by simulating the Markov chain. Let's plot the path of the sample averages over 5,000 periods ```{code-cell} ipython3 -def markov_update(state, t, P, keys): +def markov_update(state, P, key): """ Sample next state from transition probabilities. """ probs = P[state] - state_new = jax.random.choice(keys[t], + state_new = jax.random.choice(key, a=jnp.arange(len(probs)), p=probs) return state_new -model_markov = LakeModel(d=0, b=0) +model_markov = create_lake_model(d=0, b=0) T = 5000 # Simulation length 伪, 位 = model_markov.伪, model_markov.位 @@ -550,10 +583,21 @@ P = jnp.array([[1 - 位, 位], xbar = rate_steady_state(model_markov) -# Simulate the Markov chain +# Simulate the Markov chain - we need a different approach for random updates key = jax.random.PRNGKey(0) -keys = jax.random.split(key, T) -s_path = generate_path(markov_update, 1, T, P=P, keys=keys) + +def simulate_markov(P, initial_state, T, key): + """Simulate Markov chain for T periods""" + keys = jax.random.split(key, T) + + def scan_fn(state, key): + next_state = markov_update(state, P, key) + return next_state, state + + _, path = jax.lax.scan(scan_fn, initial_state, keys) + return path + +s_path = simulate_markov(P, 1, T, key) fig, axes = plt.subplots(2, 1, figsize=(10, 8)) s_bar_e = jnp.cumsum(s_path) / jnp.arange(1, T+1) @@ -841,25 +885,25 @@ def compute_optimal_quantities(c, 蟿, @jax.jit -def compute_steady_state_quantities(c, 蟿, +def compute_steady_state_quantities(c, 蟿, params: EconomyParameters, w_vec, p_vec): """ Compute the steady state unemployment rate given c and 蟿 using optimal quantities from the McCall model and computing corresponding steady state quantities """ - w_bar, 位, V, U = compute_optimal_quantities(c, 蟿, + w_bar, 位, V, U = compute_optimal_quantities(c, 蟿, params, w_vec, p_vec) - + # Compute steady state employment and unemployment rates - model = LakeModel(伪=params.伪_q, 位=位, b=params.b, d=params.d) + model = create_lake_model(位=位, 伪=params.伪_q, b=params.b, d=params.d) u, e = rate_steady_state(model) - + # Compute steady state welfare mask = (w_vec - 蟿 > w_bar) w = jnp.sum(V * p_vec * mask) / jnp.sum(p_vec * mask) welfare = e * w + u * U - + return e, u, welfare @@ -970,7 +1014,7 @@ We begin by constructing the model with default parameters and finding the initial steady state ```{code-cell} ipython3 -model_initial = LakeModel() +model_initial = create_lake_model() x0 = rate_steady_state(model_initial) print(f"Initial Steady State: {x0}") ``` @@ -985,7 +1029,7 @@ T = 50 New legislation changes $\lambda$ to $0.2$ ```{code-cell} ipython3 -model_ex2 = LakeModel(位=0.2) +model_ex2 = create_lake_model(位=0.2) xbar = rate_steady_state(model_ex2) # new steady state # Simulate paths @@ -1063,7 +1107,7 @@ Let's start off at the baseline parameterization and record the steady state ```{code-cell} ipython3 -model_baseline = LakeModel() +model_baseline = create_lake_model() x0 = rate_steady_state(model_baseline) N0 = 100 T = 50 @@ -1079,7 +1123,7 @@ T_hat = 20 Let's increase $b$ to the new value and simulate for 20 periods ```{code-cell} ipython3 -model_high_b = LakeModel(b=b_hat) +model_high_b = create_lake_model(b=b_hat) # Simulate stocks and rates for first 20 periods X_path1 = generate_path(stock_update, x0 * N0, T_hat, model=model_high_b) From 8a8406750f1aa4718e09fdfb3094223f63f03fcd Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Sun, 2 Nov 2025 18:13:53 +0900 Subject: [PATCH 3/3] misc --- lectures/lake_model.md | 104 ++++++++++++++++++++--------------------- 1 file changed, 52 insertions(+), 52 deletions(-) diff --git a/lectures/lake_model.md b/lectures/lake_model.md index f56218753..7cc6420a2 100644 --- a/lectures/lake_model.md +++ b/lectures/lake_model.md @@ -210,47 +210,36 @@ This follows from the fact that the columns of $R$ sum to 1. Let's code up these equations. -To do this we're going to use a class that we'll call `LakeModel` that stores the primitives $\alpha, \lambda, b, d$ +### Model + +To begin, we set up a class called `LakeModel` that stores the primitives $\alpha, \lambda, b, d$. ```{code-cell} ipython3 class LakeModel(NamedTuple): """ Parameters for the lake model """ - 位: float = 0.283 - 伪: float = 0.013 - b: float = 0.0124 - d: float = 0.00822 - A: jnp.ndarray = None - R: jnp.ndarray = None - g: float = None - - -def create_lake_model(位: float = 0.283, - 伪: float = 0.013, - b: float = 0.0124, - d: float = 0.00822) -> LakeModel: + 位: float + 伪: float + b: float + d: float + A: jnp.ndarray + R: jnp.ndarray + g: float + + +def create_lake_model( + 位: float = 0.283, # job finding rate + 伪: float = 0.013, # separation rate + b: float = 0.0124, # birth rate + d: float = 0.00822 # death rate + ) -> LakeModel: """ Create a LakeModel instance with default parameters. Computes and stores the transition matrices A and R, and the labor force growth rate g. - Parameters - ---------- - 位 : float, optional - Job finding rate (default: 0.283) - 伪 : float, optional - Job separation rate (default: 0.013) - b : float, optional - Entry rate into labor force (default: 0.0124) - d : float, optional - Exit rate from labor force (default: 0.00822) - - Returns - ------- - LakeModel - A LakeModel instance with computed matrices A, R, and growth rate g """ # Compute growth rate g = b - d @@ -267,11 +256,33 @@ def create_lake_model(位: float = 0.283, return LakeModel(位=位, 伪=伪, b=b, d=d, A=A, R=R, g=g) ``` +As an experiment, let's create two instances, one with $伪=0.013$ and another with $伪=0.03$ + +```{code-cell} ipython3 +model = create_lake_model() +print(f"Default 伪: {model.伪}") +print(f"A matrix:\n{model.A}") +print(f"R matrix:\n{model.R}") +``` + +```{code-cell} ipython3 +model_new = create_lake_model(伪=0.03) +print(f"New 伪: {model_new.伪}") +print(f"New A matrix:\n{model_new.A}") +print(f"New R matrix:\n{model_new.R}") +``` + +### Code for dynamics + We will also use a specialized function to generate time series in an efficient JAX-compatible manner. -(Iteratively generating time series is somewhat nontrivial in JAX because arrays -are immutable.) +Iteratively generating time series is somewhat nontrivial in JAX because arrays +are immutable. + +Here we use `lax.scan`, which allows the function to be jit-compiled. + +Readers who prefer to skip the details can safely continue reading after the function definition. ```{code-cell} ipython3 @partial(jax.jit, static_argnames=['f', 'num_steps']) @@ -307,7 +318,7 @@ def generate_path(f, initial_state, num_steps, **kwargs): return path.T ``` -Now we can simulate the dynamics. +Here are functions to update $X_t$ and $x_t$. ```{code-cell} ipython3 def stock_update(X: jnp.ndarray, model: LakeModel) -> jnp.ndarray: @@ -321,26 +332,12 @@ def rate_update(x: jnp.ndarray, model: LakeModel) -> jnp.ndarray: return R @ x ``` -We create two instances, one with $伪=0.013$ and another with $伪=0.03$ - -```{code-cell} ipython3 -model = create_lake_model() -model_new = create_lake_model(伪=0.03) - -print(f"Default 伪: {model.伪}") -print(f"A matrix:\n{model.A}") -print(f"R matrix:\n{model.R}") -``` - -```{code-cell} ipython3 -print(f"New 伪: {model_new.伪}") -print(f"New A matrix:\n{model_new.A}") -print(f"New R matrix:\n{model_new.R}") -``` ### Aggregate dynamics -Let's run a simulation under the default parameters (see above) starting from $X_0 = (12, 138)$. +Let's run a simulation under the default parameters starting from $X_0 = (12, 138)$. + +We will plot the sequences $\{E_t\}$, $\{U_t\}$ and $\{N_t\}$. ```{code-cell} ipython3 N_0 = 150 # Population @@ -351,23 +348,26 @@ T = 50 # Simulation length U_0 = u_0 * N_0 E_0 = e_0 * N_0 -fig, axes = plt.subplots(3, 1, figsize=(10, 8)) +# Generate X path X_0 = jnp.array([U_0, E_0]) X_path = generate_path(stock_update, X_0, T, model=model) +# Plot +fig, axes = plt.subplots(3, 1, figsize=(10, 8)) titles = ['unemployment', 'employment', 'labor force'] data = [X_path[0, :], X_path[1, :], X_path.sum(0)] - for ax, title, series in zip(axes, titles, data): ax.plot(series, lw=2) ax.set_title(title) - plt.tight_layout() plt.show() ``` The aggregates $E_t$ and $U_t$ don't converge because their sum $E_t + U_t$ grows at rate $g$. + +### Rate dynamics + On the other hand, the vector of employment and unemployment rates $x_t$ can be in a steady state $\bar x$ if there exists an $\bar x$ such that