In [1]:
from numba import jit
import numpy as np
import time
################ Pure Python ###############
# Function is not compiled and runs in byte code
def python_trace(a):
    trace = 0.0
    for i in range(a.shape[0]):
        trace += np.tanh(a[i, i])
    return a + trace


In [2]:
################ Numba ###############
# Function is compiled and runs in machine code
@jit(nopython=True) # <--------------- Numba decorator
def numba_trace(a):
    trace = 0.0
    for i in range(a.shape[0]):
        trace += np.tanh(a[i, i])
    return a + trace

In [3]:
x = np.random.rand(100,100)

In [4]:
%%time

#   Warmup the function
numba_trace(x)

CPU times: user 1.41 s, sys: 17.1 ms, total: 1.42 s
Wall time: 405 ms


array([[45.00117695, 45.40615077, 45.84281113, ..., 45.42475777,
        45.46019224, 45.07311551],
       [45.75557008, 44.9514531 , 45.10743309, ..., 45.73538339,
        44.98134267, 45.82383864],
       [45.14951293, 45.34624588, 45.89572805, ..., 45.29192572,
        45.25532091, 45.73589068],
       ...,
       [45.5149149 , 45.55785516, 45.18141423, ..., 45.87580545,
        45.37979148, 45.76496928],
       [45.50736554, 45.38626693, 45.84615668, ..., 45.85398976,
        45.035848  , 45.74804314],
       [45.59778675, 45.13451265, 45.84708884, ..., 45.48008762,
        45.567107  , 45.40221395]])

In [5]:
%%timeit
python_trace(x)

73.8 μs ± 508 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [7]:
%%time
numba_trace(x)

CPU times: user 23 μs, sys: 2 μs, total: 25 μs
Wall time: 26.5 μs


array([[45.00117695, 45.40615077, 45.84281113, ..., 45.42475777,
        45.46019224, 45.07311551],
       [45.75557008, 44.9514531 , 45.10743309, ..., 45.73538339,
        44.98134267, 45.82383864],
       [45.14951293, 45.34624588, 45.89572805, ..., 45.29192572,
        45.25532091, 45.73589068],
       ...,
       [45.5149149 , 45.55785516, 45.18141423, ..., 45.87580545,
        45.37979148, 45.76496928],
       [45.50736554, 45.38626693, 45.84615668, ..., 45.85398976,
        45.035848  , 45.74804314],
       [45.59778675, 45.13451265, 45.84708884, ..., 45.48008762,
        45.567107  , 45.40221395]])