## Imports

In [1]:
import torch
import torch_sparse

In [2]:
!nvidia-smi

Tue Jan  2 15:14:14 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.86.01    Driver Version: 515.86.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  On   | 00000000:C4:00.0 Off |                    0 |
| N/A   61C    P0    62W / 300W |      0MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Parameters

In [18]:
batch_size = 1
num_neurons = 100000
synapses_per_neuron = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [19]:
# Create a dense vector
indices = []
for i in range(num_neurons):
    synapses = torch.randint(0, num_neurons, (synapses_per_neuron,))
    synapse_root = torch.ones_like(synapses) * i
    indices.append(torch.stack((synapses, synapse_root)))
indices = torch.cat(indices, dim=1).to(device)
values = torch.randn(num_neurons*synapses_per_neuron).to(device)

indices, values = torch_sparse.coalesce(indices, values, num_neurons, num_neurons)
        
coo_matrix = torch.sparse_coo_tensor(indices, values, (num_neurons, num_neurons)).coalesce().to(device)
csr_matrix = coo_matrix.to_sparse_csr().to(device)
# dense_matrix = coo_matrix.to_dense().to(device)
dense_vector_batched = torch.randn(num_neurons, batch_size).to(device)
bias = torch.randn(num_neurons, 1).to(device)

## On GPU

In [26]:
%timeit torch.mm(coo_matrix, dense_vector_batched) + bias

351 µs ± 381 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [27]:
%timeit torch.addmm(bias, coo_matrix, dense_vector_batched)

345 µs ± 37 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [28]:
%timeit torch.mm(csr_matrix, dense_vector_batched) + bias

165 µs ± 143 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [29]:
%timeit torch.addmm(bias, csr_matrix, dense_vector_batched)

165 µs ± 307 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [32]:
%timeit torch.sparse.mm(coo_matrix, dense_vector_batched) + bias

351 µs ± 275 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [35]:
%timeit torch.sparse.addmm(bias, csr_matrix, dense_vector_batched)

365 µs ± 8.28 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [33]:
%timeit torch.sparse.mm(csr_matrix, dense_vector_batched) + bias

166 µs ± 108 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [34]:
%timeit torch_sparse.spmm(indices, values, num_neurons, num_neurons, dense_vector_batched) + bias

KeyboardInterrupt: 

In [None]:
# %timeit torch.mv(dense_matrix, dense_vector_batched)

## On CPU

In [16]:
coo_matrix = coo_matrix.to('cpu')
csr_matrix = csr_matrix.to('cpu')
# dense_matrix = dense_matrix.to('cpu')
dense_vector_batched = dense_vector_batched.to('cpu')
indices = indices.to('cpu')
values = values.to('cpu')

In [17]:
%timeit torch.mm(coo_matrix, dense_vector_batched)

1.46 s ± 12.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%timeit torch.mm(csr_matrix, dense_vector_batched)

1.23 ms ± 18.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
%timeit torch.sparse.mm(coo_matrix, dense_vector_batched)

18.3 ms ± 133 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%timeit torch.sparse.mm(csr_matrix, dense_vector_batched)

2.04 ms ± 110 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%timeit torch_sparse.spmm(indices, values, num_neurons, num_neurons, dense_vector_batched)

8 ms ± 111 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
# %timeit torch.mv(dense_matrix, dense_vector_batched)