## Imports

In [1]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
import torch_sparse
import torchsparsegradutils as tsgu
from torch_sparse.tensor import SparseTensor

In [2]:
!nvidia-smi

Fri Mar  7 19:20:51 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:04:00.0 Off |                    0 |
| N/A   57C    P0             58W /  300W |       0MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

## Parameters

In [3]:
batch_size = 1
num_neurons = 100000
synapses_per_neuron = 100
torch.set_float32_matmul_precision("high")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Sparsity: {1 - synapses_per_neuron / num_neurons:.2%}")
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

Sparsity: 99.90%


In [4]:
# Create a dense vector
indices = []
for i in range(num_neurons):
    synapses = torch.randint(0, num_neurons, (synapses_per_neuron,))
    synapse_root = torch.ones_like(synapses) * i
    indices.append(torch.stack((synapses, synapse_root)))
indices = torch.cat(indices, dim=1).to(device)
values = torch.randn(num_neurons * synapses_per_neuron).to(device)

indices, values = torch_sparse.coalesce(
    indices, values, num_neurons, num_neurons
)

coo_matrix = (
    torch.sparse_coo_tensor(indices, values, (num_neurons, num_neurons))
    .coalesce()
    .to(device)
)
csr_matrix = coo_matrix.to_sparse_csr().to(device)
spt_coo_matrix = SparseTensor.from_torch_sparse_coo_tensor(coo_matrix).to(
    device
)
spt_csr_matrix = SparseTensor.from_torch_sparse_csr_tensor(csr_matrix).to(
    device
)

# dense_matrix = coo_matrix.to_dense().to(device)
dense_vector_batched = torch.randn(num_neurons, batch_size).to(device)

  csr_matrix = coo_matrix.to_sparse_csr().to(device)


## On GPU

### `requires_grad == False`

In [5]:
%timeit torch.sparse.mm(coo_matrix, dense_vector_batched)

345 μs ± 1.35 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [6]:
%timeit torch.sparse.mm(csr_matrix, dense_vector_batched)

155 μs ± 184 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [8]:
%timeit tsgu.sparse_mm(coo_matrix, dense_vector_batched)

345 μs ± 48.2 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [10]:
%timeit tsgu.sparse_mm(csr_matrix, dense_vector_batched)

156 μs ± 197 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [11]:
%timeit torch_sparse.spmm(indices, values, num_neurons, num_neurons, dense_vector_batched)

781 μs ± 161 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
%timeit spt_csr_matrix.spmm(dense_vector_batched)

988 μs ± 12.5 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [11]:
%timeit spt_coo_matrix.spmm(dense_vector_batched)

988 μs ± 16.1 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
# %timeit torch.mm(dense_matrix, dense_vector_batched)

In [3]:
# sns.set(rc={"text.usetex": True})
times = [346, 158, 346, 159, 771, 24400]
labels = [
    "torch.sparse.mm (COO)",
    "torch.sparse.mm (CSR)",
    "tsgu.sparse_mm (COO)",
    "tsgu.sparse_mm (CSR)",
    "torch_sparse.spmm",
    "torch.mm (Dense)",
]
df = pd.DataFrame({"Time (µs)": times, "MM Function": labels})

In [None]:
plt.figure(figsize=(10, 5))
ax = sns.barplot(data=df, x="MM Function", y="Time (µs)", palette="flare")
ax.set_xticklabels(
    ax.get_xticklabels(),
    rotation=45,
    horizontalalignment="right",
)
ax.set_yscale("log")
plt.title("Sparse and Dense Matrix Multiplication Function Times")

In [None]:
df = df.drop(df[df["MM Function"] == "torch.mm (Dense)"].index)
plt.figure(figsize=(10, 5))
ax = sns.barplot(data=df, x="MM Function", y="Time (µs)", palette="flare")
ax.set_xticklabels(
    ax.get_xticklabels(),
    rotation=45,
    horizontalalignment="right",
)
plt.title("Sparse Matrix Multiplication Function Times")

### `requires_grad == True`

In [5]:
coo_weight = coo_matrix.clone().requires_grad_(True)
# csr_weight = csr_matrix.clone().requires_grad_(True)
torch_sparse_weight = values.clone().requires_grad_(True)
spt_weight = spt_csr_matrix.clone().requires_grad_(True)
spt_value = spt_csr_matrix.storage.value()
# dense_weight = dense_matrix.clone().requires_grad_(True)

coo_optimizer = torch.optim.SGD([coo_weight], lr=0.01)
# csr_optimizer = SparseSGD([csr_weight], lr=0.01)
torch_sparse_optimizer = torch.optim.SGD([torch_sparse_weight], lr=0.01)
spt_optimizer = torch.optim.SGD([spt_value], lr=0.01)

In [6]:
torch.cuda.empty_cache()

In [7]:
%%timeit
out = tsgu.sparse_mm(coo_weight, dense_vector_batched)
coo_optimizer.zero_grad()
out.sum().backward()
coo_optimizer.step()

OutOfMemoryError: CUDA out of memory. Tried to allocate 19.07 GiB. GPU 0 has a total capacity of 79.14 GiB of which 11.19 GiB is free. Including non-PyTorch memory, this process has 67.94 GiB memory in use. Of the allocated memory 64.87 GiB is allocated by PyTorch, and 2.59 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [14]:
%%timeit
out = torch_sparse.spmm(
    indices,
    torch_sparse_weight,
    num_neurons,
    num_neurons,
    dense_vector_batched,
)
torch_sparse_optimizer.zero_grad()
out.sum().backward()
torch_sparse_optimizer.step()

1.03 ms ± 141 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [23]:
%%timeit
out = spt_weight.spmm(dense_vector_batched)
spt_optimizer.zero_grad()
out.sum().backward()
spt_optimizer.step()

3.25 ms ± 112 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## On CPU

In [None]:
coo_matrix = coo_matrix.to("cpu")
csr_matrix = csr_matrix.to("cpu")
# dense_matrix = coo_matrix.to_dense().to("cpu")
dense_vector_batched = dense_vector_batched.to("cpu")
indices = indices.to("cpu")
values = values.to("cpu")

In [None]:
%timeit torch.mm(coo_matrix, dense_vector_batched)

In [None]:
%timeit torch.mm(csr_matrix, dense_vector_batched)

In [None]:
%timeit torch.sparse.mm(coo_matrix, dense_vector_batched)

In [None]:
%timeit torch.sparse.mm(csr_matrix, dense_vector_batched)

In [None]:
%timeit tsgu.sparse_mm(coo_matrix, dense_vector_batched)

In [None]:
%timeit tsgu.sparse_mm(csr_matrix, dense_vector_batched)

In [None]:
%timeit torch_sparse.spmm(indices, values, num_neurons, num_neurons, dense_vector_batched)

In [20]:
# %timeit torch.mm(dense_matrix, dense_vector_batched)