In [1]:
from torchvision.datasets import MNIST
import torchvision.transforms as T
import torch
from dynamic_routing import InputStem, PrimaryCaps, DigitCaps
from loss_function import MarginLoss
from collections import defaultdict
from tqdm import tqdm

torch.manual_seed(0)
device = torch.device('mps' if torch.has_mps else 'cuda' if torch.has_cuda else 'cpu')

print(f'Using {device.type}')

Using mps


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = torch.nn.Sequential(
    InputStem(1),
    PrimaryCaps(256),
    DigitCaps(32 * 6 * 6)
).to(device)

transform = T.ToTensor()

trainset = MNIST('mnist', train=True, transform=transform, download=True)
validset = MNIST('mnist', train=False, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, 32, True, num_workers=8)
validloader = torch.utils.data.DataLoader(validset, 32, num_workers=8)

criterion = MarginLoss()
optimizer = torch.optim.AdamW(model.parameters(), 5e-4)

print(f'CapsNet parameter size: {sum(p.numel() for p in model.parameters()):,}')
print()

history = []
epochs = 1
for i in range(1, epochs + 1):
    record = defaultdict(float)

    print(f'Epoch {i}/{epochs}')
    with tqdm(total=len(trainloader) + len(validloader)) as pbar:
        model.train()
        for inputs, targets in trainloader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            record['loss'] += loss.item()
            record['accuracy'] += (outputs.argmax(-1) == targets).float().mean().item() * 100
            pbar.update()

        record['loss'] /= len(trainloader)
        record['accuracy'] /= len(trainloader)
        pbar.set_postfix(record)

        model.eval()
        for inputs, targets in validloader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            with torch.inference_mode():
                outputs = model(inputs)
                loss = criterion(outputs, targets)

            record['val_loss'] += loss.item()
            record['val_accuracy'] += (outputs.argmax(-1) == targets).float().mean().item() * 100
            pbar.update()

        record['val_loss'] /= len(validloader)
        record['val_accuracy'] /= len(validloader)
        pbar.set_postfix(record)
        history.append(record)

CapsNet parameter size: 6,804,384

Epoch 1/1


100%|██████████| 2188/2188 [01:39<00:00, 22.10it/s, loss=0.134, accuracy=96.6, val_loss=0.0974, val_accuracy=98.4]
