## Imports

In [1]:
import torch
import torch_sparse
import torchsparsegradutils as tsgu
import math
from sklearn.model_selection import ParameterGrid
from bioplnn.utils import AttrDict, idx_2D_to_1D
from bioplnn.topography import TopographicalRNN
from bioplnn.dataset import get_dataloaders
from bioplnn.sparse_sgd import SparseSGD

In [2]:
!nvidia-smi

Fri Mar 29 13:56:29 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:44:00.0 Off |                    0 |
| N/A   56C    P0             63W /  300W |       0MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

## Parameters

In [4]:
batch_size = 1
num_neurons = 100000
synapses_per_neuron = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
# Create a dense vector
indices = []
for i in range(num_neurons):
    synapses = torch.randint(0, num_neurons, (synapses_per_neuron,))
    synapse_root = torch.ones_like(synapses) * i
    indices.append(torch.stack((synapses, synapse_root)))
indices = torch.cat(indices, dim=1).to(device)
values = torch.randn(num_neurons * synapses_per_neuron).to(device)

indices, values = torch_sparse.coalesce(
    indices, values, num_neurons, num_neurons
)

coo_matrix = (
    torch.sparse_coo_tensor(indices, values, (num_neurons, num_neurons))
    .coalesce()
    .to(device)
)
csr_matrix = coo_matrix.to_sparse_csr().to(device)
# dense_matrix = coo_matrix.to_dense().to(device)
dense_vector_batched = torch.randn(num_neurons, batch_size).to(device)
bias = torch.randn(num_neurons, 1).to(device)

  csr_matrix = coo_matrix.to_sparse_csr().to(device)


## On GPU

In [9]:
%timeit torch.mm(coo_matrix, dense_vector_batched) + bias

349 µs ± 129 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [10]:
%timeit torch.addmm(bias, coo_matrix, dense_vector_batched)

345 µs ± 63.7 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [11]:
%timeit torch.mm(csr_matrix, dense_vector_batched) + bias

163 µs ± 151 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [12]:
%timeit torch.addmm(bias, csr_matrix, dense_vector_batched)

164 µs ± 106 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [9]:
%timeit torch.sparse.mm(coo_matrix, dense_vector_batched) + bias

349 µs ± 58.3 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [8]:
%timeit torch.sparse.addmm(bias, coo_matrix, dense_vector_batched)

344 µs ± 36.8 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [7]:
%timeit torch.sparse.mm(csr_matrix, dense_vector_batched) + bias

162 µs ± 149 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [17]:
%timeit torch.sparse.addmm(bias, csr_matrix, dense_vector_batched)

164 µs ± 190 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [13]:
%timeit sparse_mm(coo_matrix, dense_vector_batched) + bias

349 µs ± 49.9 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [14]:
%timeit sparse_mm(csr_matrix, dense_vector_batched) + bias

162 µs ± 98.1 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [15]:
%timeit torch_sparse.spmm(indices, values, num_neurons, num_neurons, dense_vector_batched) + bias

782 µs ± 154 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
# %timeit torch.mv(dense_matrix, dense_vector_batched)

In [30]:
%%time
weight = coo_matrix.clone().requires_grad_(True)
out = torch.sparse.mm(weight, dense_vector_batched) + bias
out.sum().backward()

CPU times: user 1.73 ms, sys: 0 ns, total: 1.73 ms
Wall time: 1.79 ms


In [27]:
%%time
weight = csr_matrix.clone().requires_grad_(True)
out = torch.sparse.mm(weight, dense_vector_batched) + bias
out.sum().backward()

CPU times: user 1.77 ms, sys: 0 ns, total: 1.77 ms
Wall time: 1.86 ms


In [23]:
weight

tensor(crow_indices=tensor([      0,     108,     214,  ..., 9994846,
                            9994958, 9995039]),
       col_indices=tensor([  917,  1160,  2078,  ..., 95933, 95997, 98054]),
       values=tensor([-0.6188, -2.0300, -0.6389,  ...,  1.2913, -0.5155,
                       0.7060]), device='cuda:0', size=(100000, 100000),
       nnz=9995039, layout=torch.sparse_csr, requires_grad=True)

## On CPU

In [16]:
coo_matrix = coo_matrix.to("cpu")
csr_matrix = csr_matrix.to("cpu")
# dense_matrix = dense_matrix.to('cpu')
dense_vector_batched = dense_vector_batched.to("cpu")
indices = indices.to("cpu")
values = values.to("cpu")

In [17]:
%timeit torch.mm(coo_matrix, dense_vector_batched)

1.46 s ± 12.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%timeit torch.mm(csr_matrix, dense_vector_batched)

1.23 ms ± 18.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
%timeit torch.sparse.mm(coo_matrix, dense_vector_batched)

18.3 ms ± 133 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%timeit torch.sparse.mm(csr_matrix, dense_vector_batched)

2.04 ms ± 110 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%timeit torch_sparse.spmm(indices, values, num_neurons, num_neurons, dense_vector_batched)

8 ms ± 111 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
# %timeit torch.mv(dense_matrix, dense_vector_batched)

## Testing framework

In [None]:
config = dict(
    sheet_size=[(10, 10), (100, 100), (1000, 1000), (10000, 10000)],
    connectivity_std=[1, 10, 100, 1000],
    synapses_per_neuron=[10, 100, 1000, 10000],
    num_timesteps=[100],
    bias=[True],
    mm_function=["torch_sparse", "native", "tsgu"],
    sparse_format=["torch_sparse", "coo", "csr"],
    batch_first=[True],
    adjacency_matrix_path=[None],
    self_recurrence=[True],
    input_indices=["connection/V1_indices_flat.pt"],
    output_indices=["connection/V4_indices_flat.pt"],
    activation=["relu", "tanh"],
    batch_size=[1, 16, 64, 256, 1024, 4096],
)
lr = 1e-3
momentum = 0.9

grid = ParameterGrid(config)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for params in grid:
    params = AttrDict(params)
    batch_size = params.batch_size
    del params.batch_size
    try:
        model = TopographicalRNN(**params).to(device)
    except Exception as e:
        print(e)
        continue

    if params.mm_function == "torch_sparse" or params.sparse_format == "coo":
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=lr,
            momentum=momentum,
        )
    else:
        optimizer = SparseSGD(
            model.parameters(),
            lr=lr,
            momentum=momentum,
        )
    criterion = torch.nn.CrossEntropyLoss()
    train_loader, test_loader = get_dataloaders(
        dataset="mnist",
        root="data",
        retina_path="connection/V1_indices.npy",
        batch_size=params.batch_size,
        num_workers=0,
    )
    