Skip to content

Commit 02e6a43

Browse files
Use fori_loop (#115)
* Use fori_loop * update vstack --------- Co-authored-by: Humphrey Yang <humphrey.yang@anu.edu.au>
1 parent e6fdcff commit 02e6a43

File tree

1 file changed

+46
-54
lines changed

1 file changed

+46
-54
lines changed

lectures/inventory_dynamics.md

Lines changed: 46 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ kernelspec:
1111
name: python3
1212
---
1313

14-
+++ {"user_expressions": []}
15-
1614
```{raw} html
1715
<div id="qe-notebook-header" align="right" style="text-align:right;">
1816
<a href="https://quantecon.org/" title="quantecon.org">
@@ -92,7 +90,6 @@ Firm = namedtuple('Firm', ['s', 'S', 'mu', 'sigma'])
9290
firm = Firm(s=10, S=100, mu=1.0, sigma=0.5)
9391
```
9492

95-
9693
## Example 1: marginal distributions
9794

9895
Now let’s look at the marginal distribution $\psi_T$ of $X_T$ for some fixed
@@ -168,19 +165,19 @@ def update_X(X, firm, D):
168165
return res
169166
170167
171-
def shift_firms_forward(x_init, firm, sample_dates,
168+
def shift_firms_forward(x_init, firm, sample_dates,
172169
key, num_firms=50_000, sim_length=750):
173-
170+
174171
X = res = jnp.full((num_firms, ), x_init)
175172
176173
# Use for loop to update X and collect samples
177174
for i in range(sim_length):
178175
Z = random.normal(key, shape=(num_firms, ))
179176
D = jnp.exp(firm.mu + firm.sigma * Z)
180-
177+
181178
X = update_X(X, firm, D)
182179
_, key = random.split(key)
183-
180+
184181
# draw a sample at the sample dates
185182
if (i+1 in sample_dates):
186183
res = jnp.vstack((res, X))
@@ -200,7 +197,7 @@ fig, ax = plt.subplots()
200197
sample_dates, key).block_until_ready()
201198
202199
for i, date in enumerate(sample_dates):
203-
plot_kde(X[i, :], ax, label=f't = {date}')
200+
plot_kde(X[i, :], ax, label=f't = {date}')
204201
205202
ax.set_xlabel('inventory')
206203
ax.set_ylabel('probability')
@@ -224,16 +221,16 @@ Here is an example of the same function in `lax.scan`
224221
@jax.jit
225222
def shift_firms_forward(x_init, firm, key,
226223
num_firms=50_000, sim_length=750):
227-
224+
228225
s, S, mu, sigma = firm.s, firm.S, firm.mu, firm.sigma
229226
X = jnp.full((num_firms, ), x_init)
230227
Z = random.normal(key, shape=(sim_length, num_firms))
231228
D = jnp.exp(mu + sigma * Z)
232-
229+
233230
# Define the function for each update
234231
def update_X(X, D):
235-
res = jnp.where(X <= s,
236-
jnp.maximum(S - D, 0),
232+
res = jnp.where(X <= s,
233+
jnp.maximum(S - D, 0),
237234
jnp.maximum(X - D, 0))
238235
return res, res
239236
@@ -263,7 +260,7 @@ fig, ax = plt.subplots()
263260
%time X = shift_firms_forward(x_init, firm, key).block_until_ready()
264261
265262
for date in sample_dates:
266-
plot_kde(X[date, :], ax, label=f't = {date}')
263+
plot_kde(X[date, :], ax, label=f't = {date}')
267264
268265
ax.set_xlabel('inventory')
269266
ax.set_ylabel('probability')
@@ -290,20 +287,18 @@ fig, ax = plt.subplots()
290287
%time X = shift_firms_forward(x_init, firm, key).block_until_ready()
291288
292289
for date in sample_dates:
293-
plot_kde(X[date, :], ax, label=f't = {date}')
290+
plot_kde(X[date, :], ax, label=f't = {date}')
294291
295292
ax.set_xlabel('inventory')
296293
ax.set_ylabel('probability')
297294
ax.legend()
298295
plt.show()
299296
```
300297

301-
302-
303298
## Example 2: restock frequency
304299

305300
Let's go through another example where we calculate the probability of firms
306-
having restocks.
301+
having restocks.
307302

308303
Specifically we set the starting stock level to 70 ($X_0 = 70$), as we calculate
309304
the proportion of firms that need to order twice or more in the first 50
@@ -317,19 +312,19 @@ Again, we start with an easier `for` loop implementation
317312
# Define a jitted function for each update
318313
@jax.jit
319314
def update_stock(n_restock, X, firm, D):
320-
n_restock = jnp.where(X <= firm.s,
321-
n_restock + 1,
322-
n_restock)
323-
X = jnp.where(X <= firm.s,
324-
jnp.maximum(firm.S - D, 0),
325-
jnp.maximum(X - D, 0))
326-
return n_restock, X, key
327-
328-
def compute_freq(firm, key,
329-
x_init=70,
330-
sim_length=50,
315+
n_restock = jnp.where(X <= firm.s,
316+
n_restock + 1,
317+
n_restock)
318+
X = jnp.where(X <= firm.s,
319+
jnp.maximum(firm.S - D, 0),
320+
jnp.maximum(X - D, 0))
321+
return n_restock, X, key
322+
323+
def compute_freq(firm, key,
324+
x_init=70,
325+
sim_length=50,
331326
num_firms=1_000_000):
332-
327+
333328
# Prepare initial arrays
334329
X = jnp.full((num_firms, ), x_init)
335330
@@ -343,7 +338,7 @@ def compute_freq(firm, key,
343338
n_restock, X, key = update_stock(
344339
n_restock, X, firm, D)
345340
key = random.fold_in(key, i)
346-
341+
347342
return jnp.mean(n_restock > 1, axis=0)
348343
```
349344

@@ -353,17 +348,17 @@ key = random.PRNGKey(27)
353348
print(f"Frequency of at least two stock outs = {freq}")
354349
```
355350

356-
### Alternative implementation with `lax.scan`
351+
### Alternative implementation with `lax.fori_loop`
357352

358-
Now let's write a `lax.scan` version that JIT compiles the whole function
353+
Now let's write a `lax.fori_loop` version that JIT compiles the whole function
359354

360355
```{code-cell} ipython3
361356
@jax.jit
362-
def compute_freq(firm, key,
363-
x_init=70,
364-
sim_length=50,
357+
def compute_freq(firm, key,
358+
x_init=70,
359+
sim_length=50,
365360
num_firms=1_000_000):
366-
361+
367362
s, S, mu, sigma = firm.s, firm.S, firm.mu, firm.sigma
368363
# Prepare initial arrays
369364
X = jnp.full((num_firms, ), x_init)
@@ -372,29 +367,26 @@ def compute_freq(firm, key,
372367
373368
# Stack the restock counter on top of the inventory
374369
restock_count = jnp.zeros((num_firms, ))
375-
Xs = jnp.vstack((X, restock_count))
370+
Xs = (X, restock_count)
376371
377372
# Define the function for each update
378-
def update_X(Xs, D):
379-
373+
def update_X(i, Xs):
380374
# Separate the inventory and restock counter
381-
X = Xs[0]
382-
restock_count = Xs[1]
383-
384-
restock_count = jnp.where(X <= s,
385-
restock_count + 1,
386-
restock_count)
387-
X = jnp.where(X <= s,
388-
jnp.maximum(S - D, 0),
389-
jnp.maximum(X - D, 0))
390-
391-
Xs = jnp.vstack((X, restock_count))
392-
return Xs, None
375+
x, restock_count = Xs[0], Xs[1]
376+
restock_count = jnp.where(x <= s,
377+
restock_count + 1,
378+
restock_count)
379+
x = jnp.where(x <= s,
380+
jnp.maximum(S - D[i], 0),
381+
jnp.maximum(x - D[i], 0))
393382
394-
# Use lax.scan to perform the calculations on all states
395-
X_final, _ = lax.scan(update_X, Xs, D)
383+
Xs = (x, restock_count)
384+
return Xs
385+
386+
# Use lax.fori_loop to perform the calculations on all states
387+
X_final = lax.fori_loop(0, sim_length, update_X, Xs)
396388
397-
return np.mean(X_final[1] > 1)
389+
return jnp.mean(X_final[1] > 1)
398390
```
399391

400392
Note the time the routine takes to run, as well as the output

0 commit comments

Comments
 (0)