In [None]:
import numpy as np

In [None]:
class CSR:
    """
    Compressed Sparse Row matrix.
    data:  nonzeros (nnz,)
    indices: column index for each nonzero (nnz,)
    indptr: row pointer (n_rows+1,)
    shape: (m, n)
    """
    def __init__(self, data, indices, indptr, shape):
        self.data = np.asarray(data)
        self.indices = np.asarray(indices, dtype=np.int64)
        self.indptr = np.asarray(indptr, dtype=np.int64)
        self.shape = shape

    @staticmethod
    def from_coo(row, col, val, shape):
        """Build CSR from COO (row, col, val), duplicates summed."""
        m, n = shape
        row = np.asarray(row, dtype=np.int64)
        col = np.asarray(col, dtype=np.int64)
        val = np.asarray(val)

        # sort by (row, col) for stable CSR
        order = np.lexsort((col, row))
        row, col, val = row[order], col[order], val[order]

        # sum duplicates
        if len(val) > 0:
            keep = [0]
            for i in range(1, len(val)):
                if row[i] == row[keep[-1]] and col[i] == col[keep[-1]]:
                    val[keep[-1]] += val[i]
                else:
                    keep.append(i)
            row, col, val = row[keep], col[keep], val[keep]

        counts = np.bincount(row, minlength=m)  # number of nonzeros per row
        indptr = np.empty(m + 1, dtype=np.int64)
        indptr[0] = 0
        np.cumsum(counts, out=indptr[1:])

        return CSR(val, col, indptr, shape)

    # y = A @ x  (x is (n,) vector)
    def matvec(self, x):
        m, n = self.shape
        x = np.asarray(x)
        assert x.shape == (n,)
        y = np.zeros(m, dtype=np.result_type(self.data, x, np.float32))
        for i in range(m):
            start, end = self.indptr[i], self.indptr[i+1]
            cols = self.indices[start:end]
            vals = self.data[start:end]
            y[i] = (vals * x[cols]).sum()
        return y

    # Y = A @ X  (X is (n, k) dense)
    def matmat(self, X):
        m, n = self.shape
        X = np.asarray(X)
        assert X.shape[0] == n
        k = X.shape[1]
        Y = np.zeros((m, k), dtype=np.result_type(self.data, X, np.float32))
        # row-by-row O(nnz * k)
        for i in range(m):
            start, end = self.indptr[i], self.indptr[i+1]
            cols = self.indices[start:end]          # (rnnz,)
            vals = self.data[start:end][:, None]    # (rnnz,1)
            # weighted sum of rows of X
            # Y[i] = sum_j vals[j] * X[cols[j], :]
            Y[i] = (vals * X[cols]).sum(axis=0)
        return Y

    def to_dense(self):
        m, n = self.shape
        A = np.zeros((m, n), dtype=self.data.dtype)
        for i in range(m):
            s, e = self.indptr[i], self.indptr[i+1]
            A[i, self.indices[s:e]] = self.data[s:e]
        return A