In [1]:
import torch
import torchvision.transforms as T
from hydra import compose, initialize
from omegaconf import OmegaConf
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from tqdm import tqdm

from bioplnn.models import ConnectomeODEClassifier
from bioplnn.utils import (
    AttrDict,
)

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_float32_matmul_precision("high")
torch.cuda.empty_cache()
!nvidia-smi

Sat Mar 15 23:14:00 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:C4:00.0 Off |                    0 |
| N/A   36C    P0             53W /  300W |       3MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [4]:
with initialize(version_base=None, config_path="../config/", job_name="testing"):
    config = compose(
        config_name="config",
        overrides=["data=mnist_conn", "model=conn"],
    )
config = OmegaConf.to_container(config, resolve=True)
config = AttrDict(config)

In [11]:
model = ConnectomeODEClassifier(
    rnn_kwargs={
        "input_size": 784,
        "hidden_size": 47521,
        "connectivity_hh": "connectivity/sunny/connectivity_hh.pt",
        "connectivity_ih": "connectivity/sunny/connectivity_ih_mnist.pt",
        "output_neurons": "connectivity/sunny/output_indices_mnist.pt",
        "nonlinearity": "Sigmoid",
        "batch_first": False,
        "compile_solver_kwargs": {
            "mode": "max-autotune",
            "dynamic": False,
            "fullgraph": True,
        },
    },
    num_classes=10,
    fc_dim=256,
    dropout=0.5,
).to(device)
# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Define the loss function
criterion = nn.CrossEntropyLoss()


In [12]:
transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
train_data = MNIST(root="data", train=True, transform=transform)
train_loader = DataLoader(
    train_data, batch_size=8, num_workers=8, shuffle=True
)

In [None]:
# Define the training loop
model.train()
try:
    del x, labels, preds, loss
except Exception:
    pass
torch.cuda.empty_cache()
for epoch in range(10):
    for i, (x, labels) in enumerate(tqdm(train_loader)):
        x = x.to(device)
        labels = labels.to(device)
        torch._inductor.cudagraph_mark_step_begin()
        preds = model(x, num_steps=10)
        loss = criterion(preds, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()