In [1]:
# https://leimao.github.io/blog/CSR-Sparse-Matrix-Multiplication/
from __future__ import annotations
from typing import Tuple
import numpy as np


In [None]:
class CSRMatrix:
    def __init__(self, indptr: np.ndarray, indices: np.ndarray, 
                 data: np.ndarray, shape: Tuple[int, int]) -> None:
        """
        e.g. of CSR representation of a matrix

        A = [[10, 20, 0, 0],
             [0, 0, 0, 12],
             [0, 1, 0, 0]]

        v = [10, 20, 12, 1]
        c = [0, 1, 3, 1]
        r = [0, 2, 3, 4]

        """
        # Row index
        self.indptr = indptr
        # Column index
        self.indices = indices
        self.data = data
        self.shape = shape

        self.dtype = self.data.dtype

    def toarray(self) -> np.ndarray:
        """
        Convert CSR matrix to numpy array.

        Returns:
            np.ndarray: Dense matrix.
        """

        array = np.zeros(self.shape).astype(self.data.dtype)
        num_rows = self.shape[0]
        
        assert num_rows == len(self.indptr)-1

        for i in range(num_rows):
            num_vals = self.indptr[i+1] - self.indptr[i]
            for k in num_vals:
                val = self.data[self.indptr[i]+k]
                j = self.indices[self.indptr[i]+k]
                array[i][j] = val

        return array

    def transpose(self) -> CSRMatrix:
        """
        Transpose of a CSR Matrix.

        e.g. 

        A = [[10, 20, 0, 0],
             [0, 0, 0, 12],
             [0, 1, 0, 0]]

        v = [10, 20, 12, 1]
        c = [0, 1, 3, 1]
        r = [0, 2, 3, 4]

        A^T = [[10, 0, 0],
               [20, 0, 1],
               [0, 0, 0],
               [0, 12, 0]]

        v' = [10, 20, 1, 12]
        c' = [0, 0, 2, 1]
        r' = [0, 1, 3, 3, 4]
        """
        col_2d_idx = self.indices # [0,1,3,1]
        row_2d_idx = np.zeros_like(col_2d_idx) # [0,0,0,0] # row indices of non-zero elements
        k=0
        # self.indptr = [0,2,3,4], self.shape = (3,4)
        num_rows = self.shape[0] # 3
        for i in range(num_rows):
            num_vals = self.indptr[i+1] - self.indptr[i] # 1 # number of non zeros in each rows
            for j in range(num_vals): 
                row_2d_idx[k+j] = i 
            k+=num_vals # row_2d_idx = [0, 0, 1, 2]
        assert k == self.indptr[-1]

        # exchange the row and column index
        new_row_2d_idx = col_2d_idx # [0,1,3,1] 
        new_col_2d_idx = row_2d_idx # [0,0,1,2]

        # Stable sort by using row and column index to find the new nz index of transposed matrix
        ind = np.lexsort((new_col_2d_idx, new_row_2d_idx)) # [0, 1, 3, 2]
        new_row_2d_idx = new_row_2d_idx[ind] # [0, 1, 1, 3] 
        new_col_2d_idx = new_col_2d_idx[ind] # [0, 0, 2, 1]

        # Create CSR matrix
        # O(N)
        indices = new_col_2d_idx # [0, 1, 1, 3]
        # self.data = [10, 20, 12, 1]
        data = self.data[ind] # [10, 20, 1, 12]
        shape = (self.shape[1], self.shape[0]) # (4,3)
        num_rows = shape[0] # 4
        indptr = np.zeros(num_rows + 1).astype(np.int32) # [0, 0, 0, 0, 0]

        for i in new_row_2d_idx:
            indptr[i+1] += 1 # [0, 1, 2, 0, 1] # find the number non zero entries in transposed matrix

        for i in range(num_rows):
            indptr[i+1] += indptr[i] # [0, 1, 3, 3, 4] # take cummulative sum

        indices = np.array(indices).astype(np.int32)
        data = np.array(data).astype(self.dtype)

        csr_matrix = CSRMatrix(indptr=indptr,
                               indices=indices,
                               data=data,
                               shape=shape)

        return csr_matrix

    

        




        





