In [1]:
import torch
import torchvision.transforms as T
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from tqdm import tqdm

from bioplnn.models import SpatiallyEmbeddedClassifier

In [2]:
# Torch setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_float32_matmul_precision("high")

In [3]:
# Model setup
model = SpatiallyEmbeddedClassifier(
    rnn_kwargs={
        "num_areas": 1,
        "area_kwargs": [
            {
                "in_size": [28, 28],
                "in_channels": 1,
                "out_channels": 32,
            },
        ],
    },
    num_classes=10,
    fc_dim=256,
    dropout=0.5,
).to(device)

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Define the loss function
criterion = nn.CrossEntropyLoss()

In [4]:
# Dataloader setup
transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
train_data = MNIST(root="data", train=True, transform=transform)
train_loader = DataLoader(
    train_data, batch_size=256, num_workers=8, shuffle=True
)

In [5]:
# Define the training loop
model.train()
# Calculate accuracy
correct = 0
total = 0
cum_loss = 0
for epoch in range(10):
    for i, (x, labels) in enumerate(tqdm(train_loader)):
        x = x.to(device)
        labels = labels.to(device)
        torch._inductor.cudagraph_mark_step_begin()
        logits = model(x, num_steps=5)
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Calculate running accuracy and loss
        _, predicted = torch.max(logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        cum_loss += loss.item()

    accuracy = correct / total
    print(f"Accuracy: {accuracy:.2%}")
    print(f"Loss: {cum_loss / len(train_loader):.4f}")

100%|██████████| 235/235 [00:08<00:00, 28.54it/s]


Accuracy: 20.89%
Loss: 2.0974


100%|██████████| 235/235 [00:05<00:00, 46.27it/s]


Accuracy: 23.93%
Loss: 4.0279


100%|██████████| 235/235 [00:05<00:00, 46.65it/s]


Accuracy: 25.13%
Loss: 5.9420


100%|██████████| 235/235 [00:04<00:00, 47.62it/s]


Accuracy: 25.80%
Loss: 7.8426


100%|██████████| 235/235 [00:05<00:00, 46.33it/s]


Accuracy: 26.32%
Loss: 9.7150


100%|██████████| 235/235 [00:05<00:00, 45.98it/s]


Accuracy: 26.94%
Loss: 11.5414


100%|██████████| 235/235 [00:05<00:00, 44.78it/s]


Accuracy: 27.74%
Loss: 13.3005


100%|██████████| 235/235 [00:05<00:00, 46.10it/s]


Accuracy: 28.42%
Loss: 15.0455


100%|██████████| 235/235 [00:04<00:00, 47.25it/s]


Accuracy: 29.06%
Loss: 16.7629


100%|██████████| 235/235 [00:04<00:00, 48.53it/s]

Accuracy: 29.64%
Loss: 18.4573



