From e860f3d2f53dfd22ed204584e6803b4a02142fcb Mon Sep 17 00:00:00 2001 From: Longye Tian Date: Wed, 17 Sep 2025 11:39:48 +1000 Subject: [PATCH 01/10] Update mccall_model.md --- lectures/mccall_model.md | 328 ++++++++++++++++++--------------------- 1 file changed, 152 insertions(+), 176 deletions(-) diff --git a/lectures/mccall_model.md b/lectures/mccall_model.md index 202b9d591..45017beb7 100644 --- a/lectures/mccall_model.md +++ b/lectures/mccall_model.md @@ -62,8 +62,9 @@ Let's start with some imports: ```{code-cell} ipython import matplotlib.pyplot as plt import numpy as np -from numba import jit, float64 -from numba.experimental import jitclass +import jax +import jax.numpy as jnp +from typing import NamedTuple import quantecon as qe from quantecon.distributions import BetaBinomial ``` @@ -343,14 +344,14 @@ Our default for $q$, the distribution of the state process, will be ```{code-cell} python3 n, a, b = 50, 200, 100 # default parameters -q_default = BetaBinomial(n, a, b).pdf() # default choice of q +q_default = jnp.array(BetaBinomial(n, a, b).pdf()) # default choice of q as JAX array ``` Our default set of values for wages will be ```{code-cell} python3 w_min, w_max = 10, 60 -w_default = np.linspace(w_min, w_max, n+1) +w_default = jnp.linspace(w_min, w_max, n+1) ``` Here's a plot of the probabilities of different wage outcomes: @@ -364,32 +365,10 @@ ax.set_ylabel('probabilities') plt.show() ``` -We are going to use Numba to accelerate our code. +We are going to use JAX to accelerate our code. -* See, in particular, the discussion of `@jitclass` in [our lecture on Numba](https://python-programming.quantecon.org/numba.html). - -The following helps Numba by providing some type specifications. - -```{code-cell} python3 -mccall_data = [ - ('c', float64), # unemployment compensation - ('β', float64), # discount factor - ('w', float64[::1]), # array of wage values, w[i] = wage at state i - ('q', float64[::1]) # array of probabilities -] -``` - -```{note} -Note the use of `[::1]` in the array type declarations above. - -This notation specifies that the arrays should be C-contiguous. - -This is important for performance, especially when using the `@` operator for matrix multiplication (e.g., `v @ q`). - -Without this specification, Numba might need to handle non-contiguous arrays, which can significantly slow down these operations. - -Try to replace `[::1]` with `[:]` and see what happens. -``` +* JAX provides automatic differentiation and JIT compilation capabilities. +* We'll use NamedTuple for our model class to maintain immutability, which works well with JAX's functional programming paradigm. Here's a class that stores the data and computes the values of state-action pairs, i.e. the value in the maximum bracket on the right hand side of the Bellman equation {eq}`odu_pv2p`, @@ -398,26 +377,25 @@ given the current state and an arbitrary feasible action. Default parameter values are embedded in the class. ```{code-cell} python3 -@jitclass(mccall_data) -class McCallModel: - - def __init__(self, c=25, β=0.99, w=w_default, q=q_default): - - self.c, self.β = c, β - self.w, self.q = w_default, q_default - - def state_action_values(self, i, v): - """ - The values of state-action pairs. - """ - # Simplify names - c, β, w, q = self.c, self.β, self.w, self.q - # Evaluate value for each state-action pair - # Consider action = accept or reject the current offer - accept = w[i] / (1 - β) - reject = c + β * (v @ q) - - return np.array([accept, reject]) +class McCallModel(NamedTuple): + c: float = 25 # unemployment compensation + β: float = 0.99 # discount factor + w: jnp.ndarray = w_default # array of wage values, w[i] = wage at state i + q: jnp.ndarray = q_default # array of probabilities + +@jax.jit +def state_action_values(model, i, v): + """ + The values of state-action pairs. + """ + # Simplify names + c, β, w, q = model.c, model.β, model.w, model.q + # Evaluate value for each state-action pair + # Consider action = accept or reject the current offer + accept = w[i] / (1 - β) + reject = c + β * (v @ q) + + return jnp.array([accept, reject]) ``` Based on these defaults, let's try plotting the first few approximate value functions @@ -439,13 +417,13 @@ def plot_value_function_seq(mcm, ax, num_plots=6): n = len(mcm.w) v = mcm.w / (1 - mcm.β) - v_next = np.empty_like(v) for i in range(num_plots): ax.plot(mcm.w, v, '-', alpha=0.4, label=f"iterate {i}") # Update guess + v_next = jnp.zeros_like(v) for j in range(n): - v_next[j] = np.max(mcm.state_action_values(j, v)) - v[:] = v_next # copy contents into v + v_next = v_next.at[j].set(jnp.max(state_action_values(mcm, j, v))) + v = v_next # update v ax.legend(loc='lower right') ``` @@ -469,37 +447,35 @@ Here's a more serious iteration effort to compute the limit, which continues unt Once we obtain a good approximation to the limit, we will use it to calculate the reservation wage. -We'll be using JIT compilation via Numba to turbocharge our loops. +We'll be using JIT compilation via JAX to accelerate our loops. ```{code-cell} python3 -@jit -def compute_reservation_wage(mcm, - max_iter=500, - tol=1e-6): - +@jax.jit +def compute_reservation_wage(mcm, max_iter=500, tol=1e-6): # Simplify names c, β, w, q = mcm.c, mcm.β, mcm.w, mcm.q - + # == First compute the value function == # - n = len(w) - v = w / (1 - β) # initial guess - v_next = np.empty_like(v) - j = 0 - error = tol + 1 - while j < max_iter and error > tol: - + v = w / (1 - β) # initial guess + + def body_fun(state): + v, i, error = state + v_next = jnp.zeros_like(v) for j in range(n): - v_next[j] = np.max(mcm.state_action_values(j, v)) - - error = np.max(np.abs(v_next - v)) - j += 1 - - v[:] = v_next # copy contents into v - + v_next = v_next.at[j].set(jnp.max(state_action_values(mcm, j, v))) + error = jnp.max(jnp.abs(v_next - v)) + return v_next, i + 1, error + + def cond_fun(state): + v, i, error = state + return jnp.logical_and(i < max_iter, error > tol) + + initial_state = (v, 0, tol + 1) + v_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state) + # == Now compute the reservation wage == # - - return (1 - β) * (c + β * (v @ q)) + return (1 - β) * (c + β * (v_final @ q)) ``` The next line computes the reservation wage at default parameters @@ -518,15 +494,17 @@ $c$. ```{code-cell} python3 grid_size = 25 -R = np.empty((grid_size, grid_size)) +c_vals = jnp.linspace(10.0, 30.0, grid_size) +β_vals = jnp.linspace(0.9, 0.99, grid_size) -c_vals = np.linspace(10.0, 30.0, grid_size) -β_vals = np.linspace(0.9, 0.99, grid_size) +def compute_R_element(c, β): + mcm = McCallModel(c=c, β=β) + return compute_reservation_wage(mcm) -for i, c in enumerate(c_vals): - for j, β in enumerate(β_vals): - mcm = McCallModel(c=c, β=β) - R[i, j] = compute_reservation_wage(mcm) +# Create meshgrid and vectorize computation +c_grid, β_grid = jnp.meshgrid(c_vals, β_vals, indexing='ij') +compute_R_vectorized = jax.vmap(jax.vmap(compute_R_element, in_axes=(None, 0)), in_axes=(0, None)) +R = compute_R_vectorized(c_vals, β_vals) ``` ```{code-cell} python3 @@ -623,32 +601,30 @@ The big difference here, however, is that we're iterating on a scalar $h$, rathe Here's an implementation: ```{code-cell} python3 -@jit -def compute_reservation_wage_two(mcm, - max_iter=500, - tol=1e-5): - +@jax.jit +def compute_reservation_wage_two(mcm, max_iter=500, tol=1e-5): # Simplify names c, β, w, q = mcm.c, mcm.β, mcm.w, mcm.q - + # == First compute h == # - h = (w @ q) / (1 - β) - i = 0 - error = tol + 1 - while i < max_iter and error > tol: - - s = np.maximum(w / (1 - β), h) + + def body_fun(state): + h, i, error = state + s = jnp.maximum(w / (1 - β), h) h_next = c + β * (s @ q) - - error = np.abs(h_next - h) - i += 1 - - h = h_next - + error = jnp.abs(h_next - h) + return h_next, i + 1, error + + def cond_fun(state): + h, i, error = state + return jnp.logical_and(i < max_iter, error > tol) + + initial_state = (h, 0, tol + 1) + h_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state) + # == Now compute the reservation wage == # - - return (1 - β) * h + return (1 - β) * h_final ``` You can use this code to solve the exercise below. @@ -678,37 +654,42 @@ Plot mean unemployment duration as a function of $c$ in `c_vals`. Here's one solution ```{code-cell} python3 -cdf = np.cumsum(q_default) - -@jit -def compute_stopping_time(w_bar, seed=1234): - - np.random.seed(seed) - t = 1 - while True: - # Generate a wage draw - w = w_default[qe.random.draw(cdf)] - # Stop when the draw is above the reservation wage - if w >= w_bar: - stopping_time = t - break - else: - t += 1 - return stopping_time - -@jit -def compute_mean_stopping_time(w_bar, num_reps=100000): - obs = np.empty(num_reps) - for i in range(num_reps): - obs[i] = compute_stopping_time(w_bar, seed=i) - return obs.mean() - -c_vals = np.linspace(10, 40, 25) -stop_times = np.empty_like(c_vals) -for i, c in enumerate(c_vals): +cdf = jnp.cumsum(q_default) + +@jax.jit +def compute_stopping_time(w_bar, key): + def body_fun(state): + t, key, done = state + key, subkey = jax.random.split(key) + u = jax.random.uniform(subkey) + w = w_default[jnp.searchsorted(cdf, u)] + done = w >= w_bar + t = jnp.where(done, t, t + 1) + return t, key, done + + def cond_fun(state): + t, _, done = state + return jnp.logical_not(done) + + initial_state = (1, key, False) + t_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state) + return t_final + +@jax.jit +def compute_mean_stopping_time(w_bar, num_reps=100000, seed=1234): + key = jax.random.PRNGKey(seed) + keys = jax.random.split(key, num_reps) + obs = jax.vmap(compute_stopping_time, in_axes=(None, 0))(w_bar, keys) + return jnp.mean(obs) + +c_vals = jnp.linspace(10, 40, 25) + +def compute_stop_time_for_c(c): mcm = McCallModel(c=c) w_bar = compute_reservation_wage_two(mcm) - stop_times[i] = compute_mean_stopping_time(w_bar) + return compute_mean_stopping_time(w_bar) + +stop_times = jax.vmap(compute_stop_time_for_c)(c_vals) fig, ax = plt.subplots() @@ -789,48 +770,41 @@ Once your code is working, investigate how the reservation wage changes with $c$ Here is one solution: ```{code-cell} python3 -mccall_data_continuous = [ - ('c', float64), # unemployment compensation - ('β', float64), # discount factor - ('σ', float64), # scale parameter in lognormal distribution - ('μ', float64), # location parameter in lognormal distribution - ('w_draws', float64[:]) # draws of wages for Monte Carlo -] - -@jitclass(mccall_data_continuous) -class McCallModelContinuous: - - def __init__(self, c=25, β=0.99, σ=0.5, μ=2.5, mc_size=1000): - - self.c, self.β, self.σ, self.μ = c, β, σ, μ - - # Draw and store shocks - np.random.seed(1234) - s = np.random.randn(mc_size) - self.w_draws = np.exp(μ+ σ * s) - - -@jit +class McCallModelContinuous(NamedTuple): + c: float # unemployment compensation + β: float # discount factor + σ: float # scale parameter in lognormal distribution + μ: float # location parameter in lognormal distribution + w_draws: jnp.ndarray # draws of wages for Monte Carlo + +def create_mccall_continuous(c=25, β=0.99, σ=0.5, μ=2.5, mc_size=1000, seed=1234): + key = jax.random.PRNGKey(seed) + s = jax.random.normal(key, (mc_size,)) + w_draws = jnp.exp(μ + σ * s) + return McCallModelContinuous(c=c, β=β, σ=σ, μ=μ, w_draws=w_draws) + +@jax.jit def compute_reservation_wage_continuous(mcmc, max_iter=500, tol=1e-5): - c, β, σ, μ, w_draws = mcmc.c, mcmc.β, mcmc.σ, mcmc.μ, mcmc.w_draws - - h = np.mean(w_draws) / (1 - β) # initial guess - i = 0 - error = tol + 1 - while i < max_iter and error > tol: - - integral = np.mean(np.maximum(w_draws / (1 - β), h)) + + h = jnp.mean(w_draws) / (1 - β) # initial guess + + def body_fun(state): + h, i, error = state + integral = jnp.mean(jnp.maximum(w_draws / (1 - β), h)) h_next = c + β * integral - - error = np.abs(h_next - h) - i += 1 - - h = h_next - + error = jnp.abs(h_next - h) + return h_next, i + 1, error + + def cond_fun(state): + h, i, error = state + return jnp.logical_and(i < max_iter, error > tol) + + initial_state = (h, 0, tol + 1) + h_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state) + # == Now compute the reservation wage == # - - return (1 - β) * h + return (1 - β) * h_final ``` Now we investigate how the reservation wage changes with $c$ and @@ -840,15 +814,17 @@ We will do this using a contour plot. ```{code-cell} python3 grid_size = 25 -R = np.empty((grid_size, grid_size)) +c_vals = jnp.linspace(10.0, 30.0, grid_size) +β_vals = jnp.linspace(0.9, 0.99, grid_size) -c_vals = np.linspace(10.0, 30.0, grid_size) -β_vals = np.linspace(0.9, 0.99, grid_size) +def compute_R_element(c, β): + mcmc = create_mccall_continuous(c=c, β=β) + return compute_reservation_wage_continuous(mcmc) -for i, c in enumerate(c_vals): - for j, β in enumerate(β_vals): - mcmc = McCallModelContinuous(c=c, β=β) - R[i, j] = compute_reservation_wage_continuous(mcmc) +# Create meshgrid and vectorize computation +c_grid, β_grid = jnp.meshgrid(c_vals, β_vals, indexing='ij') +compute_R_vectorized = jax.vmap(jax.vmap(compute_R_element, in_axes=(None, 0)), in_axes=(0, None)) +R = compute_R_vectorized(c_vals, β_vals) ``` ```{code-cell} python3 From a1c0a5e623a190a5bf3cdd4d0d35ccb8f8871c78 Mon Sep 17 00:00:00 2001 From: Longye Tian <133612246+longye-tian@users.noreply.github.com> Date: Thu, 18 Sep 2025 12:06:15 +1000 Subject: [PATCH 02/10] Update lectures/mccall_model.md Co-authored-by: Humphrey Yang <39026988+HumphreyYang@users.noreply.github.com> --- lectures/mccall_model.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lectures/mccall_model.md b/lectures/mccall_model.md index 45017beb7..a79d213d7 100644 --- a/lectures/mccall_model.md +++ b/lectures/mccall_model.md @@ -378,8 +378,8 @@ Default parameter values are embedded in the class. ```{code-cell} python3 class McCallModel(NamedTuple): - c: float = 25 # unemployment compensation - β: float = 0.99 # discount factor + c: float = 25 # unemployment compensation + β: float = 0.99 # discount factor w: jnp.ndarray = w_default # array of wage values, w[i] = wage at state i q: jnp.ndarray = q_default # array of probabilities From f66be678dfa0e8e14c6b1251c19e9b81de0e01eb Mon Sep 17 00:00:00 2001 From: Longye Tian <133612246+longye-tian@users.noreply.github.com> Date: Thu, 18 Sep 2025 12:06:30 +1000 Subject: [PATCH 03/10] Update lectures/mccall_model.md Co-authored-by: Humphrey Yang <39026988+HumphreyYang@users.noreply.github.com> --- lectures/mccall_model.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lectures/mccall_model.md b/lectures/mccall_model.md index a79d213d7..1f21920d7 100644 --- a/lectures/mccall_model.md +++ b/lectures/mccall_model.md @@ -344,7 +344,7 @@ Our default for $q$, the distribution of the state process, will be ```{code-cell} python3 n, a, b = 50, 200, 100 # default parameters -q_default = jnp.array(BetaBinomial(n, a, b).pdf()) # default choice of q as JAX array +q_default = jnp.array(BetaBinomial(n, a, b).pdf()) ``` Our default set of values for wages will be From 94d23397511a6652c8447ca82dee0643d0c17f36 Mon Sep 17 00:00:00 2001 From: Longye Tian <133612246+longye-tian@users.noreply.github.com> Date: Thu, 18 Sep 2025 12:06:39 +1000 Subject: [PATCH 04/10] Update lectures/mccall_model.md Co-authored-by: Humphrey Yang <39026988+HumphreyYang@users.noreply.github.com> --- lectures/mccall_model.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lectures/mccall_model.md b/lectures/mccall_model.md index 1f21920d7..196789ca7 100644 --- a/lectures/mccall_model.md +++ b/lectures/mccall_model.md @@ -676,7 +676,7 @@ def compute_stopping_time(w_bar, key): return t_final @jax.jit -def compute_mean_stopping_time(w_bar, num_reps=100000, seed=1234): +def compute_mean_stopping_time(w_bar, num_reps=100000, seed=0): key = jax.random.PRNGKey(seed) keys = jax.random.split(key, num_reps) obs = jax.vmap(compute_stopping_time, in_axes=(None, 0))(w_bar, keys) From 01c5e805f5519ccdd639485d16c6196fd37c70a5 Mon Sep 17 00:00:00 2001 From: Longye Tian <133612246+longye-tian@users.noreply.github.com> Date: Thu, 18 Sep 2025 12:06:47 +1000 Subject: [PATCH 05/10] Update lectures/mccall_model.md Co-authored-by: Humphrey Yang <39026988+HumphreyYang@users.noreply.github.com> --- lectures/mccall_model.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lectures/mccall_model.md b/lectures/mccall_model.md index 196789ca7..6295aa1bc 100644 --- a/lectures/mccall_model.md +++ b/lectures/mccall_model.md @@ -777,7 +777,7 @@ class McCallModelContinuous(NamedTuple): μ: float # location parameter in lognormal distribution w_draws: jnp.ndarray # draws of wages for Monte Carlo -def create_mccall_continuous(c=25, β=0.99, σ=0.5, μ=2.5, mc_size=1000, seed=1234): +def create_mccall_continuous(c=25, β=0.99, σ=0.5, μ=2.5, mc_size=1000, seed=0): key = jax.random.PRNGKey(seed) s = jax.random.normal(key, (mc_size,)) w_draws = jnp.exp(μ + σ * s) From 3711f5bbdc66fb7638cbddedc4e2e9c25dc52d89 Mon Sep 17 00:00:00 2001 From: Longye Tian <133612246+longye-tian@users.noreply.github.com> Date: Thu, 18 Sep 2025 12:06:56 +1000 Subject: [PATCH 06/10] Update lectures/mccall_model.md Co-authored-by: Humphrey Yang <39026988+HumphreyYang@users.noreply.github.com> --- lectures/mccall_model.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lectures/mccall_model.md b/lectures/mccall_model.md index 6295aa1bc..9079eb60c 100644 --- a/lectures/mccall_model.md +++ b/lectures/mccall_model.md @@ -823,7 +823,10 @@ def compute_R_element(c, β): # Create meshgrid and vectorize computation c_grid, β_grid = jnp.meshgrid(c_vals, β_vals, indexing='ij') -compute_R_vectorized = jax.vmap(jax.vmap(compute_R_element, in_axes=(None, 0)), in_axes=(0, None)) +compute_R_vectorized = jax.vmap( + jax.vmap(compute_R_element, + in_axes=(None, 0)), + in_axes=(0, None)) R = compute_R_vectorized(c_vals, β_vals) ``` From db5a05f673443642a5e41d97506e4408b2c36c5d Mon Sep 17 00:00:00 2001 From: Longye Tian <133612246+longye-tian@users.noreply.github.com> Date: Thu, 18 Sep 2025 12:07:04 +1000 Subject: [PATCH 07/10] Update lectures/mccall_model.md Co-authored-by: Humphrey Yang <39026988+HumphreyYang@users.noreply.github.com> --- lectures/mccall_model.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lectures/mccall_model.md b/lectures/mccall_model.md index 9079eb60c..711376890 100644 --- a/lectures/mccall_model.md +++ b/lectures/mccall_model.md @@ -771,10 +771,10 @@ Here is one solution: ```{code-cell} python3 class McCallModelContinuous(NamedTuple): - c: float # unemployment compensation - β: float # discount factor - σ: float # scale parameter in lognormal distribution - μ: float # location parameter in lognormal distribution + c: float # unemployment compensation + β: float # discount factor + σ: float # scale parameter in lognormal distribution + μ: float # location parameter in lognormal distribution w_draws: jnp.ndarray # draws of wages for Monte Carlo def create_mccall_continuous(c=25, β=0.99, σ=0.5, μ=2.5, mc_size=1000, seed=0): From 75dc290e42739258721c7917c28c924fb69e9700 Mon Sep 17 00:00:00 2001 From: Longye Tian Date: Thu, 18 Sep 2025 12:12:09 +1000 Subject: [PATCH 08/10] Update mccall_model.md --- lectures/mccall_model.md | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/lectures/mccall_model.md b/lectures/mccall_model.md index 711376890..2cded1d3b 100644 --- a/lectures/mccall_model.md +++ b/lectures/mccall_model.md @@ -34,11 +34,11 @@ and the pros and cons as they themselves see them." -- Robert E. Lucas, Jr. In addition to what's in Anaconda, this lecture will need the following libraries: -```{code-cell} ipython +```{code-cell} ipython3 --- tags: [hide-output] --- -!pip install quantecon +!pip install quantecon jax ``` ## Overview @@ -64,6 +64,7 @@ import matplotlib.pyplot as plt import numpy as np import jax import jax.numpy as jnp +import jax.random as jr from typing import NamedTuple import quantecon as qe from quantecon.distributions import BetaBinomial @@ -367,7 +368,6 @@ plt.show() We are going to use JAX to accelerate our code. -* JAX provides automatic differentiation and JIT compilation capabilities. * We'll use NamedTuple for our model class to maintain immutability, which works well with JAX's functional programming paradigm. Here's a class that stores the data and computes the values of state-action pairs, @@ -455,7 +455,7 @@ def compute_reservation_wage(mcm, max_iter=500, tol=1e-6): # Simplify names c, β, w, q = mcm.c, mcm.β, mcm.w, mcm.q - # == First compute the value function == # + # First compute the value function n = len(w) v = w / (1 - β) # initial guess @@ -474,7 +474,7 @@ def compute_reservation_wage(mcm, max_iter=500, tol=1e-6): initial_state = (v, 0, tol + 1) v_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state) - # == Now compute the reservation wage == # + # Now compute the reservation wage return (1 - β) * (c + β * (v_final @ q)) ``` @@ -606,7 +606,7 @@ def compute_reservation_wage_two(mcm, max_iter=500, tol=1e-5): # Simplify names c, β, w, q = mcm.c, mcm.β, mcm.w, mcm.q - # == First compute h == # + # First compute h h = (w @ q) / (1 - β) def body_fun(state): @@ -623,7 +623,7 @@ def compute_reservation_wage_two(mcm, max_iter=500, tol=1e-5): initial_state = (h, 0, tol + 1) h_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state) - # == Now compute the reservation wage == # + # Now compute the reservation wage return (1 - β) * h_final ``` @@ -660,8 +660,8 @@ cdf = jnp.cumsum(q_default) def compute_stopping_time(w_bar, key): def body_fun(state): t, key, done = state - key, subkey = jax.random.split(key) - u = jax.random.uniform(subkey) + key, subkey = jr.split(key) + u = jr.uniform(subkey) w = w_default[jnp.searchsorted(cdf, u)] done = w >= w_bar t = jnp.where(done, t, t + 1) @@ -676,9 +676,9 @@ def compute_stopping_time(w_bar, key): return t_final @jax.jit -def compute_mean_stopping_time(w_bar, num_reps=100000, seed=0): - key = jax.random.PRNGKey(seed) - keys = jax.random.split(key, num_reps) +def compute_mean_stopping_time(w_bar, num_reps=100000, seed=1234): + key = jr.PRNGKey(seed) + keys = jr.split(key, num_reps) obs = jax.vmap(compute_stopping_time, in_axes=(None, 0))(w_bar, keys) return jnp.mean(obs) @@ -777,9 +777,9 @@ class McCallModelContinuous(NamedTuple): μ: float # location parameter in lognormal distribution w_draws: jnp.ndarray # draws of wages for Monte Carlo -def create_mccall_continuous(c=25, β=0.99, σ=0.5, μ=2.5, mc_size=1000, seed=0): - key = jax.random.PRNGKey(seed) - s = jax.random.normal(key, (mc_size,)) +def create_mccall_continuous(c=25, β=0.99, σ=0.5, μ=2.5, mc_size=1000, seed=1234): + key = jr.PRNGKey(seed) + s = jr.normal(key, (mc_size,)) w_draws = jnp.exp(μ + σ * s) return McCallModelContinuous(c=c, β=β, σ=σ, μ=μ, w_draws=w_draws) @@ -803,7 +803,7 @@ def compute_reservation_wage_continuous(mcmc, max_iter=500, tol=1e-5): initial_state = (h, 0, tol + 1) h_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state) - # == Now compute the reservation wage == # + # Now compute the reservation wage return (1 - β) * h_final ``` From 1a748dcf35c02481d7acfe696e4512765b8f4c3e Mon Sep 17 00:00:00 2001 From: Longye Tian Date: Thu, 18 Sep 2025 12:14:35 +1000 Subject: [PATCH 09/10] Update mccall_model.md --- lectures/mccall_model.md | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/lectures/mccall_model.md b/lectures/mccall_model.md index 2cded1d3b..07151a840 100644 --- a/lectures/mccall_model.md +++ b/lectures/mccall_model.md @@ -370,11 +370,7 @@ We are going to use JAX to accelerate our code. * We'll use NamedTuple for our model class to maintain immutability, which works well with JAX's functional programming paradigm. -Here's a class that stores the data and computes the values of state-action pairs, -i.e. the value in the maximum bracket on the right hand side of the Bellman equation {eq}`odu_pv2p`, -given the current state and an arbitrary feasible action. - -Default parameter values are embedded in the class. +Here's a class that stores the model parameters with default values, and a separate function that computes the values of state-action pairs (i.e., the value in the maximum bracket on the right hand side of the Bellman equation {eq}`odu_pv2p`). ```{code-cell} python3 class McCallModel(NamedTuple): From d9cae8695fd1c5cbb508d6732a674c3253955a18 Mon Sep 17 00:00:00 2001 From: Longye Tian Date: Sun, 21 Sep 2025 12:42:14 +1000 Subject: [PATCH 10/10] Update mccall_model.md --- lectures/mccall_model.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/lectures/mccall_model.md b/lectures/mccall_model.md index 07151a840..6205e276b 100644 --- a/lectures/mccall_model.md +++ b/lectures/mccall_model.md @@ -93,9 +93,11 @@ At time $t$, our agent has two choices: The agent is infinitely lived and aims to maximize the expected discounted sum of earnings -$$ -\mathbb{E} \sum_{t=0}^{\infty} \beta^t y_t -$$ +```{math} +:label: obj_model + +{\mathbb E} \sum_{t=0}^\infty \beta^t u(y_t) +``` The constant $\beta$ lies in $(0, 1)$ and is called a **discount factor**. @@ -140,7 +142,7 @@ $w \in \mathbb{W}$. In particular, the agent has wage offer $w$ in hand. More precisely, $v^*(w)$ denotes the value of the objective function -{eq}`objective` when an agent in this situation makes *optimal* decisions now +{eq}`obj_model` when an agent in this situation makes *optimal* decisions now and at all future points in time. Of course $v^*(w)$ is not trivial to calculate because we don't yet know