In [1]:
import torch
import torchvision.transforms as T
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from tqdm import tqdm

from bioplnn.models import SpatiallyEmbeddedClassifier

In [2]:
# Torch setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_float32_matmul_precision("high")

In [5]:
# Model setup
model = SpatiallyEmbeddedClassifier(
    rnn_kwargs={
        "num_areas": 1,
        "area_kwargs": [
            {
                "in_size": [28, 28],
                "in_channels": 1,
                "out_channels": 32,
            },
        ],
    },
    num_classes=10,
    fc_dim=256,
    dropout=0.5,
).to(device)

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Define the loss function
criterion = nn.CrossEntropyLoss()

In [6]:
# Dataloader setup
transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
train_data = MNIST(root="data", train=True, transform=transform)
train_loader = DataLoader(
    train_data, batch_size=256, num_workers=8, shuffle=True
)

In [7]:
# Define the training loop
model.train()
# Calculate accuracy
correct = 0
total = 0
cum_loss = 0
for epoch in range(10):
    for i, (x, labels) in enumerate(tqdm(train_loader)):
        x = x.to(device)
        labels = labels.to(device)
        torch._inductor.cudagraph_mark_step_begin()
        logits = model(x, num_steps=2)
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Calculate running accuracy and loss
        _, predicted = torch.max(logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        cum_loss += loss.item()

    accuracy = correct / total
    print(f"Accuracy: {accuracy:.2%}")
    print(f"Loss: {cum_loss / len(train_loader):.4f}")

  0%|          | 0/7500 [00:00<?, ?it/s]

 18%|█▊        | 1371/7500 [00:17<01:16, 79.86it/s] 


KeyboardInterrupt: 