@@ -4,7 +4,7 @@ jupytext:
44 extension : .md
55 format_name : myst
66 format_version : 0.13
7- jupytext_version : 1.16.1
7+ jupytext_version : 1.17.2
88kernelspec :
99 display_name : Python 3 (ipykernel)
1010 language : python
@@ -57,6 +57,9 @@ import jax
5757import jax.numpy as jnp
5858from jax import random
5959from jax import lax
60+ from quantecon import tic, toc
61+ from typing import NamedTuple
62+ from functools import partial
6063```
6164
6265Let's check the GPU we are running
@@ -168,29 +171,36 @@ We can investigate this question via simulation and rank-size plots.
168171
169172The approach will be to
170173
171- 1 . generate $M$ draws of $s_T$ when $M$ and $T$ are
172- large and
174+ 1 . generate $M$ draws of $s_T$ when $M$ and $T$ are large and
1731751 . plot the largest 1,000 of the resulting draws in a rank-size plot.
174176
175177(The distribution of $s_T$ will be close to the stationary distribution
176178when $T$ is large.)
177179
178180In the simulation, we assume that each of $a_t, b_t$ and $e_t$ is lognormal.
179181
180- Here's code to update a cross-section of firms according to the dynamics in
181- [ ] ( firm_dynam_ee ) .
182+ Here's a class to store parameters:
182183
183184``` {code-cell} ipython3
184- @jax.jit
185- def update_s(s, s_bar, a_random, b_random, e_random):
186- exp_a = jnp.exp(a_random)
187- exp_b = jnp.exp(b_random)
188- exp_e = jnp.exp(e_random)
189-
190- s = jnp.where(s < s_bar,
191- exp_e,
192- exp_a * s + exp_b)
185+ class Firm(NamedTuple):
186+ μ_a: float = -0.5
187+ σ_a: float = 0.1
188+ μ_b: float = 0.0
189+ σ_b: float = 0.5
190+ μ_e: float = 0.0
191+ σ_e: float = 0.5
192+ s_bar: float = 1.0
193+
194+ #
195+ # Here's code to update a cross-section of firms according to the dynamics in
196+ # [](firm_dynam_ee).
197+ ```
193198
199+ ``` {code-cell} ipython3
200+ @jax.jit
201+ def update_cross_section(s, a, b, e, firm):
202+ μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar = firm
203+ s = jnp.where(s < s_bar, e, a * s + b)
194204 return s
195205```
196206
@@ -201,51 +211,45 @@ For sufficiently large `T`, the cross-section it returns (the cross-section at
201211time ` T ` ) corresponds to firm size distribution in (approximate) equilibrium.
202212
203213``` {code-cell} ipython3
204- def generate_draws(M=1_000_000,
205- μ_a=-0.5,
206- σ_a=0.1,
207- μ_b=0.0,
208- σ_b=0.5,
209- μ_e=0.0,
210- σ_e=0.5,
211- s_bar=1.0,
212- T=500,
213- s_init=1.0,
214- seed=123):
214+ def generate_cross_section(
215+ firm, M=1_000_000, T=500, s_init=1.0, seed=123
216+ ):
215217
218+ μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar = firm
216219 key = random.PRNGKey(seed)
217220
218- # Initialize the array of s values with the initial value
221+ # Initialize the cross-section to a common value
219222 s = jnp.full((M, ), s_init)
220223
221224 # Perform updates on s for time t
222225 for t in range(T):
223- keys = random.split(key, 3)
224- a_random = μ_a + σ_a * random.normal(keys[0], (M, ))
225- b_random = μ_b + σ_b * random.normal(keys[1], (M, ))
226- e_random = μ_e + σ_e * random.normal(keys[2], (M, ))
227-
228- s = update_s(s, s_bar, a_random, b_random, e_random)
229-
230- # Generate new key for the next iteration
231- key = random.fold_in(key, t)
226+ key, *subkeys = random.split(key, 4)
227+ a = μ_a + σ_a * random.normal(subkeys[0], (M,))
228+ b = μ_b + σ_b * random.normal(subkeys[1], (M,))
229+ e = μ_e + σ_e * random.normal(subkeys[2], (M,))
230+ # Exponentiate shocks
231+ a, b, e = jax.tree.map(jnp.exp, (a, b, e))
232+ # Update the cross-section of firms
233+ s = update_cross_section(s, a, b, e, firm)
232234
233235 return s
236+ ```
234237
235- %time data = generate_draws().block_until_ready()
238+ ``` {code-cell} ipython3
239+ firm = Firm()
240+ tic()
241+ data = generate_cross_section(firm).block_until_ready()
242+ toc()
236243```
237244
238245Running the above function again so we can see the speed with and without compile time.
239246
240247``` {code-cell} ipython3
241- %time data = generate_draws().block_until_ready()
248+ tic()
249+ data = generate_cross_section(firm).block_until_ready()
250+ toc()
242251```
243252
244- Notice that we do not JIT-compile the ` for ` loops, since
245-
246- 1 . acceleration of the outer loop makes little difference terms of compute
247- time and
248- 2 . compiling the outer loop is often very slow.
249253
250254Let's produce the rank-size plot and check the distribution:
251255
@@ -265,64 +269,62 @@ The plot produces a straight line, consistent with a Pareto tail.
265269
266270#### Alternative implementation with ` lax.fori_loop `
267271
268- If the time horizon is not too large, we can try to further accelerate our code
272+ We did not JIT-compile the ` for ` loop above because
273+ acceleration of outer loops makes relatively little difference terms of
274+ compute time.
275+
276+ However, to maximize performance, let's try squeezing out a bit more speed
269277by replacing the ` for ` loop with
270278[ ` lax.fori_loop ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html ) .
271279
272- Note, however, that
273-
274- 1 . as mentioned above, there is not much speed gain in accelerating outer loops,
275- 2 . ` lax.fori_loop ` has a more complicated syntax, and, most importantly,
276- 3 . the ` lax.fori_loop ` implementation consumes far more memory, as we need to have to
277- store large matrices of random draws
278-
279- Hence the code below will fail due to out-of-memory errors when ` T ` and ` M ` are large.
280-
281- Here is the ` lax.fori_loop ` version:
280+ Here a the ` lax.fori_loop ` version:
282281
283282``` {code-cell} ipython3
284283@jax.jit
285- def generate_draws_lax(μ_a=-0.5,
286- σ_a=0.1,
287- μ_b=0.0,
288- σ_b=0.5,
289- μ_e=0.0,
290- σ_e=0.5,
291- s_bar=1.0,
292- T=500,
293- M=500_000,
294- s_init=1.0,
295- seed=123):
284+ def generate_cross_section_lax(
285+ firm, T=500, M=500_000, s_init=1.0, seed=123
286+ ):
296287
288+ μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar = firm
297289 key = random.PRNGKey(seed)
298- keys = random.split(key, 3)
299290
300- # Generate random draws and initial values
301- a_random = μ_a + σ_a * random.normal(keys[0], (T, M))
302- b_random = μ_b + σ_b * random.normal(keys[1], (T, M))
303- e_random = μ_e + σ_e * random.normal(keys[2], (T, M))
291+ # Initial cross section
304292 s = jnp.full((M, ), s_init)
305293
306- # Define the function for each update
307- def update_s(i, s):
308- a, b, e = a_random[i], b_random[i], e_random[i]
309- s = jnp.where(s < s_bar,
310- jnp.exp(e),
311- jnp.exp(a) * s + jnp.exp(b))
312- return s
313-
314- # Use lax.scan to perform the calculations on all states
315- s_final = lax.fori_loop(0, T, update_s, s)
316- return s_final
317-
318- %time data = generate_draws_lax().block_until_ready()
294+ def update_cross_section(t, state):
295+ s, key = state
296+ key, *subkeys = jax.random.split(key, 4)
297+ # Generate current random draws
298+ a = μ_a + σ_a * random.normal(subkeys[0], (M,))
299+ b = μ_b + σ_b * random.normal(subkeys[1], (M,))
300+ e = μ_e + σ_e * random.normal(subkeys[2], (M,))
301+ # Exponentiate them
302+ a, b, e = jax.tree.map(jnp.exp, (a, b, e))
303+ # Pull out the t-th cross-section of shocks
304+ s = jnp.where(s < s_bar, e, a * s + b)
305+ new_state = s, key
306+ return new_state
307+
308+ # Use fori_loop
309+ initial_state = s, key
310+ final_s, final_key = lax.fori_loop(
311+ 0, T, update_cross_section, initial_state
312+ )
313+ return final_s
314+
315+ # Let's see if we got any speed gain
319316```
320317
321- In this case, ` M ` is small enough for the code to run and
322- we see some speed gain over the for loop implementation:
318+ ``` {code-cell} ipython3
319+ tic()
320+ data = generate_cross_section_lax(firm).block_until_ready()
321+ toc()
322+ ```
323323
324324``` {code-cell} ipython3
325- %time data = generate_draws_lax().block_until_ready()
325+ tic()
326+ data = generate_cross_section_lax(firm).block_until_ready()
327+ toc()
326328```
327329
328330Here we produce the same rank-size plot:
@@ -336,19 +338,61 @@ ax.set_xlabel("log rank")
336338ax.set_ylabel("log size")
337339
338340plt.show()
339- ```
340341
341- Let's rerun the ` for ` loop version on smaller ` M ` to compare the speed
342+ #
343+ # If the time horizon is not too large, we can also try generating all shocks at
344+ # once.
345+ #
346+ # Note, however, that this approach consumes more memory, as we need to have to
347+ # store large matrices of random draws
348+ #
349+ # Hence the code below will fail due to out-of-memory errors when `T` and `M` are large.
350+ ```
342351
343352``` {code-cell} ipython3
344- %time generate_draws(M=500_000).block_until_ready()
353+ @jax.jit
354+ def generate_cross_section_lax(
355+ firm, T=500, M=500_000, s_init=1.0, seed=123
356+ ):
357+
358+ μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar = firm
359+ key = random.PRNGKey(seed)
360+ subkey_1, subkey_2, subkey_3 = random.split(key, 3)
361+
362+ # Generate entire sequence of random draws
363+ a = μ_a + σ_a * random.normal(subkey_1, (T, M))
364+ b = μ_b + σ_b * random.normal(subkey_2, (T, M))
365+ e = μ_e + σ_e * random.normal(subkey_3, (T, M))
366+ # Exponentiate them
367+ a, b, e = jax.tree.map(jnp.exp, (a, b, e))
368+ # Initial cross section
369+ s = jnp.full((M, ), s_init)
370+
371+ def update_cross_section(t, s):
372+ # Pull out the t-th cross-section of shocks
373+ a_t, b_t, e_t = a[t], b[t], e[t]
374+ s = jnp.where(s < s_bar, e_t, a_t * s + b_t)
375+ return s
376+
377+ # Use lax.scan to perform the calculations on all states
378+ s_final = lax.fori_loop(0, T, update_cross_section, s)
379+ return s_final
345380```
346381
347- Let's run it again to get rid of the compile time.
382+ Here are the run times.
383+
384+ ``` {code-cell} ipython3
385+ tic()
386+ data = generate_cross_section_lax(firm).block_until_ready()
387+ toc()
388+ ```
348389
349390``` {code-cell} ipython3
350- %time generate_draws(M=500_000).block_until_ready()
391+ tic()
392+ data = generate_cross_section_lax(firm).block_until_ready()
393+ toc()
351394```
352395
353- We see that the ` lax.fori_loop ` version is faster than the ` for ` loop version
354- when memory is not an issue.
396+ This second method might be slightly faster in some cases but in general the
397+ relative speed will depend on the size of the cross-section and the length of
398+ the simulation paths.
0 commit comments