In [1]:
import os

import torch
import torchvision.transforms as T
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from tqdm import tqdm

from bioplnn.models import ConnectomeODEClassifier

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_float32_matmul_precision("high")
print("Using device: {}".format(device))

Sat Mar 15 23:14:00 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:C4:00.0 Off |                    0 |
| N/A   36C    P0             53W /  300W |       3MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
# Model setup
connectivity_base_path = "/om2/user/valmiki/bioplnn/connectivity/sunny"
model = ConnectomeODEClassifier(
    rnn_kwargs={
        "input_size": 784,
        "hidden_size": 47521,
        "connectivity_hh": os.path.join(
            connectivity_base_path, "connectivity_hh.pt"
        ),
        "connectivity_ih": os.path.join(
            connectivity_base_path, "connectivity_ih_mnist.pt"
        ),
        "output_neurons": os.path.join(
            connectivity_base_path, "output_indices_mnist.pt"
        ),
        "nonlinearity": "Sigmoid",
        "compile_solver_kwargs": {
            "mode": "max-autotune",
            "dynamic": False,
            "fullgraph": True,
        },
    },
    num_classes=10,
    fc_dim=256,
    dropout=0.5,
).to(device)

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Define the loss function
criterion = nn.CrossEntropyLoss()

In [4]:
# Dataloader setup
transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
train_data = MNIST(
    root="data", train=True, transform=transform, download=True
)  # LNG -- download=True
# LNG
train_loader = DataLoader(
    train_data, batch_size=8, num_workers=0, shuffle=True
)

In [None]:
model.train()
n_epochs = 10

In [None]:
correct = 0
total = 0
cum_loss = 0
for epoch in range(n_epochs):
    for i, (x, labels) in enumerate(tqdm(train_loader)):
        x = x.to(device)
        labels = labels.to(device)
        torch._inductor.cudagraph_mark_step_begin()
        logits = model(x, num_steps=2)
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Calculate running accuracy and loss
        _, predicted = torch.max(logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        cum_loss += loss.item()

    accuracy = correct / total
    print(f"Accuracy: {accuracy:.2%}")
    print(f"Loss: {cum_loss / len(train_loader):.4f}")