In [None]:
import numpy as np

In [None]:
def make_sparse_weights(rows, cols):
    W = np.random.randn(rows, cols)
    mask = (np.random.randn((rows * cols) // 2) >= 0).astype(np.uint8)
    sparsity_mask = np.eye(2)[mask, :].reshape(rows, cols)
    return W*sparsity_mask, sparsity_mask

In [None]:
np.set_printoptions(precision=3)

In [None]:
rows = 16
cols = 8
W, sparsity_mask = make_sparse_weights(rows, cols)
W, sparsity_mask

In [None]:
x = np.random.randn(cols, 1)
x

In [None]:
a = W @ x
a

Reordering computation to not need to store the zero weights. Only V and sparsity_mask would need to be stored as part of the model.

In [None]:
# Non-zero entries of W:
V = W[sparsity_mask.astype(bool)].reshape(rows, cols // 2)
V

In [None]:
# Broadcast and do elementwise multiplication:
P = V[np.newaxis, ...]*x.reshape(cols// 2, 2).T[:, np.newaxis, :]
P.shape

In [None]:
# sparsity_mask tells us which of the two channels we want to keep:
Q = P * sparsity_mask.reshape(-1, 2).T.reshape(2, rows, cols // 2)
Q

In [None]:
# summing gives the desired output
b = np.sum(Q, axis=(0, 2))
a.flatten(), b