## Imports

In [1]:
import torch
import torch_sparse
import torchsparsegradutils as tsgu
import math
from sklearn.model_selection import ParameterGrid
from bioplnn.utils import AttrDict, idx_2D_to_1D
from bioplnn.models import TopographicalRNN
from bioplnn.utils import get_mnist_v1_dataloaders
from bioplnn.sparse_sgd import SparseSGD

In [2]:
!nvidia-smi

Wed Jun 19 19:33:44 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:43:00.0 Off |                    0 |
| N/A   50C    P0             56W /  300W |       0MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

## Parameters

In [3]:
batch_size = 1
num_neurons = 100000
synapses_per_neuron = 100
torch.set_float32_matmul_precision("high")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Sparsity: {1 - synapses_per_neuron / num_neurons:.2%}")

Sparsity: 99.90%


In [4]:
# Create a dense vector
indices = []
for i in range(num_neurons):
    synapses = torch.randint(0, num_neurons, (synapses_per_neuron,))
    synapse_root = torch.ones_like(synapses) * i
    indices.append(torch.stack((synapses, synapse_root)))
indices = torch.cat(indices, dim=1).to(device)
values = torch.randn(num_neurons * synapses_per_neuron).to(device)

indices, values = torch_sparse.coalesce(
    indices, values, num_neurons, num_neurons
)

coo_matrix = (
    torch.sparse_coo_tensor(indices, values, (num_neurons, num_neurons))
    .coalesce()
    .to(device)
)
csr_matrix = coo_matrix.to_sparse_csr().to(device)
# dense_matrix = coo_matrix.to_dense().to(device)
dense_vector_batched = torch.randn(num_neurons, batch_size).to(device)
bias = torch.randn(num_neurons, 1).to(device)

  csr_matrix = coo_matrix.to_sparse_csr().to(device)


## On GPU

### `requires_grad == False`

In [31]:
%timeit torch.sparse.mm(coo_matrix, dense_vector_batched)

346 µs ± 21.6 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [32]:
%timeit torch.sparse.mm(csr_matrix, dense_vector_batched)

158 µs ± 182 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [33]:
%timeit tsgu.sparse_mm(coo_matrix, dense_vector_batched)

346 µs ± 36.3 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [34]:
%timeit tsgu.sparse_mm(csr_matrix, dense_vector_batched)

159 µs ± 139 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [35]:
%timeit torch_sparse.spmm(indices, values, num_neurons, num_neurons, dense_vector_batched)

771 µs ± 142 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [37]:
%timeit torch.mm(dense_matrix, dense_vector_batched)

24.4 ms ± 203 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### `requires_grad == True`

In [18]:
coo_weight = coo_matrix.clone().requires_grad_(True)
csr_weight = csr_matrix.clone().requires_grad_(True)
torch_sparse_weight = values.clone().requires_grad_(True)
# dense_weight = dense_matrix.clone().requires_grad_(True)

coo_optimizer = torch.optim.SGD([coo_weight], lr=0.01)
csr_optimizer = SparseSGD([csr_weight], lr=0.01)
torch_sparse_optimizer = torch.optim.SGD([torch_sparse_weight], lr=0.01)

In [None]:
%%timeit
out = torch.sparse.mm(coo_weight, dense_vector_batched)
coo_optimizer.zero_grad()
out.sum().backward()
coo_optimizer.step()

In [15]:
%%timeit
out = torch.sparse.mm(csr_weight, dense_vector_batched)
csr_optimizer.zero_grad()
out.sum().backward()
csr_optimizer.step()

7.83 ms ± 3.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [16]:
%%timeit
out = torch_sparse.spmm(indices, torch_sparse_weight, num_neurons, num_neurons, dense_vector_batched)
torch_sparse_optimizer.zero_grad()
out.sum().backward()
torch_sparse_optimizer.step()

1.03 ms ± 191 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## On CPU

In [43]:
coo_matrix = coo_matrix.to("cpu")
csr_matrix = csr_matrix.to("cpu")
# dense_matrix = coo_matrix.to_dense().to("cpu")
dense_vector_batched = dense_vector_batched.to("cpu")
indices = indices.to("cpu")
values = values.to("cpu")

: 

In [28]:
%timeit torch.mm(coo_matrix, dense_vector_batched)

166 ms ± 1.58 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [29]:
%timeit torch.mm(csr_matrix, dense_vector_batched)

10.1 ms ± 163 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [30]:
%timeit torch.sparse.mm(coo_matrix, dense_vector_batched)

165 ms ± 1.66 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [31]:
%timeit torch.sparse.mm(csr_matrix, dense_vector_batched)

9.7 ms ± 57.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [37]:
%timeit tsgu.sparse_mm(coo_matrix, dense_vector_batched)

166 ms ± 1.75 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [38]:
%timeit tsgu.sparse_mm(csr_matrix, dense_vector_batched)

9.65 ms ± 30.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [39]:
%timeit torch_sparse.spmm(indices, values, num_neurons, num_neurons, dense_vector_batched)

55.3 ms ± 176 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [20]:
# %timeit torch.mm(dense_matrix, dense_vector_batched)

## Testing framework

In [None]:
config = dict(
    sheet_size=[(10, 10), (100, 100), (1000, 1000), (10000, 10000)],
    connectivity_std=[1, 10, 100, 1000],
    synapses_per_neuron=[10, 100, 1000, 10000],
    num_timesteps=[100],
    bias=[True],
    mm_function=["torch_sparse", "native", "tsgu"],
    sparse_format=["torch_sparse", "coo", "csr"],
    batch_first=[True],
    adjacency_matrix_path=[None],
    self_recurrence=[True],
    input_indices=["connection/V1_indices_flat.pt"],
    output_indices=["connection/V4_indices_flat.pt"],
    activation=["relu", "tanh"],
    batch_size=[1, 16, 64, 256, 1024, 4096],
)
lr = 1e-3
momentum = 0.9

grid = ParameterGrid(config)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for params in grid:
    params = AttrDict(params)
    batch_size = params.batch_size
    del params.batch_size
    try:
        model = TopographicalRNN(**params).to(device)
    except Exception as e:
        print(e)
        continue

    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=lr,
        momentum=momentum,
        foreach=False if params.sparse_format == "csr" else True,
    )
    criterion = torch.nn.CrossEntropyLoss()
    train_loader, test_loader = get_mnist_v1_dataloaders(
        root="data",
        retina_path="connection/V1_indices.npy",
        batch_size=params.batch_size,
        num_workers=0,
    )
    
    model.train()
    for epoch in range(10):
        for i, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            if i % 100 == 0:
                print(f"Epoch: {epoch}, Loss: {loss.item()}")
    