Skip to content

Commit eca8220

Browse files
mmckykp992
andauthored
MAINT: update njit to jit (#182)
* update njit to jit * fix typo --------- Co-authored-by: kp992 <kpl992@outlook.com>
1 parent 1886849 commit eca8220

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

lectures/cake_eating_numerical.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ for the purpose of comparing the results of JAX implementation.
323323

324324
```{code-cell} ipython3
325325
import numpy as np
326-
from numba import prange, njit
326+
from numba import prange, jit
327327
from quantecon.optimize import brent_max
328328
```
329329

@@ -344,13 +344,13 @@ def create_cake_eating_model_numba(β=0.96, # discount factor
344344

345345
```{code-cell} ipython3
346346
# Utility function
347-
@njit
347+
@jit
348348
def u_numba(c, cem):
349349
return (c ** (1 - cem.γ)) / (1 - cem.γ)
350350
```
351351

352352
```{code-cell} ipython3
353-
@njit
353+
@jit
354354
def state_action_value_numba(c, x, v_array, cem):
355355
"""
356356
Right hand side of the Bellman equation given x and c.
@@ -363,7 +363,7 @@ def state_action_value_numba(c, x, v_array, cem):
363363
```
364364

365365
```{code-cell} ipython3
366-
@njit
366+
@jit
367367
def T_numba(v, ce):
368368
"""
369369
The Bellman operator. Updates the guess of the value function.
@@ -424,7 +424,7 @@ numba_time = time.time() - in_time
424424

425425
```{code-cell} ipython3
426426
ratio = numba_time/jax_time
427-
print(f"JAX implementation is {ratio} times faster than NumPy.")
427+
print(f"JAX implementation is {ratio} times faster than Numba.")
428428
print(f"JAX time: {jax_time}")
429429
print(f"Numba time: {numba_time}")
430430
```

lectures/inventory_ssd.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ import numpy as np
174174
import matplotlib.pyplot as plt
175175
from collections import namedtuple
176176
import time
177-
from numba import njit, prange
177+
from numba import jit, prange
178178
```
179179

180180
Let's check the GPU we are running
@@ -395,11 +395,11 @@ plot_ts()
395395
Let's try the same operations in Numba in order to compare the speed.
396396

397397
```{code-cell} ipython3
398-
@njit
398+
@jit
399399
def demand_pdf_numba(p, d):
400400
return (1 - p)**d * p
401401
402-
@njit
402+
@jit
403403
def B_numba(x, i_z, a, v, model):
404404
"""
405405
The function B(x, z, a, v) = r(x, a) + β(z) Σ_x′ v(x′) P(x, a, x′).
@@ -415,7 +415,7 @@ def B_numba(x, i_z, a, v, model):
415415
return profit + z * cv
416416
417417
418-
@njit(parallel=True)
418+
@jit(parallel=True)
419419
def T_numba(v, model):
420420
"""The Bellman operator."""
421421
c, κ, p, z_vals, Q = model
@@ -428,7 +428,7 @@ def T_numba(v, model):
428428
return new_v
429429
430430
431-
@njit(parallel=True)
431+
@jit(parallel=True)
432432
def get_greedy_numba(v, model):
433433
"""Get a v-greedy policy. Returns a zero-based array."""
434434
c, κ, p, z_vals, Q = model

0 commit comments

Comments
 (0)