In [1]:
%cd ~/om2/bioplnn

/rdma/vast-rdma/user/valmiki/bioplnn


  bkms = self.shell.db.get('bookmarks', {})


In [2]:
import torch
from sklearn.model_selection import ParameterGrid
from bioplnn.utils import AttrDict
from bioplnn.topography import TopographicalRNN
from bioplnn.dataset import get_dataloaders
from bioplnn.sparse_sgd import SparseSGD
import os

In [3]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [4]:
config = dict(
    sheet_size=[(10, 10), (100, 100), (1000, 1000), (10000, 10000)],
    connectivity_std=[1, 10, 100, 1000],
    synapses_per_neuron=[10, 100, 1000, 10000],
    num_timesteps=[100],
    bias=[True],
    mm_function=["torch_sparse", "native", "tsgu"],
    sparse_format=["torch_sparse", "coo", "csr"],
    batch_first=[True],
    adjacency_matrix_path=[None],
    self_recurrence=[True],
    input_indices=["connection/V1_indices_flat.pt"],
    output_indices=["connection/V4_indices_flat.pt"],
    activation=["relu", "tanh"],
    batch_size=[4, 16, 64, 256, 1024, 4096],
)
grid = ParameterGrid(config)

lr = 1e-3
momentum = 0.9

start_idx = 20
num_iters = 20

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for params in grid:
    print(f"Params: {params}")
    params = AttrDict(params)
    batch_size = params.batch_size
    del params.batch_size
    try:
        model = TopographicalRNN(**params).to(device)
    except Exception as e:
        print(e)
        continue

    if params.mm_function == "torch_sparse" or params.sparse_format == "coo":
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=lr,
            momentum=momentum,
        )
    else:
        optimizer = SparseSGD(
            model.parameters(),
            lr=lr,
            momentum=momentum,
        )
    criterion = torch.nn.CrossEntropyLoss()
    train_loader, test_loader = get_dataloaders(
        dataset="mnist_v1",
        root="data",
        retina_path="connection/V1_indices.npy",
        batch_size=batch_size,
        num_workers=4,
    )

    for i, (images, labels) in enumerate(train_loader):
        if i == start_idx:
            time = time.time()
        elif i == start_idx + num_iters:
            time = time.time() - time
            break
        images = images.to(device)
        labels = labels.to(device)
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        break
    print(f"Avg time: {time / num_iters:.2f} s")
    print("-" * 80)

Params: {'activation': 'relu', 'adjacency_matrix_path': None, 'batch_first': True, 'batch_size': 4, 'bias': True, 'connectivity_std': 1, 'input_indices': 'connection/V1_indices_flat.pt', 'mm_function': 'torch_sparse', 'num_timesteps': 100, 'output_indices': 'connection/V4_indices_flat.pt', 'self_recurrence': True, 'sheet_size': (10, 10), 'sparse_format': 'torch_sparse', 'synapses_per_neuron': 10}


../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [23,0,0], thread: [0,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [23,0,0], thread: [1,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [23,0,0], thread: [2,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [23,0,0], thread: [3,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [23,0,0], thread: [4,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [23,0,0], thread: [5,0,0] Assertion `-sizes[i] <

RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
