@@ -11,8 +11,6 @@ kernelspec:
1111 name : python3
1212---
1313
14- +++ {"user_expressions": [ ] }
15-
1614``` {raw} html
1715<div id="qe-notebook-header" align="right" style="text-align:right;">
1816 <a href="https://quantecon.org/" title="quantecon.org">
@@ -92,7 +90,6 @@ Firm = namedtuple('Firm', ['s', 'S', 'mu', 'sigma'])
9290firm = Firm(s=10, S=100, mu=1.0, sigma=0.5)
9391```
9492
95-
9693## Example 1: marginal distributions
9794
9895Now let’s look at the marginal distribution $\psi_T$ of $X_T$ for some fixed
@@ -168,19 +165,19 @@ def update_X(X, firm, D):
168165 return res
169166
170167
171- def shift_firms_forward(x_init, firm, sample_dates,
168+ def shift_firms_forward(x_init, firm, sample_dates,
172169 key, num_firms=50_000, sim_length=750):
173-
170+
174171 X = res = jnp.full((num_firms, ), x_init)
175172
176173 # Use for loop to update X and collect samples
177174 for i in range(sim_length):
178175 Z = random.normal(key, shape=(num_firms, ))
179176 D = jnp.exp(firm.mu + firm.sigma * Z)
180-
177+
181178 X = update_X(X, firm, D)
182179 _, key = random.split(key)
183-
180+
184181 # draw a sample at the sample dates
185182 if (i+1 in sample_dates):
186183 res = jnp.vstack((res, X))
@@ -200,7 +197,7 @@ fig, ax = plt.subplots()
200197 sample_dates, key).block_until_ready()
201198
202199for i, date in enumerate(sample_dates):
203- plot_kde(X[i, :], ax, label=f't = {date}')
200+ plot_kde(X[i, :], ax, label=f't = {date}')
204201
205202ax.set_xlabel('inventory')
206203ax.set_ylabel('probability')
@@ -224,16 +221,16 @@ Here is an example of the same function in `lax.scan`
224221@jax.jit
225222def shift_firms_forward(x_init, firm, key,
226223 num_firms=50_000, sim_length=750):
227-
224+
228225 s, S, mu, sigma = firm.s, firm.S, firm.mu, firm.sigma
229226 X = jnp.full((num_firms, ), x_init)
230227 Z = random.normal(key, shape=(sim_length, num_firms))
231228 D = jnp.exp(mu + sigma * Z)
232-
229+
233230 # Define the function for each update
234231 def update_X(X, D):
235- res = jnp.where(X <= s,
236- jnp.maximum(S - D, 0),
232+ res = jnp.where(X <= s,
233+ jnp.maximum(S - D, 0),
237234 jnp.maximum(X - D, 0))
238235 return res, res
239236
@@ -263,7 +260,7 @@ fig, ax = plt.subplots()
263260%time X = shift_firms_forward(x_init, firm, key).block_until_ready()
264261
265262for date in sample_dates:
266- plot_kde(X[date, :], ax, label=f't = {date}')
263+ plot_kde(X[date, :], ax, label=f't = {date}')
267264
268265ax.set_xlabel('inventory')
269266ax.set_ylabel('probability')
@@ -290,20 +287,18 @@ fig, ax = plt.subplots()
290287%time X = shift_firms_forward(x_init, firm, key).block_until_ready()
291288
292289for date in sample_dates:
293- plot_kde(X[date, :], ax, label=f't = {date}')
290+ plot_kde(X[date, :], ax, label=f't = {date}')
294291
295292ax.set_xlabel('inventory')
296293ax.set_ylabel('probability')
297294ax.legend()
298295plt.show()
299296```
300297
301-
302-
303298## Example 2: restock frequency
304299
305300Let's go through another example where we calculate the probability of firms
306- having restocks.
301+ having restocks.
307302
308303Specifically we set the starting stock level to 70 ($X_0 = 70$), as we calculate
309304the proportion of firms that need to order twice or more in the first 50
@@ -317,19 +312,19 @@ Again, we start with an easier `for` loop implementation
317312# Define a jitted function for each update
318313@jax.jit
319314def update_stock(n_restock, X, firm, D):
320- n_restock = jnp.where(X <= firm.s,
321- n_restock + 1,
322- n_restock)
323- X = jnp.where(X <= firm.s,
324- jnp.maximum(firm.S - D, 0),
325- jnp.maximum(X - D, 0))
326- return n_restock, X, key
327-
328- def compute_freq(firm, key,
329- x_init=70,
330- sim_length=50,
315+ n_restock = jnp.where(X <= firm.s,
316+ n_restock + 1,
317+ n_restock)
318+ X = jnp.where(X <= firm.s,
319+ jnp.maximum(firm.S - D, 0),
320+ jnp.maximum(X - D, 0))
321+ return n_restock, X, key
322+
323+ def compute_freq(firm, key,
324+ x_init=70,
325+ sim_length=50,
331326 num_firms=1_000_000):
332-
327+
333328 # Prepare initial arrays
334329 X = jnp.full((num_firms, ), x_init)
335330
@@ -343,7 +338,7 @@ def compute_freq(firm, key,
343338 n_restock, X, key = update_stock(
344339 n_restock, X, firm, D)
345340 key = random.fold_in(key, i)
346-
341+
347342 return jnp.mean(n_restock > 1, axis=0)
348343```
349344
@@ -353,17 +348,17 @@ key = random.PRNGKey(27)
353348print(f"Frequency of at least two stock outs = {freq}")
354349```
355350
356- ### Alternative implementation with ` lax.scan `
351+ ### Alternative implementation with ` lax.fori_loop `
357352
358- Now let's write a ` lax.scan ` version that JIT compiles the whole function
353+ Now let's write a ` lax.fori_loop ` version that JIT compiles the whole function
359354
360355``` {code-cell} ipython3
361356@jax.jit
362- def compute_freq(firm, key,
363- x_init=70,
364- sim_length=50,
357+ def compute_freq(firm, key,
358+ x_init=70,
359+ sim_length=50,
365360 num_firms=1_000_000):
366-
361+
367362 s, S, mu, sigma = firm.s, firm.S, firm.mu, firm.sigma
368363 # Prepare initial arrays
369364 X = jnp.full((num_firms, ), x_init)
@@ -372,29 +367,26 @@ def compute_freq(firm, key,
372367
373368 # Stack the restock counter on top of the inventory
374369 restock_count = jnp.zeros((num_firms, ))
375- Xs = jnp.vstack(( X, restock_count) )
370+ Xs = ( X, restock_count)
376371
377372 # Define the function for each update
378- def update_X(Xs, D):
379-
373+ def update_X(i, Xs):
380374 # Separate the inventory and restock counter
381- X = Xs[0]
382- restock_count = Xs[1]
383-
384- restock_count = jnp.where(X <= s,
385- restock_count + 1,
386- restock_count)
387- X = jnp.where(X <= s,
388- jnp.maximum(S - D, 0),
389- jnp.maximum(X - D, 0))
390-
391- Xs = jnp.vstack((X, restock_count))
392- return Xs, None
375+ x, restock_count = Xs[0], Xs[1]
376+ restock_count = jnp.where(x <= s,
377+ restock_count + 1,
378+ restock_count)
379+ x = jnp.where(x <= s,
380+ jnp.maximum(S - D[i], 0),
381+ jnp.maximum(x - D[i], 0))
393382
394- # Use lax.scan to perform the calculations on all states
395- X_final, _ = lax.scan(update_X, Xs, D)
383+ Xs = (x, restock_count)
384+ return Xs
385+
386+ # Use lax.fori_loop to perform the calculations on all states
387+ X_final = lax.fori_loop(0, sim_length, update_X, Xs)
396388
397- return np .mean(X_final[1] > 1)
389+ return jnp .mean(X_final[1] > 1)
398390```
399391
400392Note the time the routine takes to run, as well as the output
0 commit comments