@@ -4,7 +4,7 @@ jupytext:
44 extension : .md
55 format_name : myst
66 format_version : 0.13
7- jupytext_version : 1.16.1
7+ jupytext_version : 1.17.2
88kernelspec :
99 display_name : Python 3 (ipykernel)
1010 language : python
@@ -56,7 +56,8 @@ import numpy as np
5656import jax
5757import jax.numpy as jnp
5858from jax import random, lax
59- from collections import namedtuple
59+ from typing import NamedTuple
60+ from time import time
6061```
6162
6263Here's a description of our GPU:
@@ -97,10 +98,11 @@ and standard normal.
9798Here's a ` namedtuple ` that stores parameters.
9899
99100``` {code-cell} ipython3
100- Parameters = namedtuple('Parameters', ['s', 'S', 'μ', 'σ'])
101-
102- # Create a default instance
103- params = Parameters(s=10, S=100, μ=1.0, σ=0.5)
101+ class ModelParameters(NamedTuple):
102+ s: int = 10
103+ S: int = 100
104+ μ: float = 1.0
105+ σ: float = 0.5
104106```
105107
106108## Cross-sectional distributions
@@ -126,19 +128,21 @@ We will use the following code to update the cross-section of firms by one perio
126128
127129``` {code-cell} ipython3
128130@jax.jit
129- def update_cross_section(params, X_vec, D):
131+ def update_cross_section(params: ModelParameters,
132+ X_vec: jnp.ndarray,
133+ D: jnp.ndarray) -> jnp.ndarray:
130134 """
131135 Update by one period a cross-section of firms with inventory levels given by
132- X_vec, given the vector of demand shocks in D.
133-
134- * D[i] is the demand shock for firm i with current inventory X_vec[i]
136+ X_vec, given the vector of demand shocks in D. Here D[i] is the demand shock
137+ for firm i with current inventory X_vec[i].
135138
136139 """
137140 # Unpack
138141 s, S = params.s, params.S
139142 # Restock if the inventory is below the threshold
140143 X_new = jnp.where(X_vec <= s,
141- jnp.maximum(S - D, 0), jnp.maximum(X_vec - D, 0))
144+ jnp.maximum(S - D, 0),
145+ jnp.maximum(X_vec - D, 0))
142146 return X_new
143147```
144148
@@ -149,17 +153,18 @@ initial distribution $\psi_0$ and a positive integer $T$.
149153
150154In this code we use an ordinary Python ` for ` loop to step forward through time
151155
152- While Python loops are slow, this approach is reasonable here because
153- efficiency of outer loops has far less influence on runtime than efficiency of inner loops.
154-
155156(Below we will squeeze out more speed by compiling the outer loop as well as the
156157update rule.)
157158
158159In the code below, the initial distribution $\psi_0$ takes all firms to have
159160initial inventory ` x_init ` .
160161
161162``` {code-cell} ipython3
162- def compute_cross_section(params, x_init, T, key, num_firms=50_000):
163+ def project_cross_section(params: ModelParameters,
164+ x_init: jnp.ndarray,
165+ T: int,
166+ key: jnp.ndarray,
167+ num_firms: int = 50_000) -> jnp.ndarray:
163168 # Set up initial distribution
164169 X_vec = jnp.full((num_firms, ), x_init)
165170 # Loop
@@ -176,6 +181,7 @@ def compute_cross_section(params, x_init, T, key, num_firms=50_000):
176181We'll use the following specification
177182
178183``` {code-cell} ipython3
184+ params = ModelParameters()
179185x_init = 50
180186T = 500
181187# Initialize random number generator
@@ -185,15 +191,21 @@ key = random.PRNGKey(10)
185191Let's look at the timing.
186192
187193``` {code-cell} ipython3
188- %time X_vec = compute_cross_section(params, \
189- x_init, T, key).block_until_ready()
194+ start_time = time()
195+ X_vec = project_cross_section(
196+ params, x_init, T, key).block_until_ready()
197+ end_time = time()
198+ print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms")
190199```
191200
192201Let's run again to eliminate compile time.
193202
194203``` {code-cell} ipython3
195- %time X_vec = compute_cross_section(params, \
196- x_init, T, key).block_until_ready()
204+ start_time = time()
205+ X_vec = project_cross_section(
206+ params, x_init, T, key).block_until_ready()
207+ end_time = time()
208+ print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms")
197209```
198210
199211Here's a histogram of inventory levels at time $T$.
@@ -218,15 +230,21 @@ through the time dimension.
218230We will do this using ` jax.jit ` and a ` fori_loop ` , which is a compiler-ready version of a ` for ` loop provided by JAX.
219231
220232``` {code-cell} ipython3
221- def compute_cross_section_fori(params, x_init, T, key, num_firms=50_000):
233+ def project_cross_section_fori(
234+ params: ModelParameters,
235+ x_init: jnp.ndarray,
236+ T: int,
237+ key: jnp.ndarray,
238+ num_firms: int = 50_000
239+ ) -> jnp.ndarray:
222240
223241 s, S, μ, σ = params.s, params.S, params.μ, params.σ
224242 X = jnp.full((num_firms, ), x_init)
225243
226244 # Define the function for each update
227- def fori_update(t, inputs ):
245+ def fori_update(t, loop_state ):
228246 # Unpack
229- X, key = inputs
247+ X, key = loop_state
230248 # Draw shocks using key
231249 Z = random.normal(key, shape=(num_firms,))
232250 D = jnp.exp(μ + σ * Z)
@@ -239,90 +257,40 @@ def compute_cross_section_fori(params, x_init, T, key, num_firms=50_000):
239257 return X, subkey
240258
241259 # Loop t from 0 to T, applying fori_update each time.
242- # The initial condition for fori_update is (X, key).
243- X, key = lax.fori_loop(0, T, fori_update, (X, key))
244-
260+ initial_loop_state = X, key
261+ X, key = lax.fori_loop(0, T, fori_update, initial_loop_state)
245262 return X
246263
247264# Compile taking T and num_firms as static (changes trigger recompile)
248- compute_cross_section_fori = jax.jit(
249- compute_cross_section_fori , static_argnums=(2, 4))
265+ project_cross_section_fori = jax.jit(
266+ project_cross_section_fori , static_argnums=(2, 4))
250267```
251268
252269Let's see how fast this runs with compile time.
253270
254271``` {code-cell} ipython3
255- %time X_vec = compute_cross_section_fori(params, \
256- x_init, T, key).block_until_ready()
272+ start_time = time()
273+ X_vec = project_cross_section_fori(
274+ params, x_init, T, key).block_until_ready()
275+ end_time = time()
276+ print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms")
257277```
258278
259279And let's see how fast it runs without compile time.
260280
261281``` {code-cell} ipython3
262- %time X_vec = compute_cross_section_fori(params, \
263- x_init, T, key).block_until_ready()
282+ start_time = time()
283+ X_vec = project_cross_section_fori(
284+ params, x_init, T, key).block_until_ready()
285+ end_time = time()
286+ print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms")
264287```
265288
266289Compared to the original version with a pure Python outer loop, we have
267290produced a nontrivial speed gain.
268291
269292
270- This is due to the fact that we have compiled the whole operation.
271-
272-
273-
274-
275- ### Further vectorization
276-
277- For relatively small problems, we can make this code run even faster by generating
278- all random variables at once.
279-
280- This improves efficiency because we are taking more operations out of the loop.
281-
282- ``` {code-cell} ipython3
283- def compute_cross_section_fori(params, x_init, T, key, num_firms=50_000):
284-
285- s, S, μ, σ = params.s, params.S, params.μ, params.σ
286- X = jnp.full((num_firms, ), x_init)
287- Z = random.normal(key, shape=(T, num_firms))
288- D = jnp.exp(μ + σ * Z)
289-
290- def update_cross_section(i, X):
291- X = jnp.where(X <= s,
292- jnp.maximum(S - D[i, :], 0),
293- jnp.maximum(X - D[i, :], 0))
294- return X
295-
296- X = lax.fori_loop(0, T, update_cross_section, X)
297-
298- return X
299-
300- # Compile taking T and num_firms as static (changes trigger recompile)
301- compute_cross_section_fori = jax.jit(
302- compute_cross_section_fori, static_argnums=(2, 4))
303- ```
304-
305- Let's test it with compile time included.
306-
307- ``` {code-cell} ipython3
308- %time X_vec = compute_cross_section_fori(params, \
309- x_init, T, key).block_until_ready()
310- ```
311-
312- Let's run again to eliminate compile time.
313-
314- ``` {code-cell} ipython3
315- %time X_vec = compute_cross_section_fori(params, \
316- x_init, T, key).block_until_ready()
317- ```
318-
319- On one hand, this version is faster than the previous one, where random variables were
320- generated inside the loop.
321-
322- On the other hand, this implementation consumes far more memory, as we need to
323- store large arrays of random draws.
324-
325- The high memory consumption becomes problematic for large problems.
293+ This is due to the fact that we have compiled the entire sequence of operations.
326294
327295
328296
@@ -364,16 +332,8 @@ num_firms = 10_000
364332sample_dates = 10, 50, 250, 500, 750
365333key = random.PRNGKey(10)
366334
367-
368- %time X = shift_forward_and_sample(x_init, params, \
369- sample_dates, key).block_until_ready()
370- ```
371-
372- We run the code again to eliminate compile time.
373-
374- ``` {code-cell} ipython3
375- %time X = shift_forward_and_sample(x_init, params, \
376- sample_dates, key).block_until_ready()
335+ X = shift_forward_and_sample(
336+ x_init, params, sample_dates, key).block_until_ready()
377337```
378338
379339Let's plot the output.
@@ -464,13 +424,19 @@ def compute_freq(params, key,
464424``` {code-cell} ipython3
465425key = random.PRNGKey(27)
466426
467- %time freq = compute_freq(params, key).block_until_ready()
427+ start_time = time()
428+ freq = compute_freq(params, key).block_until_ready()
429+ end_time = time()
430+ print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms")
468431```
469432
470433We run the code again to get rid of compile time.
471434
472435``` {code-cell} ipython3
473- %time freq = compute_freq(params, key).block_until_ready()
436+ start_time = time()
437+ freq = compute_freq(params, key).block_until_ready()
438+ end_time = time()
439+ print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms")
474440```
475441
476442``` {code-cell} ipython3
@@ -533,13 +499,19 @@ def compute_freq(params, key,
533499Note the time the routine takes to run, as well as the output
534500
535501``` {code-cell} ipython3
536- %time freq = compute_freq(params, key).block_until_ready()
502+ start_time = time()
503+ freq = compute_freq(params, key).block_until_ready()
504+ end_time = time()
505+ print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms")
537506```
538507
539508We run the code again to eliminate the compile time.
540509
541510``` {code-cell} ipython3
542- %time freq = compute_freq(params, key).block_until_ready()
511+ start_time = time()
512+ freq = compute_freq(params, key).block_until_ready()
513+ end_time = time()
514+ print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms")
543515```
544516
545517``` {code-cell} ipython3
0 commit comments