In [1]:
import numpy as np
import numba
import random
import math
from numba import jit, set_num_threads

def csr_spmv_single_thread(rowptr, colidx, val, x):
    """CSR format based SpMV (y=Ax)"""
    
    num_row = len(rowptr) - 1
    y = np.zeros(num_row, dtype=np.float32)
    
    for i in range(num_row):
        for j in range(rowptr[i], rowptr[i + 1]):
            Ax_data = val[j] * x[colidx[j]]
            y[i] += val[j] * x[colidx[j]]
    
    return y

@jit(nopython=True, parallel=True, nogil=True, fastmath=True)
def csr_spmv_multi_thread(y, num_row, rowptr, colidx, val, x):
    """CSR format based SpMV (y=Ax)"""
    
    for i in numba.prange(num_row):
        row_data = 0.0
        for j in range(rowptr[i], rowptr[i + 1]):
            Ax_data = val[j] * x[colidx[j]]
            row_data += val[j] * x[colidx[j]]
        y[i] = row_data
        
    return y

def sliced_ellpack_spmv_single_thread(N, slice_ptr, colidx, val, x, slice_height):
    """Sliced ELLPACK format based SpMV (y=Ax)"""
    
    y = np.zeros(N, dtype=np.float32)
    slice_count = int(N / slice_height)
    
    for s in range(slice_count):
        row_idx = s * slice_height
        for r in range(row_idx, row_idx + slice_height):
            for i in range(slice_ptr[s] + r - row_idx, slice_ptr[s + 1], slice_height):
                Ax_data = x[colidx[i]] * val[i]
                y[r] += Ax_data
    
    return y

@jit(nopython=True, parallel=True, nogil=True, fastmath=True)
def sliced_ellpack_spmv_multi_thread(y, N, slice_count, slice_ptr, colidx, val, x, slice_height):
    """Sliced ELLPACK format based SpMV (y=Ax)"""
    
    for s in numba.prange(slice_count):
        row_idx = s * slice_height
        loop_y = np.zeros(slice_height, dtype=np.float32)
        for r in range(slice_height):
            row_data = 0.0
            for i in range(slice_ptr[s] + r, slice_ptr[s + 1], slice_height):
                Ax_data = x[colidx[i]] * val[i]
                row_data += Ax_data
            loop_y[r] = row_data
        y[s * slice_height:(s + 1) * slice_height] = loop_y
        
    return y

def random_spmatrix(n_row, n_col, per_nnz):
    """Output a random value sparse matrix"""
    
    sp_matrix = []
    nnz_count = 0
    row_max_nnz = 0
    
    for i in range(n_row):
        row_data = []
        row_nnz_count = 0
        for j in range(n_col):
            r_val = random.randint(0, 100)
            if r_val < per_nnz:
                row_data.append(0)
            else:
                nnz_count += 1
                row_nnz_count += 1
                row_data.append(r_val)
        row_max_nnz = max(row_max_nnz, row_nnz_count)
        sp_matrix.append(row_data)
    
    return sp_matrix, nnz_count, row_max_nnz

def spmatrix_to_CSR(sp_matrix):
    """Convert sparse matrix to CSR format"""
    
    n_row = len(sp_matrix)
    n_col = len(sp_matrix[0])
    
    rowptr = []
    colidx = []
    val = []
    nnz_count = 0
    
    for i in range(n_row):
        rowptr.append(nnz_count)
        for j in range(n_col):
            if sp_matrix[i][j] != 0:
                nnz_count += 1
                colidx.append(j)
                val.append(sp_matrix[i][j])
    rowptr.append(nnz_count)
    
    return np.array(rowptr), np.array(colidx), np.array(val, dtype=np.float32)

def CSR_to_SELLPACK(rowptr, colidx, val, slice_height):
    """Convert CSR format to Sliced ELLPACK format"""
    
    N = len(rowptr) - 1 # number of rows
    slice_number = math.ceil(N / slice_height) # how many slices
    nnz_count = 0
    
    ell_colidx = []
    ell_sliceptr = []
    ell_val = []
    
    for i in range(slice_number):
        max_nnz = 0
        for s in range(slice_height):
            col_count = rowptr[i * slice_height + s + 1] - rowptr[i * slice_height + s]
            max_nnz = max(max_nnz, col_count)
        
        ell_sliceptr.append(nnz_count)
        for j in range(max_nnz):
            for k in range(slice_height):
                idx = i * slice_height + k # row index
                now_ptr = rowptr[idx] # start index of this row
                next_ptr = rowptr[idx + 1] # start index of next row
                nnz_count += 1 # count non-zero number
                if now_ptr + j < next_ptr:
                    ell_colidx.append(colidx[now_ptr + j])
                    ell_val.append(val[now_ptr + j])
                else:
                    ell_colidx.append(-1) # -1 means invalid
                    ell_val.append(0) # padded zero
    ell_sliceptr.append(nnz_count)
    
    return np.array(ell_colidx), np.array(ell_sliceptr), np.array(ell_val, dtype=np.float32)

In [2]:
# set number of rows, columns
n_row = 2000
n_col = 2000
# sparse matrix non-zero percentage
per_nnz = 5
# generate a sparse matrix fill with random value
sp_matrix, nnz_count, row_max_nnz = random_spmatrix(n_row, n_col, per_nnz)
nnz_per = (nnz_count / (n_row * n_col)) * 100
avg_nnz = nnz_count / n_row
print(str(nnz_count) + " non-zero elements in this sparse matrix (" + str(nnz_per) + "%).")
print("Row average non-zero elements: " + str(avg_nnz) + ", row max non-zero elements: " + str(row_max_nnz))

3802461 non-zero elements in this sparse matrix (95.061525%).
Row average non-zero elements: 1901.2305, row max non-zero elements: 1939


In [3]:
# convert sparse matrix to CSR format
csr_rowptr, csr_colidx, csr_val = spmatrix_to_CSR(sp_matrix)

In [14]:
# set slice height
slice_height = 2
# convert CSR to Sliced ELLPACK
ell_colidx, ell_sliceptr, ell_val = CSR_to_SELLPACK(csr_rowptr, csr_colidx, csr_val, slice_height)

In [5]:
# generate x array
x = np.ones(n_col, dtype=np.float64)
x *= 1.23
# generate data
csr_y = np.zeros(n_row, dtype=np.float32)
ell_y = np.zeros(n_row, dtype=np.float32)
slice_count = int(n_row / slice_height)

In [6]:
set_num_threads(4)

In [11]:
%timeit csr_output = csr_spmv_multi_thread(csr_y, n_row, csr_rowptr, csr_colidx, csr_val, x)

1.23 ms ± 36.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [15]:
%timeit ellpack_output = sliced_ellpack_spmv_multi_thread(ell_y, n_row, slice_count, ell_sliceptr, ell_colidx, ell_val, x, slice_height)

562 µs ± 22.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [13]:
sum(csr_output - ellpack_output)

0.0

In [16]:
sliced_ellpack_spmv_multi_thread.parallel_diagnostics(level=4)

 
 Parallel Accelerator Optimizing:  Function sliced_ellpack_spmv_multi_thread, 
<ipython-input-1-4cb6c8eb78c8> (48)  


Parallel loop listing for  Function sliced_ellpack_spmv_multi_thread, <ipython-input-1-4cb6c8eb78c8> (48) 
-----------------------------------------------------------------------------------------------------|loop #ID
@jit(nopython=True, parallel=True, nogil=True, fastmath=True)                                        | 
def sliced_ellpack_spmv_multi_thread(y, N, slice_count, slice_ptr, colidx, val, x, slice_height):    | 
    """Sliced ELLPACK format based SpMV (y=Ax)"""                                                    | 
                                                                                                     | 
    for s in numba.prange(slice_count):--------------------------------------------------------------| #3
        row_idx = s * slice_height                                                                   | 
        loop_y = np.zeros(slice_hei

In [17]:
sliced_ellpack_spmv_multi_thread.inspect_types(pretty=True)

  warn("The pretty_annotate functionality is experimental and might change API",


0
label 0
"y = arg(0, name=y) :: array(float32, 1d, C)"
"N = arg(1, name=N) :: int64"
del N
"slice_count = arg(2, name=slice_count) :: int64"
"slice_ptr = arg(3, name=slice_ptr) :: array(int32, 1d, C)"
"colidx = arg(4, name=colidx) :: array(int32, 1d, C)"
"val = arg(5, name=val) :: array(float32, 1d, C)"
"x = arg(6, name=x) :: array(float64, 1d, C)"
"slice_height = arg(7, name=slice_height) :: int64"

0
label 0
"y = arg(0, name=y) :: array(float32, 1d, C)"
"N = arg(1, name=N) :: int64"
del N
"slice_count = arg(2, name=slice_count) :: int64"
"slice_ptr = arg(3, name=slice_ptr) :: array(int32, 1d, C)"
"colidx = arg(4, name=colidx) :: array(int32, 1d, C)"
"val = arg(5, name=val) :: array(float32, 1d, C)"
"x = arg(6, name=x) :: array(float64, 1d, C)"
"slice_height = arg(7, name=slice_height) :: int64"

0
$row_idx.141 = s * slice_height :: int64
row_idx = $row_idx.141 :: int64
del row_idx
del $20binary_multiply.4

0
$24load_global.5 = global(np: <module 'numpy' from 'C:\\Users\\50621\\anaconda3\\lib\\site-packages\\numpy\\__init__.py'>) :: Module(<module 'numpy' from 'C:\\Users\\50621\\anaconda3\\lib\\site-packages\\numpy\\__init__.py'>)
"$26load_attr.6 = getattr(value=$24load_global.5, attr=zeros) :: Function(<built-in function zeros>)"
del $24load_global.5
$30load_global.8 = global(np: <module 'numpy' from 'C:\\Users\\50621\\anaconda3\\lib\\site-packages\\numpy\\__init__.py'>) :: Module(<module 'numpy' from 'C:\\Users\\50621\\anaconda3\\lib\\site-packages\\numpy\\__init__.py'>)
"$32load_attr.9 = getattr(value=$30load_global.8, attr=float32) :: class(float32)"
del $30load_global.8
"$36call_function_kw.11 = call $26load_attr.6(slice_height, func=$26load_attr.6, args=[Var(slice_height, <ipython-input-1-4cb6c8eb78c8>:52)], kws=[('dtype', Var($32load_attr.9, <ipython-input-1-4cb6c8eb78c8>:54))], vararg=None) :: (int64, class(float32)) -> array(float32, 1d, C)"
del $32load_attr.9
del $26load_attr.6
"loop_y = $loop_y.142 :: array(float32, 1d, C)"

0
$40load_global.12 = global(range: <class 'range'>) :: Function(<class 'range'>)
"$44call_function.14 = call $push_global_to_block.354(slice__height, func=$push_global_to_block.354, args=[Var(slice__height, <ipython-input-1-4cb6c8eb78c8>:52)], kws=(), vararg=None) :: (int64,) -> range_state_int64"
del $40load_global.12
$46get_iter.15 = getiter(value=$44call_function.14) :: range_iter_int64
del $44call_function.14
$phi48.1 = $46get_iter.15 :: range_iter_int64
del $46get_iter.15
jump 48
label 48
"$48for_iter.2 = iternext(value=$46get_iter.15) :: pair<int64, bool>"

0
"$row_data.2.387 = const(float, 0.0) :: <missing>"
row_data = $row_data.144 :: float64
del $const52.3

0
$56load_global.4 = global(range: <class 'range'>) :: Function(<class 'range'>)
"$62binary_subscr.7 = getitem(value=slice__ptr, index=$s.291) :: int32"
$66binary_add.9 = $62binary_subscr.7 + $r.143 :: int64
del $62binary_subscr.7
"$const72.12 = const(int, 1) :: Literal[int](1)"
$74binary_add.13 = $s.291 + $const72.12 :: int64
del $const72.12
"$76binary_subscr.14 = getitem(value=slice__ptr, index=$74binary_add.13) :: int32"
del $74binary_add.13
"$80call_function.16 = call $push_global_to_block.358($66binary_add.9, $76binary_subscr.14, slice__height, func=$push_global_to_block.358, args=[Var($66binary_add.9, <ipython-input-1-4cb6c8eb78c8>:57), Var($76binary_subscr.14, <ipython-input-1-4cb6c8eb78c8>:57), Var(slice__height, <ipython-input-1-4cb6c8eb78c8>:52)], kws=(), vararg=None) :: (int64, int64, int64) -> range_state_int64"

0
"$94binary_subscr.7 = getitem(value=colidx, index=$i.146) :: int32"
"$96binary_subscr.8 = getitem(value=x, index=$94binary_subscr.7) :: float64"
del $94binary_subscr.7
"$102binary_subscr.11 = getitem(value=val, index=$i.146) :: float32"
del i
$Ax_data.147 = $96binary_subscr.8 * $102binary_subscr.11 :: float64
del $96binary_subscr.8
del $102binary_subscr.11
Ax_data = $Ax_data.147 :: float64
del $104binary_multiply.12

0
"$row_data.2.236 = inplace_binop(fn=<built-in function iadd>, immutable_fn=<built-in function add>, lhs=row_data.2, rhs=$Ax_data.147, static_lhs=Undefined, static_rhs=Undefined) :: float64"
del row_data.2
del Ax_data
row_data.1 = $row_data.1.148 :: float64
del $112inplace_add.15
jump 84

0
label 118
del row_data
del $phi86.3
del $phi84.2
del $84for_iter.5
"$loop_y.142[$r.143] = row_data.2 :: (array(float32, 1d, C), int64, float64) -> none"
del row_data.2
del r
jump 48

0
label 128
del $phi50.2
del $phi48.1
del $48for_iter.4
$136binary_multiply.5 = $s.291 * slice__height :: int64
"$const140.7 = const(int, 1) :: Literal[int](1)"
$142binary_add.8 = $s.291 + $const140.7 :: int64
del s
del $const140.7
$146binary_multiply.10 = $142binary_add.8 * slice__height :: int64

0
label 154
del x
del val
del slice_ptr
del slice_height
del row_data.1
del colidx
del $phi14.1
del $phi12.0
del $12for_iter.3
