# Numba jit
`nb.njit`（`nb.jit`のnopythonモード）を使ってpythonコードを"just in time" compileで高速化する．

## 参考
- [array layout](https://stackoverflow.com/questions/66363148/in-numba-whats-the-difference-between-arrayfloat64-1d-c-and-arrayfloat64)

In [50]:
import numba as nb
import numpy as np

In [51]:
def dot_loop(A, B):
    assert A.shape[1] == B.shape[0]
    L, M = A.shape
    M, N = B.shape
    C = np.zeros((L, N))
    for i in range(L):
        for j in range(N):
            for k in range(M):
                C[i, j] += A[i, k] * B[k, j]
    return C

@nb.njit
def dot_loop_jit(A, B):
    assert A.shape[1] == B.shape[0]
    L, M = A.shape
    M, N = B.shape
    C = np.zeros((L, N))
    for i in range(L):
        for j in range(N):
            for k in range(M):
                C[i, j] += A[i, k] * B[k, j]
    return C

In [52]:
L, M, N = 200, 100, 200
A = np.ones((L, M))
B = np.ones((M, N))
dot_loop(A, B)

array([[100., 100., 100., ..., 100., 100., 100.],
       [100., 100., 100., ..., 100., 100., 100.],
       [100., 100., 100., ..., 100., 100., 100.],
       ...,
       [100., 100., 100., ..., 100., 100., 100.],
       [100., 100., 100., ..., 100., 100., 100.],
       [100., 100., 100., ..., 100., 100., 100.]])

In [69]:
L, M, N = 200, 100, 200
A = np.ones((L, M))
B = np.ones((M, N))
dot_loop_jit(A, B)

array([[100., 100., 100., ..., 100., 100., 100.],
       [100., 100., 100., ..., 100., 100., 100.],
       [100., 100., 100., ..., 100., 100., 100.],
       ...,
       [100., 100., 100., ..., 100., 100., 100.],
       [100., 100., 100., ..., 100., 100., 100.],
       [100., 100., 100., ..., 100., 100., 100.]])

型指定とメモリアクセス

In [54]:
N=1000

In [55]:
def dot(a, b):
    return a @ b

@nb.njit("f8(f8[:],f8[:])", cache=True)
def dot_jit(a, b):
    return a @ b

c_vector = nb.types.Array(dtype=nb.f8, ndim=1, layout="C")

@nb.njit(nb.f8(c_vector, c_vector), cache=True)
def dot_jit_type(a, b):
    return a @ b

  return a @ b


In [56]:
%%timeit
a = np.ones(N)
b = np.ones(N)
_ = dot(a, b)

15.6 µs ± 1.61 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [57]:
%%timeit
a = np.ones(N)
b = np.ones(N)
_ = dot_jit(a, b)

25.1 µs ± 1.7 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [58]:
%%timeit
a = np.ones(N)
b = np.ones(N)
_ = dot_jit_type(a, b)

12.8 µs ± 678 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


### 2dim

In [59]:
def dot2(a, b):
    return a @ b

@nb.njit("f8[:,:](f8[:,:],f8[:,:])", cache=True)
def dot2_jit(a, b):
    return a @ b

c_vector = nb.types.Array(dtype=nb.f8, ndim=2, layout="C")

@nb.njit(c_vector(c_vector, c_vector), cache=True)
def dot2_jit_type(a, b):
    return a @ b

In [60]:
N = 1000
a = np.ascontiguousarray(np.ones((N, N)))
b = np.ascontiguousarray(np.ones((N, N)))
print(a.data.c_contiguous)
print(b.data.c_contiguous)
_ = dot2(a, b)

True
True


In [61]:
N = 1000
a = np.ascontiguousarray(np.ones((N, N)))
b = np.ascontiguousarray(np.ones((N, N)))
print(a.data.c_contiguous)
print(b.data.c_contiguous)
_ = dot2_jit(a, b)

True
True


In [62]:
N = 1000
a = np.ascontiguousarray(np.ones((N, N)))
b = np.ascontiguousarray(np.ones((N, N)))
print(a.data.c_contiguous)
print(b.data.c_contiguous)
_ = dot2_jit_type(a, b)

True
True


並列化とprange

In [64]:
def dot_loop(A, B):
    assert A.shape[1] == B.shape[0]
    L, M = A.shape
    M, N = B.shape
    C = np.zeros((L, N))
    for i in range(L):
        for j in range(N):
            for k in range(M):
                C[i, j] += A[i, k] * B[k, j]
    return C

@nb.njit
def dot_loop_jit(A, B):
    assert A.shape[1] == B.shape[0]
    L, M = A.shape
    M, N = B.shape
    C = np.zeros((L, N))
    for i in range(L):
        for j in range(N):
            for k in range(M):
                C[i, j] += A[i, k] * B[k, j]
    return C

@nb.njit(parallel=True)
def dot_loop_jit_prange(A, B):
    assert A.shape[1] == B.shape[0]
    L, M = A.shape
    M, N = B.shape
    C = np.zeros((L, N))
    for i in nb.prange(L):
        for j in nb.prange(N):
            for k in nb.prange(M):
                C[i, j] += A[i, k] * B[k, j]
    return C

In [71]:
L, M, N = 200, 100, 200
A = np.ones((L, M))
B = np.ones((M, N))
dot_loop_jit_prange(A, B)

array([[100., 100., 100., ..., 100., 100., 100.],
       [100., 100., 100., ..., 100., 100., 100.],
       [100., 100., 100., ..., 100., 100., 100.],
       ...,
       [100., 100., 100., ..., 100., 100., 100.],
       [100., 100., 100., ..., 100., 100., 100.],
       [100., 100., 100., ..., 100., 100., 100.]])