In [1]:
import numpy as np
import scipy.sparse as sp
import numba

In [2]:
N = int(1e6)
np.random.seed(1)

A = sp.dia_matrix((np.random.randn(1, N), [0]), shape=(N, N))
B = sp.csr_matrix(
    (
        np.random.randn(N * 4),  # 4N elements
        (
            np.repeat(np.arange(N), 4),
            np.random.randint(low=N, high=None, size=N * 4)
        )
    ),
    shape=(N, N)
)
B.sort_indices()
B.sum_duplicates()

In [3]:
def naive_multiply(x, y):
    z = x * y
    z.sort_indices()
    z.sum_duplicates()
    return z

C1 = naive_multiply(A, B)

In [4]:
%timeit naive_multiply(A, B)

149 ms ± 2.43 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [5]:
def multiply_in_numpy(x, y):
    res = y.copy()
    ncols_per_row = np.diff(y.indptr)
    row_indices = np.repeat(np.arange(y.shape[0]), ncols_per_row)
    res.data *= x.data.ravel()[row_indices]
    return res

C2 = multiply_in_numpy(A, B)

In [6]:
np.all(C1.data == C2.data)

True

In [7]:
%timeit multiply_in_numpy(A, B)

34 ms ± 1.38 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [8]:
@numba.njit
def numba_core(indptr, data_x, data_y, data_z):
    size = data_x.shape[0]
    for i in range(size):
        for k in range(indptr[i], indptr[i + 1]):
            data_z[k] = data_x[i] * data_y[k]

def multiply_in_numba(x, y):
    z_data = np.empty_like(y.data)
    numba_core(y.indptr, x.data.ravel(), y.data, z_data)
    return sp.csr_matrix((z_data, y.indices, y.indptr), shape=(N, N))

C3 = multiply_in_numba(A, B)

In [9]:
np.all(C1.data == C3.data)

True

In [10]:
%timeit multiply_in_numba(A, B)

7.3 ms ± 95.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
