Skip to content

Commit 0cb1b26

Browse files
jstacclaude
andcommitted
Refine optimal savings lecture series with improved clarity and code organization
- Enhance code formatting and comments for better readability across all OS lectures - Improve mathematical notation and explanations in stochastic optimal savings - Restructure function definitions in os_egm_jax for better logical flow - Simplify utility function and Bellman operator implementations - Add clearer documentation of marginal utility approximations in EGM - Remove redundant code and improve variable naming throughout 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 1e2d0d3 commit 0cb1b26

File tree

4 files changed

+141
-117
lines changed

4 files changed

+141
-117
lines changed

lectures/os_egm.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,10 +241,12 @@ def K(
241241
# Allocate memory for new consumption array
242242
c_out = np.empty_like(s_grid)
243243
244-
# Solve for updated consumption value
245244
for i, s in enumerate(s_grid):
245+
# Approximate marginal utility ∫ u'(σ(f(s, α)z)) f'(s, α) z ϕ(z)dz
246246
vals = u_prime(σ(f(s, α) * shocks)) * f_prime(s, α) * shocks
247-
c_out[i] = u_prime_inv(β * np.mean(vals))
247+
mu = np.mean(vals)
248+
# Compute consumption
249+
c_out[i] = u_prime_inv(β * mu)
248250
249251
# Determine corresponding endogenous grid
250252
x_out = s_grid + c_out # x_i = s_i + c_i

lectures/os_egm_jax.md

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,16 @@ class Model(NamedTuple):
9090
α: float # production function parameter
9191
9292
93-
def create_model(β: float = 0.96,
94-
μ: float = 0.0,
95-
s: float = 0.1,
96-
grid_max: float = 4.0,
97-
grid_size: int = 120,
98-
shock_size: int = 250,
99-
seed: int = 1234,
100-
α: float = 0.4) -> Model:
93+
def create_model(
94+
β: float = 0.96,
95+
μ: float = 0.0,
96+
s: float = 0.1,
97+
grid_max: float = 4.0,
98+
grid_size: int = 120,
99+
shock_size: int = 250,
100+
seed: int = 1234,
101+
α: float = 0.4
102+
) -> Model:
101103
"""
102104
Creates an instance of the optimal savings model.
103105
"""
@@ -111,6 +113,17 @@ def create_model(β: float = 0.96,
111113
return Model(β=β, μ=μ, s=s, s_grid=s_grid, shocks=shocks, α=α)
112114
```
113115

116+
117+
We define utility and production functions globally.
118+
119+
```{code-cell} python3
120+
# Define utility and production functions with derivatives
121+
u = lambda c: jnp.log(c)
122+
u_prime = lambda c: 1 / c
123+
u_prime_inv = lambda x: 1 / x
124+
f = lambda k, α: k**α
125+
f_prime = lambda k, α: α * k**(α - 1)
126+
```
114127
Here's the Coleman-Reffett operator using EGM.
115128

116129
The key JAX feature here is `vmap`, which vectorizes the computation over the grid points.
@@ -135,10 +148,13 @@ def K(
135148
136149
# Define function to compute consumption at a single grid point
137150
def compute_c(s):
151+
# Approximate marginal utility ∫ u'(σ(f(s, α)z)) f'(s, α) z ϕ(z)dz
138152
vals = u_prime(σ(f(s, α) * shocks)) * f_prime(s, α) * shocks
139-
return u_prime_inv(β * jnp.mean(vals))
153+
mu = jnp.mean(vals)
154+
# Calculate consumption
155+
return u_prime_inv(β * mu)
140156
141-
# Vectorize over grid using vmap
157+
# Vectorize and calculate on all exogenous grid points
142158
compute_c_vectorized = jax.vmap(compute_c)
143159
c_out = compute_c_vectorized(s_grid)
144160
@@ -148,18 +164,6 @@ def K(
148164
return c_out, x_out
149165
```
150166

151-
We define utility and production functions globally.
152-
153-
Note that `f` and `f_prime` take `α` as an explicit argument, allowing them to work with JAX's functional programming model.
154-
155-
```{code-cell} python3
156-
# Define utility and production functions with derivatives
157-
u = lambda c: jnp.log(c)
158-
u_prime = lambda c: 1 / c
159-
u_prime_inv = lambda x: 1 / x
160-
f = lambda k, α: k**α
161-
f_prime = lambda k, α: α * k**(α - 1)
162-
```
163167

164168
Now we create a model instance.
165169

@@ -172,11 +176,13 @@ The solver uses JAX's `jax.lax.while_loop` for the iteration and is JIT-compiled
172176

173177
```{code-cell} python3
174178
@jax.jit
175-
def solve_model_time_iter(model: Model,
176-
c_init: jnp.ndarray,
177-
x_init: jnp.ndarray,
178-
tol: float = 1e-5,
179-
max_iter: int = 1000):
179+
def solve_model_time_iter(
180+
model: Model,
181+
c_init: jnp.ndarray,
182+
x_init: jnp.ndarray,
183+
tol: float = 1e-5,
184+
max_iter: int = 1000
185+
):
180186
"""
181187
Solve the model using time iteration with EGM.
182188
"""

lectures/os_numerical.md

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,14 @@ This is a form of **successive approximation**, and was discussed in our {doc}`l
9292
The basic idea is:
9393

9494
1. Take an arbitrary initial guess of $v$.
95-
1. Obtain an update $w$ defined by
95+
1. Obtain an update $\hat v$ defined by
9696

9797
$$
98-
w(x) = \max_{0\leq c \leq x} \{u(c) + \beta v(x-c)\}
98+
\hat v(x) = \max_{0\leq c \leq x} \{u(c) + \beta v(x-c)\}
9999
$$
100100

101-
1. Stop if $w$ is approximately equal to $v$, otherwise set
102-
$v=w$ and go back to step 2.
101+
1. Stop if $\hat v$ is approximately equal to $v$, otherwise set
102+
$v=\hat v$ and go back to step 2.
103103

104104
Let's write this a bit more mathematically.
105105

@@ -109,7 +109,7 @@ We introduce the **Bellman operator** $T$ that takes a function v as an
109109
argument and returns a new function $Tv$ defined by
110110

111111
$$
112-
Tv(x) = \max_{0 \leq c \leq x} \{u(c) + \beta v(x - c)\}
112+
Tv(x) = \max_{0 \leq c \leq x} \{u(c) + \beta v(x - c)\}
113113
$$
114114

115115
From $v$ we get $Tv$, and applying $T$ to this yields
@@ -206,13 +206,7 @@ Here's the CRRA utility function.
206206

207207
```{code-cell} python3
208208
def u(c, γ):
209-
"""
210-
Utility function.
211-
"""
212-
if γ == 1:
213-
return np.log(c)
214-
else:
215-
return (c ** (1 - γ)) / (1 - γ)
209+
return (c ** (1 - γ)) / (1 - γ)
216210
```
217211

218212
To work with the Bellman equation, let's write it as
@@ -240,8 +234,8 @@ def B(
240234
Right hand side of the Bellman equation given x and c.
241235
242236
"""
243-
# Unpack
244-
β, γ, x_grid = model.β, model.γ, model.x_grid
237+
# Unpack (simplify names)
238+
β, γ, x_grid = model
245239
246240
# Convert array v into a function by linear interpolation
247241
vf = lambda x: np.interp(x, x_grid, v)
@@ -250,7 +244,12 @@ def B(
250244
return u(c, γ) + β * vf(x - c)
251245
```
252246

253-
We now define the Bellman operation:
247+
We now define the Bellman operator acting on grid points:
248+
249+
$$
250+
Tv(x_i) = \max_{0 \leq c \leq x_i} B(x_i, c, v)
251+
\qquad \text{for all } i
252+
$$
254253

255254
```{code-cell} python3
256255
def T(
@@ -280,7 +279,7 @@ model = create_cake_eating_model()
280279
β, γ, x_grid = model
281280
```
282281

283-
Now let's see the iteration of the value function in action.
282+
Now let's see iteration of the value function in action.
284283

285284
We start from guess $v$ given by $v(x) = u(x)$ for every
286285
$x$ grid point.

0 commit comments

Comments
 (0)