Skip to content

Commit c1d524a

Browse files
jstacclaude
andauthored
Refactor cake_eating_egm lecture: update K operator signature and improve clarity (#730)
* Refactor cake_eating_egm lecture: update K operator signature and improve clarity - Update K operator to take (c_in, x_in) and return (c_out, x_out) for clarity - Rename 'grid' to 's_grid' throughout to emphasize exogenous savings grid - Change shock scale parameter from 's' to 'ν' to avoid confusion with savings - Update solve_model_time_iter to work with new K signature - Fix grammar: change "fixed/calculated" to "fix/calculate" in bullet points - Add missing period after "analytical solutions" - Remove extra space in "is determined" - Expand "EG" abbreviation to "endogenous grid" in comment - Change code-cell from ipython to python3 for consistency - Add note about Python loops and reference to JAX lecture 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Small changes to intro. --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent 2f38808 commit c1d524a

File tree

1 file changed

+80
-60
lines changed

1 file changed

+80
-60
lines changed

lectures/cake_eating_egm.md

Lines changed: 80 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -39,21 +39,26 @@ EGM is a numerical method for implementing policy iteration invented by [Chris C
3939

4040
The original reference is {cite}`Carroll2006`.
4141

42+
For now we will focus on a clean and simple implementation of EGM that stays
43+
close to the underlying mathematics.
44+
45+
Then, in {doc}`the next lecture <cake_eating_egm_jax>`, we will construct a fully vectorized and parallelized version of EGM based on JAX.
46+
4247
Let's start with some standard imports:
4348

44-
```{code-cell} ipython
49+
```{code-cell} python3
4550
import matplotlib.pyplot as plt
4651
import numpy as np
4752
import quantecon as qe
4853
```
4954

5055
## Key Idea
5156

52-
Let's start by reminding ourselves of the theory and then see how the numerics fit in.
57+
First we remind ourselves of the theory and then we turn to numerical methods.
5358

5459
### Theory
5560

56-
Take the model set out in {doc}`Cake Eating IV <cake_eating_time_iter>`, following the same terminology and notation.
61+
We work with the model set out in {doc}`cake_eating_time_iter`, following the same terminology and notation.
5762

5863
The Euler equation is
5964

@@ -79,24 +84,27 @@ u'(c)
7984

8085
### Exogenous Grid
8186

82-
As discussed in {doc}`Cake Eating IV <cake_eating_time_iter>`, to implement the method on a computer, we need a numerical approximation.
87+
As discussed in {doc}`cake_eating_time_iter`, to implement the method on a
88+
computer, we need numerical approximation.
8389

8490
In particular, we represent a policy function by a set of values on a finite grid.
8591

86-
The function itself is reconstructed from this representation when necessary, using interpolation or some other method.
92+
The function itself is reconstructed from this representation when necessary,
93+
using interpolation or some other method.
8794

88-
{doc}`Previously <cake_eating_time_iter>`, to obtain a finite representation of an updated consumption policy, we
95+
Our {doc}`previous strategy <cake_eating_time_iter>` for obtaining a finite representation of an updated consumption policy was to
8996

90-
* fixed a grid of income points $\{x_i\}$
91-
* calculated the consumption value $c_i$ corresponding to each
92-
$x_i$ using {eq}`egm_coledef` and a root-finding routine
97+
* fix a grid of income points $\{x_i\}$
98+
* calculate the consumption value $c_i$ corresponding to each $x_i$ using
99+
{eq}`egm_coledef` and a root-finding routine
93100

94101
Each $c_i$ is then interpreted as the value of the function $K \sigma$ at $x_i$.
95102

96-
Thus, with the points $\{x_i, c_i\}$ in hand, we can reconstruct $K \sigma$ via approximation.
103+
Thus, with the pairs $\{(x_i, c_i)\}$ in hand, we can reconstruct $K \sigma$ via approximation.
97104

98105
Iteration then continues...
99106

107+
100108
### Endogenous Grid
101109

102110
The method discussed above requires a root-finding routine to find the
@@ -105,7 +113,7 @@ $c_i$ corresponding to a given income value $x_i$.
105113
Root-finding is costly because it typically involves a significant number of
106114
function evaluations.
107115

108-
As pointed out by Carroll {cite}`Carroll2006`, we can avoid this if
116+
As pointed out by Carroll {cite}`Carroll2006`, we can avoid this step if
109117
$x_i$ is chosen endogenously.
110118

111119
The only assumption required is that $u'$ is invertible on $(0, \infty)$.
@@ -114,7 +122,7 @@ Let $(u')^{-1}$ be the inverse function of $u'$.
114122

115123
The idea is this:
116124

117-
* First, we fix an *exogenous* grid $\{k_i\}$ for capital ($k = x - c$).
125+
* First, we fix an *exogenous* grid $\{s_i\}$ for savings ($s = x - c$).
118126
* Then we obtain $c_i$ via
119127

120128
```{math}
@@ -123,28 +131,28 @@ The idea is this:
123131
c_i =
124132
(u')^{-1}
125133
\left\{
126-
\beta \int (u' \circ \sigma) (f(k_i) z ) \, f'(k_i) \, z \, \phi(dz)
134+
\beta \int (u' \circ \sigma) (f(s_i) z ) \, f'(s_i) \, z \, \phi(dz)
127135
\right\}
128136
```
129137

130-
* Finally, for each $c_i$ we set $x_i = c_i + k_i$.
138+
* Finally, for each $c_i$ we set $x_i = c_i + s_i$.
131139

132-
It is clear that each $(x_i, c_i)$ pair constructed in this manner satisfies {eq}`egm_coledef`.
140+
Importantly, each $(x_i, c_i)$ pair constructed in this manner satisfies {eq}`egm_coledef`.
133141

134142
With the points $\{x_i, c_i\}$ in hand, we can reconstruct $K \sigma$ via approximation as before.
135143

136-
The name EGM comes from the fact that the grid $\{x_i\}$ is determined **endogenously**.
144+
The name EGM comes from the fact that the grid $\{x_i\}$ is determined **endogenously**.
145+
137146

138147
## Implementation
139148

140-
As in {doc}`Cake Eating IV <cake_eating_time_iter>`, we will start with a simple setting
141-
where
149+
As in {doc}`cake_eating_time_iter`, we will start with a simple setting where
142150

143151
* $u(c) = \ln c$,
144-
* production is Cobb-Douglas, and
152+
* the function $f$ has a Cobb-Douglas specification, and
145153
* the shocks are lognormal.
146154

147-
This will allow us to make comparisons with the analytical solutions
155+
This will allow us to make comparisons with the analytical solutions.
148156

149157
```{code-cell} python3
150158
def v_star(x, α, β, μ):
@@ -164,7 +172,7 @@ def σ_star(x, α, β):
164172
return (1 - α * β) * x
165173
```
166174

167-
We reuse the `Model` structure from {doc}`Cake Eating IV <cake_eating_time_iter>`.
175+
We reuse the `Model` structure from {doc}`cake_eating_time_iter`.
168176

169177
```{code-cell} python3
170178
from typing import NamedTuple, Callable
@@ -174,8 +182,8 @@ class Model(NamedTuple):
174182
f: Callable # production function
175183
β: float # discount factor
176184
μ: float # shock location parameter
177-
s: float # shock scale parameter
178-
grid: np.ndarray # state grid
185+
ν: float # shock scale parameter
186+
s_grid: np.ndarray # exogenous savings grid
179187
shocks: np.ndarray # shock draws
180188
α: float # production function parameter
181189
u_prime: Callable # derivative of utility
@@ -187,7 +195,7 @@ def create_model(u: Callable,
187195
f: Callable,
188196
β: float = 0.96,
189197
μ: float = 0.0,
190-
s: float = 0.1,
198+
ν: float = 0.1,
191199
grid_max: float = 4.0,
192200
grid_size: int = 120,
193201
shock_size: int = 250,
@@ -199,53 +207,59 @@ def create_model(u: Callable,
199207
"""
200208
Creates an instance of the cake eating model.
201209
"""
202-
# Set up grid
203-
grid = np.linspace(1e-4, grid_max, grid_size)
210+
# Set up exogenous savings grid
211+
s_grid = np.linspace(1e-4, grid_max, grid_size)
204212
205213
# Store shocks (with a seed, so results are reproducible)
206214
np.random.seed(seed)
207-
shocks = np.exp(μ + s * np.random.randn(shock_size))
215+
shocks = np.exp(μ + ν * np.random.randn(shock_size))
208216
209-
return Model(u=u, f=f, β=β, μ=μ, s=s, grid=grid, shocks=shocks,
210-
α=α, u_prime=u_prime, f_prime=f_prime, u_prime_inv=u_prime_inv)
217+
return Model(u, f, β, μ, ν, s_grid, shocks, α, u_prime, f_prime, u_prime_inv)
211218
```
212219

213220
### The Operator
214221

215222
Here's an implementation of $K$ using EGM as described above.
216223

217224
```{code-cell} python3
218-
def K(σ_array: np.ndarray, model: Model) -> np.ndarray:
225+
def K(
226+
c_in: np.ndarray, # Consumption values on the endogenous grid
227+
x_in: np.ndarray, # Current endogenous grid
228+
model: Model # Model specification
229+
):
219230
"""
220-
The Coleman-Reffett operator using EGM
231+
An implementation of the Coleman-Reffett operator using EGM.
221232
222233
"""
223234
224235
# Simplify names
225-
f, β = model.f, model.β
226-
f_prime, u_prime = model.f_prime, model.u_prime
227-
u_prime_inv = model.u_prime_inv
228-
grid, shocks = model.grid, model.shocks
229-
230-
# Determine endogenous grid
231-
x = grid + σ_array # x_i = k_i + c_i
236+
u, f, β, μ, ν, s_grid, shocks, α, u_prime, f_prime, u_prime_inv = model
232237
233-
# Linear interpolation of policy using endogenous grid
234-
σ = lambda x_val: np.interp(x_val, x, σ_array)
238+
# Linear interpolation of policy on the endogenous grid
239+
σ = lambda x: np.interp(x, x_in, c_in)
235240
236241
# Allocate memory for new consumption array
237-
c = np.empty_like(grid)
242+
c_out = np.empty_like(s_grid)
238243
239244
# Solve for updated consumption value
240-
for i, k in enumerate(grid):
241-
vals = u_prime(σ(f(k) * shocks)) * f_prime(k) * shocks
242-
c[i] = u_prime_inv(β * np.mean(vals))
245+
for i, s in enumerate(s_grid):
246+
vals = u_prime(σ(f(s) * shocks)) * f_prime(s) * shocks
247+
c_out[i] = u_prime_inv(β * np.mean(vals))
248+
249+
# Determine corresponding endogenous grid
250+
x_out = s_grid + c_out # x_i = s_i + c_i
243251
244-
return c
252+
return c_out, x_out
245253
```
246254

247255
Note the lack of any root-finding algorithm.
248256

257+
```{note}
258+
The routine is still not particularly fast because we are using pure Python loops.
259+
260+
But in the next lecture ({doc}`cake_eating_egm_jax`) we will use a fully vectorized and efficient solution.
261+
```
262+
249263
### Testing
250264

251265
First we create an instance.
@@ -261,53 +275,53 @@ f_prime = lambda k: α * k**(α - 1)
261275
262276
model = create_model(u=u, f=f, α=α, u_prime=u_prime,
263277
f_prime=f_prime, u_prime_inv=u_prime_inv)
264-
grid = model.grid
278+
s_grid = model.s_grid
265279
```
266280

267281
Here's our solver routine:
268282

269283
```{code-cell} python3
270284
def solve_model_time_iter(model: Model,
271-
σ_init: np.ndarray,
285+
c_init: np.ndarray,
286+
x_init: np.ndarray,
272287
tol: float = 1e-5,
273288
max_iter: int = 1000,
274-
verbose: bool = True) -> np.ndarray:
289+
verbose: bool = True):
275290
"""
276291
Solve the model using time iteration with EGM.
277292
"""
278-
σ = σ_init
293+
c, x = c_init, x_init
279294
error = tol + 1
280295
i = 0
281296
282297
while error > tol and i < max_iter:
283-
σ_new = K(σ, model)
284-
error = np.max(np.abs(σ_new - σ))
285-
σ = σ_new
298+
c_new, x_new = K(c, x, model)
299+
error = np.max(np.abs(c_new - c))
300+
c, x = c_new, x_new
286301
i += 1
287302
if verbose:
288303
print(f"Iteration {i}, error = {error}")
289304
290305
if i == max_iter:
291306
print("Warning: maximum iterations reached")
292307
293-
return σ
308+
return c, x
294309
```
295310

296311
Let's call it:
297312

298313
```{code-cell} python3
299-
σ_init = np.copy(grid)
300-
σ = solve_model_time_iter(model, σ_init)
314+
c_init = np.copy(s_grid)
315+
x_init = s_grid + c_init
316+
c, x = solve_model_time_iter(model, c_init, x_init)
301317
```
302318

303319
Here is a plot of the resulting policy, compared with the true policy:
304320

305321
```{code-cell} python3
306-
x = grid + σ # x_i = k_i + c_i
307-
308322
fig, ax = plt.subplots()
309323
310-
ax.plot(x, σ, lw=2,
324+
ax.plot(x, c, lw=2,
311325
alpha=0.8, label='approximate policy function')
312326
313327
ax.plot(x, σ_star(x, model.α, model.β), 'k--',
@@ -320,16 +334,22 @@ plt.show()
320334
The maximal absolute deviation between the two policies is
321335

322336
```{code-cell} python3
323-
np.max(np.abs(σ - σ_star(x, model.α, model.β)))
337+
np.max(np.abs(c - σ_star(x, model.α, model.β)))
324338
```
325339

326340
Here's the execution time:
327341

328342
```{code-cell} python3
329343
with qe.Timer():
330-
σ = solve_model_time_iter(model, σ_init, verbose=False)
344+
c, x = solve_model_time_iter(model, c_init, x_init, verbose=False)
331345
```
332346

333347
EGM is faster than time iteration because it avoids numerical root-finding.
334348

335349
Instead, we invert the marginal utility function directly, which is much more efficient.
350+
351+
In the {doc}`next lecture <cake_eating_egm_jax>`, we will use a fully vectorized
352+
and efficient version of EGM that is also parallelized using JAX.
353+
354+
This provides an extremely fast way to solve the optimal consumption problem we
355+
have been studying for the last few lectures.

0 commit comments

Comments
 (0)