In [1]:
import numpy as np
import numba

def csr_spmv_single_thread(rowptr, colidx, val, x):
    """CSR format based SpMV (y=Ax)"""
    
    num_row = len(rowptr) - 1
    y = np.zeros(num_row, dtype=np.float64)
    
    for i in numba.prange(num_row):
        col_start = rowptr[i]
        col_end = rowptr[i + 1]
        for j in range(col_start, col_end):
            y[i] += val[j] * x[colidx[j]]
    
    return y

@numba.njit
def csr_spmv_multi_thread(rowptr, colidx, val, x):
    """CSR format based SpMV (y=Ax)"""
    
    num_row = len(rowptr) - 1
    y = np.zeros(num_row, dtype=np.float64)
    
    for i in numba.prange(num_row):
        col_start = rowptr[i]
        col_end = rowptr[i + 1]
        for j in range(col_start, col_end):
            y[i] += val[j] * x[colidx[j]]
    
    return y

In [2]:
rowptr = np.array([0, 2, 3, 6])
colidx = np.array([0, 2, 2, 0, 1, 2])
val = np.array([1.11, 2.22, 3.33, 4.44, 5.55, 6.66], dtype=np.float64)

x = np.array([2.22, 3.33, 4.44], dtype=np.float64)

In [9]:
%time y = csr_spmv_single_thread(rowptr, colidx, val, x)

Wall time: 0 ns


In [10]:
%time y = csr_spmv_multi_thread(rowptr, colidx, val, x)

Wall time: 0 ns


In [11]:
y

array([12.321 , 14.7852, 57.9087])