Skip to content

Commit a5f732a

Browse files
jstacmmckyHumphreyYang
authored
[inventory_dynamics] Code improvements: typing and timing (#224)
* misc * add install of quantecon package * updates according to suggestions --------- Co-authored-by: mmcky <mamckay@gmail.com> Co-authored-by: Matt McKay <mmcky@users.noreply.github.com> Co-authored-by: Humphrey Yang <u6474961@anu.edu.au>
1 parent 09ffd3e commit a5f732a

File tree

1 file changed

+73
-101
lines changed

1 file changed

+73
-101
lines changed

lectures/inventory_dynamics.md

Lines changed: 73 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ jupytext:
44
extension: .md
55
format_name: myst
66
format_version: 0.13
7-
jupytext_version: 1.16.1
7+
jupytext_version: 1.17.2
88
kernelspec:
99
display_name: Python 3 (ipykernel)
1010
language: python
@@ -56,7 +56,8 @@ import numpy as np
5656
import jax
5757
import jax.numpy as jnp
5858
from jax import random, lax
59-
from collections import namedtuple
59+
from typing import NamedTuple
60+
from time import time
6061
```
6162

6263
Here's a description of our GPU:
@@ -97,10 +98,11 @@ and standard normal.
9798
Here's a `namedtuple` that stores parameters.
9899

99100
```{code-cell} ipython3
100-
Parameters = namedtuple('Parameters', ['s', 'S', 'μ', 'σ'])
101-
102-
# Create a default instance
103-
params = Parameters(s=10, S=100, μ=1.0, σ=0.5)
101+
class ModelParameters(NamedTuple):
102+
s: int = 10
103+
S: int = 100
104+
μ: float = 1.0
105+
σ: float = 0.5
104106
```
105107

106108
## Cross-sectional distributions
@@ -126,19 +128,21 @@ We will use the following code to update the cross-section of firms by one perio
126128

127129
```{code-cell} ipython3
128130
@jax.jit
129-
def update_cross_section(params, X_vec, D):
131+
def update_cross_section(params: ModelParameters,
132+
X_vec: jnp.ndarray,
133+
D: jnp.ndarray) -> jnp.ndarray:
130134
"""
131135
Update by one period a cross-section of firms with inventory levels given by
132-
X_vec, given the vector of demand shocks in D.
133-
134-
* D[i] is the demand shock for firm i with current inventory X_vec[i]
136+
X_vec, given the vector of demand shocks in D. Here D[i] is the demand shock
137+
for firm i with current inventory X_vec[i].
135138
136139
"""
137140
# Unpack
138141
s, S = params.s, params.S
139142
# Restock if the inventory is below the threshold
140143
X_new = jnp.where(X_vec <= s,
141-
jnp.maximum(S - D, 0), jnp.maximum(X_vec - D, 0))
144+
jnp.maximum(S - D, 0),
145+
jnp.maximum(X_vec - D, 0))
142146
return X_new
143147
```
144148

@@ -149,17 +153,18 @@ initial distribution $\psi_0$ and a positive integer $T$.
149153

150154
In this code we use an ordinary Python `for` loop to step forward through time
151155

152-
While Python loops are slow, this approach is reasonable here because
153-
efficiency of outer loops has far less influence on runtime than efficiency of inner loops.
154-
155156
(Below we will squeeze out more speed by compiling the outer loop as well as the
156157
update rule.)
157158

158159
In the code below, the initial distribution $\psi_0$ takes all firms to have
159160
initial inventory `x_init`.
160161

161162
```{code-cell} ipython3
162-
def compute_cross_section(params, x_init, T, key, num_firms=50_000):
163+
def project_cross_section(params: ModelParameters,
164+
x_init: jnp.ndarray,
165+
T: int,
166+
key: jnp.ndarray,
167+
num_firms: int = 50_000) -> jnp.ndarray:
163168
# Set up initial distribution
164169
X_vec = jnp.full((num_firms, ), x_init)
165170
# Loop
@@ -176,6 +181,7 @@ def compute_cross_section(params, x_init, T, key, num_firms=50_000):
176181
We'll use the following specification
177182

178183
```{code-cell} ipython3
184+
params = ModelParameters()
179185
x_init = 50
180186
T = 500
181187
# Initialize random number generator
@@ -185,15 +191,21 @@ key = random.PRNGKey(10)
185191
Let's look at the timing.
186192

187193
```{code-cell} ipython3
188-
%time X_vec = compute_cross_section(params, \
189-
x_init, T, key).block_until_ready()
194+
start_time = time()
195+
X_vec = project_cross_section(
196+
params, x_init, T, key).block_until_ready()
197+
end_time = time()
198+
print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms")
190199
```
191200

192201
Let's run again to eliminate compile time.
193202

194203
```{code-cell} ipython3
195-
%time X_vec = compute_cross_section(params, \
196-
x_init, T, key).block_until_ready()
204+
start_time = time()
205+
X_vec = project_cross_section(
206+
params, x_init, T, key).block_until_ready()
207+
end_time = time()
208+
print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms")
197209
```
198210

199211
Here's a histogram of inventory levels at time $T$.
@@ -218,15 +230,21 @@ through the time dimension.
218230
We will do this using `jax.jit` and a `fori_loop`, which is a compiler-ready version of a `for` loop provided by JAX.
219231

220232
```{code-cell} ipython3
221-
def compute_cross_section_fori(params, x_init, T, key, num_firms=50_000):
233+
def project_cross_section_fori(
234+
params: ModelParameters,
235+
x_init: jnp.ndarray,
236+
T: int,
237+
key: jnp.ndarray,
238+
num_firms: int = 50_000
239+
) -> jnp.ndarray:
222240
223241
s, S, μ, σ = params.s, params.S, params.μ, params.σ
224242
X = jnp.full((num_firms, ), x_init)
225243
226244
# Define the function for each update
227-
def fori_update(t, inputs):
245+
def fori_update(t, loop_state):
228246
# Unpack
229-
X, key = inputs
247+
X, key = loop_state
230248
# Draw shocks using key
231249
Z = random.normal(key, shape=(num_firms,))
232250
D = jnp.exp(μ + σ * Z)
@@ -239,90 +257,40 @@ def compute_cross_section_fori(params, x_init, T, key, num_firms=50_000):
239257
return X, subkey
240258
241259
# Loop t from 0 to T, applying fori_update each time.
242-
# The initial condition for fori_update is (X, key).
243-
X, key = lax.fori_loop(0, T, fori_update, (X, key))
244-
260+
initial_loop_state = X, key
261+
X, key = lax.fori_loop(0, T, fori_update, initial_loop_state)
245262
return X
246263
247264
# Compile taking T and num_firms as static (changes trigger recompile)
248-
compute_cross_section_fori = jax.jit(
249-
compute_cross_section_fori, static_argnums=(2, 4))
265+
project_cross_section_fori = jax.jit(
266+
project_cross_section_fori, static_argnums=(2, 4))
250267
```
251268

252269
Let's see how fast this runs with compile time.
253270

254271
```{code-cell} ipython3
255-
%time X_vec = compute_cross_section_fori(params, \
256-
x_init, T, key).block_until_ready()
272+
start_time = time()
273+
X_vec = project_cross_section_fori(
274+
params, x_init, T, key).block_until_ready()
275+
end_time = time()
276+
print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms")
257277
```
258278

259279
And let's see how fast it runs without compile time.
260280

261281
```{code-cell} ipython3
262-
%time X_vec = compute_cross_section_fori(params, \
263-
x_init, T, key).block_until_ready()
282+
start_time = time()
283+
X_vec = project_cross_section_fori(
284+
params, x_init, T, key).block_until_ready()
285+
end_time = time()
286+
print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms")
264287
```
265288

266289
Compared to the original version with a pure Python outer loop, we have
267290
produced a nontrivial speed gain.
268291

269292

270-
This is due to the fact that we have compiled the whole operation.
271-
272-
273-
274-
275-
### Further vectorization
276-
277-
For relatively small problems, we can make this code run even faster by generating
278-
all random variables at once.
279-
280-
This improves efficiency because we are taking more operations out of the loop.
281-
282-
```{code-cell} ipython3
283-
def compute_cross_section_fori(params, x_init, T, key, num_firms=50_000):
284-
285-
s, S, μ, σ = params.s, params.S, params.μ, params.σ
286-
X = jnp.full((num_firms, ), x_init)
287-
Z = random.normal(key, shape=(T, num_firms))
288-
D = jnp.exp(μ + σ * Z)
289-
290-
def update_cross_section(i, X):
291-
X = jnp.where(X <= s,
292-
jnp.maximum(S - D[i, :], 0),
293-
jnp.maximum(X - D[i, :], 0))
294-
return X
295-
296-
X = lax.fori_loop(0, T, update_cross_section, X)
297-
298-
return X
299-
300-
# Compile taking T and num_firms as static (changes trigger recompile)
301-
compute_cross_section_fori = jax.jit(
302-
compute_cross_section_fori, static_argnums=(2, 4))
303-
```
304-
305-
Let's test it with compile time included.
306-
307-
```{code-cell} ipython3
308-
%time X_vec = compute_cross_section_fori(params, \
309-
x_init, T, key).block_until_ready()
310-
```
311-
312-
Let's run again to eliminate compile time.
313-
314-
```{code-cell} ipython3
315-
%time X_vec = compute_cross_section_fori(params, \
316-
x_init, T, key).block_until_ready()
317-
```
318-
319-
On one hand, this version is faster than the previous one, where random variables were
320-
generated inside the loop.
321-
322-
On the other hand, this implementation consumes far more memory, as we need to
323-
store large arrays of random draws.
324-
325-
The high memory consumption becomes problematic for large problems.
293+
This is due to the fact that we have compiled the entire sequence of operations.
326294

327295

328296

@@ -364,16 +332,8 @@ num_firms = 10_000
364332
sample_dates = 10, 50, 250, 500, 750
365333
key = random.PRNGKey(10)
366334
367-
368-
%time X = shift_forward_and_sample(x_init, params, \
369-
sample_dates, key).block_until_ready()
370-
```
371-
372-
We run the code again to eliminate compile time.
373-
374-
```{code-cell} ipython3
375-
%time X = shift_forward_and_sample(x_init, params, \
376-
sample_dates, key).block_until_ready()
335+
X = shift_forward_and_sample(
336+
x_init, params, sample_dates, key).block_until_ready()
377337
```
378338

379339
Let's plot the output.
@@ -464,13 +424,19 @@ def compute_freq(params, key,
464424
```{code-cell} ipython3
465425
key = random.PRNGKey(27)
466426
467-
%time freq = compute_freq(params, key).block_until_ready()
427+
start_time = time()
428+
freq = compute_freq(params, key).block_until_ready()
429+
end_time = time()
430+
print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms")
468431
```
469432

470433
We run the code again to get rid of compile time.
471434

472435
```{code-cell} ipython3
473-
%time freq = compute_freq(params, key).block_until_ready()
436+
start_time = time()
437+
freq = compute_freq(params, key).block_until_ready()
438+
end_time = time()
439+
print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms")
474440
```
475441

476442
```{code-cell} ipython3
@@ -533,13 +499,19 @@ def compute_freq(params, key,
533499
Note the time the routine takes to run, as well as the output
534500

535501
```{code-cell} ipython3
536-
%time freq = compute_freq(params, key).block_until_ready()
502+
start_time = time()
503+
freq = compute_freq(params, key).block_until_ready()
504+
end_time = time()
505+
print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms")
537506
```
538507

539508
We run the code again to eliminate the compile time.
540509

541510
```{code-cell} ipython3
542-
%time freq = compute_freq(params, key).block_until_ready()
511+
start_time = time()
512+
freq = compute_freq(params, key).block_until_ready()
513+
end_time = time()
514+
print(f"Elapsed time: {(end_time - start_time) * 1000:.6f} ms")
543515
```
544516

545517
```{code-cell} ipython3

0 commit comments

Comments
 (0)