Skip to content

Commit 8ff32f3

Browse files
committed
misc
1 parent 0528e32 commit 8ff32f3

File tree

2 files changed

+136
-93
lines changed

2 files changed

+136
-93
lines changed

lectures/jax_lucas.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

lectures/kesten_processes.md

Lines changed: 136 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ jupytext:
44
extension: .md
55
format_name: myst
66
format_version: 0.13
7-
jupytext_version: 1.16.1
7+
jupytext_version: 1.17.2
88
kernelspec:
99
display_name: Python 3 (ipykernel)
1010
language: python
@@ -57,6 +57,9 @@ import jax
5757
import jax.numpy as jnp
5858
from jax import random
5959
from jax import lax
60+
from quantecon import tic, toc
61+
from typing import NamedTuple
62+
from functools import partial
6063
```
6164

6265
Let's check the GPU we are running
@@ -168,29 +171,36 @@ We can investigate this question via simulation and rank-size plots.
168171

169172
The approach will be to
170173

171-
1. generate $M$ draws of $s_T$ when $M$ and $T$ are
172-
large and
174+
1. generate $M$ draws of $s_T$ when $M$ and $T$ are large and
173175
1. plot the largest 1,000 of the resulting draws in a rank-size plot.
174176

175177
(The distribution of $s_T$ will be close to the stationary distribution
176178
when $T$ is large.)
177179

178180
In the simulation, we assume that each of $a_t, b_t$ and $e_t$ is lognormal.
179181

180-
Here's code to update a cross-section of firms according to the dynamics in
181-
[](firm_dynam_ee).
182+
Here's a class to store parameters:
182183

183184
```{code-cell} ipython3
184-
@jax.jit
185-
def update_s(s, s_bar, a_random, b_random, e_random):
186-
exp_a = jnp.exp(a_random)
187-
exp_b = jnp.exp(b_random)
188-
exp_e = jnp.exp(e_random)
189-
190-
s = jnp.where(s < s_bar,
191-
exp_e,
192-
exp_a * s + exp_b)
185+
class Firm(NamedTuple):
186+
μ_a: float = -0.5
187+
σ_a: float = 0.1
188+
μ_b: float = 0.0
189+
σ_b: float = 0.5
190+
μ_e: float = 0.0
191+
σ_e: float = 0.5
192+
s_bar: float = 1.0
193+
194+
#
195+
# Here's code to update a cross-section of firms according to the dynamics in
196+
# [](firm_dynam_ee).
197+
```
193198

199+
```{code-cell} ipython3
200+
@jax.jit
201+
def update_cross_section(s, a, b, e, firm):
202+
μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar = firm
203+
s = jnp.where(s < s_bar, e, a * s + b)
194204
return s
195205
```
196206

@@ -201,51 +211,45 @@ For sufficiently large `T`, the cross-section it returns (the cross-section at
201211
time `T`) corresponds to firm size distribution in (approximate) equilibrium.
202212

203213
```{code-cell} ipython3
204-
def generate_draws(M=1_000_000,
205-
μ_a=-0.5,
206-
σ_a=0.1,
207-
μ_b=0.0,
208-
σ_b=0.5,
209-
μ_e=0.0,
210-
σ_e=0.5,
211-
s_bar=1.0,
212-
T=500,
213-
s_init=1.0,
214-
seed=123):
214+
def generate_cross_section(
215+
firm, M=1_000_000, T=500, s_init=1.0, seed=123
216+
):
215217
218+
μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar = firm
216219
key = random.PRNGKey(seed)
217220
218-
# Initialize the array of s values with the initial value
221+
# Initialize the cross-section to a common value
219222
s = jnp.full((M, ), s_init)
220223
221224
# Perform updates on s for time t
222225
for t in range(T):
223-
keys = random.split(key, 3)
224-
a_random = μ_a + σ_a * random.normal(keys[0], (M, ))
225-
b_random = μ_b + σ_b * random.normal(keys[1], (M, ))
226-
e_random = μ_e + σ_e * random.normal(keys[2], (M, ))
227-
228-
s = update_s(s, s_bar, a_random, b_random, e_random)
229-
230-
# Generate new key for the next iteration
231-
key = random.fold_in(key, t)
226+
key, *subkeys = random.split(key, 4)
227+
a = μ_a + σ_a * random.normal(subkeys[0], (M,))
228+
b = μ_b + σ_b * random.normal(subkeys[1], (M,))
229+
e = μ_e + σ_e * random.normal(subkeys[2], (M,))
230+
# Exponentiate shocks
231+
a, b, e = jax.tree.map(jnp.exp, (a, b, e))
232+
# Update the cross-section of firms
233+
s = update_cross_section(s, a, b, e, firm)
232234
233235
return s
236+
```
234237

235-
%time data = generate_draws().block_until_ready()
238+
```{code-cell} ipython3
239+
firm = Firm()
240+
tic()
241+
data = generate_cross_section(firm).block_until_ready()
242+
toc()
236243
```
237244

238245
Running the above function again so we can see the speed with and without compile time.
239246

240247
```{code-cell} ipython3
241-
%time data = generate_draws().block_until_ready()
248+
tic()
249+
data = generate_cross_section(firm).block_until_ready()
250+
toc()
242251
```
243252

244-
Notice that we do not JIT-compile the `for` loops, since
245-
246-
1. acceleration of the outer loop makes little difference terms of compute
247-
time and
248-
2. compiling the outer loop is often very slow.
249253

250254
Let's produce the rank-size plot and check the distribution:
251255

@@ -265,64 +269,62 @@ The plot produces a straight line, consistent with a Pareto tail.
265269

266270
#### Alternative implementation with `lax.fori_loop`
267271

268-
If the time horizon is not too large, we can try to further accelerate our code
272+
We did not JIT-compile the `for` loop above because
273+
acceleration of outer loops makes relatively little difference terms of
274+
compute time.
275+
276+
However, to maximize performance, let's try squeezing out a bit more speed
269277
by replacing the `for` loop with
270278
[`lax.fori_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html).
271279

272-
Note, however, that
273-
274-
1. as mentioned above, there is not much speed gain in accelerating outer loops,
275-
2. `lax.fori_loop` has a more complicated syntax, and, most importantly,
276-
3. the `lax.fori_loop` implementation consumes far more memory, as we need to have to
277-
store large matrices of random draws
278-
279-
Hence the code below will fail due to out-of-memory errors when `T` and `M` are large.
280-
281-
Here is the `lax.fori_loop` version:
280+
Here a the `lax.fori_loop` version:
282281

283282
```{code-cell} ipython3
284283
@jax.jit
285-
def generate_draws_lax(μ_a=-0.5,
286-
σ_a=0.1,
287-
μ_b=0.0,
288-
σ_b=0.5,
289-
μ_e=0.0,
290-
σ_e=0.5,
291-
s_bar=1.0,
292-
T=500,
293-
M=500_000,
294-
s_init=1.0,
295-
seed=123):
284+
def generate_cross_section_lax(
285+
firm, T=500, M=500_000, s_init=1.0, seed=123
286+
):
296287
288+
μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar = firm
297289
key = random.PRNGKey(seed)
298-
keys = random.split(key, 3)
299290
300-
# Generate random draws and initial values
301-
a_random = μ_a + σ_a * random.normal(keys[0], (T, M))
302-
b_random = μ_b + σ_b * random.normal(keys[1], (T, M))
303-
e_random = μ_e + σ_e * random.normal(keys[2], (T, M))
291+
# Initial cross section
304292
s = jnp.full((M, ), s_init)
305293
306-
# Define the function for each update
307-
def update_s(i, s):
308-
a, b, e = a_random[i], b_random[i], e_random[i]
309-
s = jnp.where(s < s_bar,
310-
jnp.exp(e),
311-
jnp.exp(a) * s + jnp.exp(b))
312-
return s
313-
314-
# Use lax.scan to perform the calculations on all states
315-
s_final = lax.fori_loop(0, T, update_s, s)
316-
return s_final
317-
318-
%time data = generate_draws_lax().block_until_ready()
294+
def update_cross_section(t, state):
295+
s, key = state
296+
key, *subkeys = jax.random.split(key, 4)
297+
# Generate current random draws
298+
a = μ_a + σ_a * random.normal(subkeys[0], (M,))
299+
b = μ_b + σ_b * random.normal(subkeys[1], (M,))
300+
e = μ_e + σ_e * random.normal(subkeys[2], (M,))
301+
# Exponentiate them
302+
a, b, e = jax.tree.map(jnp.exp, (a, b, e))
303+
# Pull out the t-th cross-section of shocks
304+
s = jnp.where(s < s_bar, e, a * s + b)
305+
new_state = s, key
306+
return new_state
307+
308+
# Use fori_loop
309+
initial_state = s, key
310+
final_s, final_key = lax.fori_loop(
311+
0, T, update_cross_section, initial_state
312+
)
313+
return final_s
314+
315+
# Let's see if we got any speed gain
319316
```
320317

321-
In this case, `M` is small enough for the code to run and
322-
we see some speed gain over the for loop implementation:
318+
```{code-cell} ipython3
319+
tic()
320+
data = generate_cross_section_lax(firm).block_until_ready()
321+
toc()
322+
```
323323

324324
```{code-cell} ipython3
325-
%time data = generate_draws_lax().block_until_ready()
325+
tic()
326+
data = generate_cross_section_lax(firm).block_until_ready()
327+
toc()
326328
```
327329

328330
Here we produce the same rank-size plot:
@@ -336,19 +338,61 @@ ax.set_xlabel("log rank")
336338
ax.set_ylabel("log size")
337339
338340
plt.show()
339-
```
340341
341-
Let's rerun the `for` loop version on smaller `M` to compare the speed
342+
#
343+
# If the time horizon is not too large, we can also try generating all shocks at
344+
# once.
345+
#
346+
# Note, however, that this approach consumes more memory, as we need to have to
347+
# store large matrices of random draws
348+
#
349+
# Hence the code below will fail due to out-of-memory errors when `T` and `M` are large.
350+
```
342351

343352
```{code-cell} ipython3
344-
%time generate_draws(M=500_000).block_until_ready()
353+
@jax.jit
354+
def generate_cross_section_lax(
355+
firm, T=500, M=500_000, s_init=1.0, seed=123
356+
):
357+
358+
μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar = firm
359+
key = random.PRNGKey(seed)
360+
subkey_1, subkey_2, subkey_3 = random.split(key, 3)
361+
362+
# Generate entire sequence of random draws
363+
a = μ_a + σ_a * random.normal(subkey_1, (T, M))
364+
b = μ_b + σ_b * random.normal(subkey_2, (T, M))
365+
e = μ_e + σ_e * random.normal(subkey_3, (T, M))
366+
# Exponentiate them
367+
a, b, e = jax.tree.map(jnp.exp, (a, b, e))
368+
# Initial cross section
369+
s = jnp.full((M, ), s_init)
370+
371+
def update_cross_section(t, s):
372+
# Pull out the t-th cross-section of shocks
373+
a_t, b_t, e_t = a[t], b[t], e[t]
374+
s = jnp.where(s < s_bar, e_t, a_t * s + b_t)
375+
return s
376+
377+
# Use lax.scan to perform the calculations on all states
378+
s_final = lax.fori_loop(0, T, update_cross_section, s)
379+
return s_final
345380
```
346381

347-
Let's run it again to get rid of the compile time.
382+
Here are the run times.
383+
384+
```{code-cell} ipython3
385+
tic()
386+
data = generate_cross_section_lax(firm).block_until_ready()
387+
toc()
388+
```
348389

349390
```{code-cell} ipython3
350-
%time generate_draws(M=500_000).block_until_ready()
391+
tic()
392+
data = generate_cross_section_lax(firm).block_until_ready()
393+
toc()
351394
```
352395

353-
We see that the `lax.fori_loop` version is faster than the `for` loop version
354-
when memory is not an issue.
396+
This second method might be slightly faster in some cases but in general the
397+
relative speed will depend on the size of the cross-section and the length of
398+
the simulation paths.

0 commit comments

Comments
 (0)