# Learning numba
## First example


In [21]:
import numba
from numba import jit
import numpy as np
x = np.arange(100).reshape(10, 10)
def go_slow(a):
    trace = 0.0
    for i in range(a.shape[0]):   # Numba likes loops
        trace += np.tanh(a[i, i]) # Numba likes NumPy functions
    return a + trace              # Numba likes NumPy broadcasting

# @jit(nopython=True) # Set "nopython" mode for best performance, equivalent to @njit
# @jit(nopython=True, parallel=True, nogil=True, cache=True, inline='always')
@jit(nopython=True, nogil=True, cache=True, inline='always')
def go_fast(a): # Function is compiled to machine code when called the first time
    # return go_slow(a)  # 错，不能被编译
    trace = 0.0
    for i in range(a.shape[0]):   # Numba likes loops
        trace += np.tanh(a[i, i]) # Numba likes NumPy functions
    return a + trace              # Numba likes NumPy broadcasting
go_fast(x) # 先编译好
%timeit go_fast(x)
%timeit go_slow(x)

1.11 µs ± 12 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
23.5 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [16]:
x = np.arange(100).reshape(10, 10)
@jit(nopython=True, nogil=True, cache=True, inline='always')
def add(a, b):
    return a+b

# 正确方法
add(x, x)
%timeit add(x, x)
%timeit x+x

# 无效方法
%timeit add(2, 3)
%timeit 2+3


# 可能有效方法
%timeit add(2, x)
%timeit 2+x

1.21 µs ± 23.9 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
578 ns ± 15.5 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
204 ns ± 1.64 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
6.15 ns ± 0.0838 ns per loop (mean ± std. dev. of 7 runs, 100,000,000 loops each)
The slowest run took 13.67 times longer than the fastest. This could mean that an intermediate result is being cached.
4.77 µs ± 5.58 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.2 µs ± 10.8 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [22]:
a = list(range(1000))
# @jit(nopython=True, nogil=True, cache=True, inline='always')
@jit(nopython=False, nogil=True, cache=True, inline='always')
def list_add_fast(x,y):
    z = x.copy()
    for i, v in enumerate(y):
        z[i]+=v
    return z
def list_add_slow(x,y):
    z = x.copy()
    for i, v in enumerate(y):
        z[i]+=v
    return z
list_add_fast(a,a)
%timeit list_add_fast(a,a)
%timeit list_add_slow(a,a)

Encountered the use of a type that is scheduled for deprecation: type 'reflected list' found for argument 'x' of function 'list_add_fast'.

For more information visit https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-reflection-for-list-and-set-types
[1m
File "C:\Users\YeCanming\AppData\Local\Temp\ipykernel_24984\631882851.py", line 4:[0m
[1m@jit(nopython=False, nogil=True, cache=True, inline='always')
[1mdef list_add_fast(x,y):
[0m[1m^[0m[0m
[0m
Encountered the use of a type that is scheduled for deprecation: type 'reflected list' found for argument 'y' of function 'list_add_fast'.

For more information visit https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-reflection-for-list-and-set-types
[1m
File "C:\Users\YeCanming\AppData\Local\Temp\ipykernel_24984\631882851.py", line 4:[0m
[1m@jit(nopython=False, nogil=True, cache=True, inline='always')
[1mdef list_add_fast(x,y):
[0m[1m^[0m[0m
[0m
Compilation is fal

153 µs ± 2.01 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
65.9 µs ± 1.27 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


## numba性能指南
> note: 实践是检验真理的唯一标准。这里的只是理论上的。要用真实数据跑跑才知道快不快

### loops

In [35]:
from numba import njit
# @njit(cache = True)
def ident_np(x):
    return np.cos(x) ** 2 + np.sin(x) ** 2

@njit(cache = True)
def ident_loops(x):
    r = np.empty_like(x)
    n = len(x)
    for i in range(n):
        r[i] = np.cos(x[i]) ** 2 + np.sin(x[i]) ** 2
    return r
x = np.arange(1.e7)
%timeit ident_np(x)
%timeit ident_loops(x)

249 ms ± 27.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
187 ms ± 1.54 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


note: 加速了还是比没加速好；numpy本来也有加速
### Fastmath

In [32]:
@njit(fastmath=False)
def do_sum(A):
    acc = 0.
    # without fastmath, this loop must accumulate in strict order
    for x in A:
        acc += np.sqrt(x)
    return acc

@njit(fastmath=True)
def do_sum_fast(A):
    acc = 0.
    # with fastmath, the reduction can be vectorized as floating point
    # reassociation is permitted.
    for x in A:
        acc += np.sqrt(x)
    return acc
%timeit do_sum(x)
%timeit do_sum_fast(x)

21.6 ms ± 515 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
21.7 ms ± 133 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


区别不大

### Parallel

In [36]:
@njit(parallel=True, nogil=True)
def ident_parallel(x):
    return np.cos(x) ** 2 + np.sin(x) ** 2
%timeit ident_loops(x)

188 ms ± 2.78 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


也是区别不大

In [39]:
#  out of order execution is valid
from numba import prange
@njit(parallel=True)
def do_sum_parallel(A):
    # each thread can accumulate its own partial sum, and then a cross
    # thread reduction is performed to obtain the result to return
    n = len(A)
    acc = 0.
    for i in prange(n):
        acc += np.sqrt(A[i])
    return acc

@njit(parallel=True, fastmath=True)
def do_sum_parallel_fast(A):
    n = len(A)
    acc = 0.
    for i in prange(n):
        acc += np.sqrt(A[i])
    return acc
%timeit do_sum_parallel(x)
%timeit do_sum_parallel_fast(x)

5.7 ms ± 1.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.05 ms ± 239 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


说明parallel对for loop 而不是np 加速更明显

###

In [46]:
import scipy
def do_norm(x):
    return np.linalg.norm(x)
@njit(cache=True, fastmath=True)
def do_norm_fast(x):
    return np.linalg.norm(x)
@njit(cache=True, parallel = True, fastmath=True)
def do_norm_list(x):
    n = len(x)
    acc = 0.
    for i in prange(n):
        acc += np.square(x[i])
    return np.sqrt(acc)
def do_norm_slow(x):
    n = len(x)
    acc = 0.
    for i in prange(n):
        acc += np.square(x[i])
    return np.sqrt(acc)
x = np.arange(1.e7)
%timeit a = do_norm(x)
%timeit b = do_norm_fast(x)
%timeit c = do_norm_list(x)
# print(f"a={a}, b={b}, c={c}")
# %timeit do_norm_slow(x)

3.08 ms ± 57.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
6.38 ms ± 149 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.8 ms ± 80.6 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


NameError: name 'b' is not defined