In [1]:
import numpy as np
import numba
import time

In [2]:
n = 1000000
a = np.random.randn(n)
b = np.random.randn(n)
c = np.zeros(n, dtype='float64')

In [8]:
@numba.njit(parallel=False)
def numba_fun(x,y):
    
    for i in range(1,n):
        y[i] = y[i-1] + x[i]
        
numba_fun(a,c)
print(c[:10])

for i in range(10):
    t1 = time.perf_counter_ns()
    numba_fun(a,c)
    t2 = time.perf_counter_ns()
    print(t2-t1)

[ 0.          0.76442425  0.28357462 -0.07494148 -0.55346375 -0.7930858
 -0.73389193 -2.76701364 -2.15278419 -2.14695208]
2078920
1945753
1870994
1840053
1825063
1809684
1807234
1801520
1798409
1805393


In [12]:
@numba.njit(parallel=True)
def numba_fun(x,y):
    
    for i in numba.prange(1,n):
        y[i] = y[i-1] + x[i]
        
numba.set_num_threads(4)
numba_fun(a,c)
print(c[:10])

for i in range(10):
    t1 = time.perf_counter_ns()
    numba_fun(a,c)
    t2 = time.perf_counter_ns()
    print(t2-t1)

[ 0.          0.76442425  0.28357462 -0.07494148 -0.55346375 -0.7930858
 -0.73389193 -2.76701364 -2.15278419 -2.14695208]
1057031
1059200
1068089
1067249
1086205
1041861
1033879
1032966
1030534
1040900


In [16]:
@numba.njit(parallel=False)
def prange_right_result(x):
    y = np.zeros(4)
    n = x.shape[0]
    for i in numba.prange(n):
        y[:] = y[0] + x[i]
        
    return y

In [17]:
@numba.njit(parallel=True)
def prange_wrong_result(x):
    y = np.zeros(4)
    n = x.shape[0]
    for i in numba.prange(n):
        y[:] = y[0] + x[i]
        
    return y

In [19]:
x = np.random.rand(1000,4)
print(prange_right_result(x))
print(prange_wrong_result(x))

[507.72435016 507.83014467 507.47133022 508.17036267]
[1.20413928 1.3099338  0.95111935 1.6501518 ]


In [20]:
@numba.njit(parallel=False)
def reduction_in_serial(x):
    y = 0.0
    n = x.shape[0]
    for i in numba.prange(n):
        y += x[i]
    return y

In [21]:
@numba.njit(parallel=True)
def reduction_in_parallel(x):
    y = 0.0
    n = x.shape[0]
    for i in numba.prange(n):
        y += x[i]
    return y

In [23]:
x = np.arange(1000)
print(reduction_in_serial(x))
print(reduction_in_parallel(x))

499500.0
499500.0
