In [103]:
import numpy as np
import numba

def csr_spmv_single_thread(rowptr, colidx, val, x):
    """CSR format based SpMV (y=Ax)"""
    
    num_row = len(rowptr) - 1
    y = np.zeros(num_row, dtype=np.float64)
    
    for i in range(num_row):
        col_start = rowptr[i]
        col_end = rowptr[i + 1]
        for j in range(col_start, col_end):
            y[i] += val[j] * x[colidx[j]]
    
    return y

@numba.njit
def csr_spmv_multi_thread(rowptr, colidx, val, x):
    """CSR format based SpMV (y=Ax)"""
    
    num_row = len(rowptr) - 1
    y = np.zeros(num_row, dtype=np.float64)
    
    for i in numba.prange(num_row):
        col_start = rowptr[i]
        col_end = rowptr[i + 1]
        for j in range(col_start, col_end):
            y[i] += val[j] * x[colidx[j]]
    
    return y

In [15]:
#[[a, 0, b, 0],
# [0, c, 0, d],
# [0, e, 0, 0],
# [0, 0, 0, f]]

# a = 1.11, b = 3.33, c = 2.22, d = 4.44, e = 5.55, f = 6.66

rowptr = np.array([0, 2, 4, 5, 6])
colidx = np.array([0, 2, 1, 3, 1, 3])
val = np.array([1.11, 3.33, 2.22, 4.44, 5.55, 6.66], dtype=np.float64)

x = np.array([2.22, 3.33, 4.44, 5.55], dtype=np.float64)

In [21]:
%time y = csr_spmv_single_thread(rowptr, colidx, val, x)

CPU times: user 43 µs, sys: 1 µs, total: 44 µs
Wall time: 47.9 µs


In [23]:
%time y = csr_spmv_multi_thread(rowptr, colidx, val, x)

CPU times: user 18 µs, sys: 1 µs, total: 19 µs
Wall time: 21 µs


In [24]:
y

array([17.2494, 32.0346, 18.4815, 36.963 ])