In [37]:
import math
import numpy as np

def CSR_to_SELLPACK(rowptr, colidx, val, slice_height):
    """Convert CSR format to Sliced ELLPACK format"""
    
    N = len(rowptr) - 1 # number of rows
    slice_number = math.ceil(N / slice_height) # how many slices
    nnz_count = 0
    
    ell_colidx = []
    ell_sliceptr = []
    ell_val = []
    
    for i in range(slice_number):
        max_nnz = 0
        for s in range(slice_height - 1):
            col_count = rowptr[i * slice_height + s + 1] - rowptr[i * slice_height + s]
            max_nnz = max(max_nnz, col_count)
        
        ell_sliceptr.append(nnz_count)
        for j in range(max_nnz):
            for k in range(slice_height):
                idx = i * slice_height + k # row index
                now_ptr = rowptr[idx] # start index of this row
                next_ptr = rowptr[idx + 1] # start index of next row
                nnz_count += 1 # count non-zero number
                
                if now_ptr + j < next_ptr:
                    ell_colidx.append(colidx[now_ptr + j])
                    ell_val.append(val[now_ptr + j])
                else:
                    ell_colidx.append(-1) # -1 means invalid
                    ell_val.append(0) # padded zero
                
    ell_sliceptr.append(nnz_count)
    return ell_colidx, ell_sliceptr, ell_val

In [38]:
#[[a, 0, b, 0],
# [0, c, 0, d],
# [0, e, 0, 0],
# [0, 0, 0, f]]

# a = 1.11, b = 3.33, c = 2.22, d = 4.44, e = 5.55, f = 6.66

rowptr = np.array([0, 2, 4, 5, 6])
colidx = np.array([0, 2, 1, 3, 1, 3])
val = np.array([1.11, 3.33, 2.22, 4.44, 5.55, 6.66], dtype=np.float64)

slice_height = 2
ell_colidx, ell_sliceptr, ell_val = CSR_to_SELLPACK(rowptr, colidx, val, slice_height)

In [39]:
ell_colidx

[0, 1, 2, 3, 1, 3]

In [40]:
ell_sliceptr

[0, 4, 6]

In [41]:
ell_val

[1.11, 2.22, 3.33, 4.44, 5.55, 6.66]