In [1]:
from collections import defaultdict

import torch
from torchvision import datasets as D, transforms as T
from tqdm import tqdm

import layer
from loss_function import MarginLoss

torch.manual_seed(0)
device = torch.device('mps' if torch.has_mps else 'cuda' if torch.has_cuda else 'cpu')

print(f'Using {device.type}')

  from .autonotebook import tqdm as notebook_tqdm


Using mps


In [2]:
train_dataset = D.MNIST('mnist', True, T.ToTensor(), download=True)
valid_dataset = D.MNIST('mnist', False, T.ToTensor())

batch_size = 32
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size, True, num_workers=8)
validloader = torch.utils.data.DataLoader(valid_dataset, batch_size, True, num_workers=8)

lr = 5e-4
model = torch.nn.Sequential(
    torch.nn.Conv2d(1, 256, 9),
    torch.nn.ReLU(),
    layer.PrimaryCapsule(256, 32, 8, 9, 2),
    layer.Squash('dr'),
    layer.CapsuleTransform(32 * 6 * 6, 8, 10, 16),
    layer.DynamicRouter(10, 16)
).to('mps')

print(f'Model size: {sum(p.numel() for p in model.parameters()):,}')

criterion = MarginLoss()
optimizer = torch.optim.Adam(model.parameters(), lr)

Model size: 6,804,384


In [3]:
print(f'CapsNet parameter size: {sum(p.numel() for p in model.parameters()):,}')
print()

history = []
epochs = 150
for i in range(1, epochs + 1):
    record = defaultdict(float)

    print(f'Epoch {i}/{epochs}')
    with tqdm(total=len(trainloader) + len(validloader)) as pbar:
        model.train()
        for inputs, targets in trainloader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs).norm(p=2, dim=-1)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            record['loss'] += loss.item()
            record['accuracy'] += (outputs.argmax(-1) == targets).float().mean().item() * 100
            pbar.update()

        record['loss'] /= len(trainloader)
        record['accuracy'] /= len(trainloader)
        pbar.set_postfix(record)

        model.eval()
        for inputs, targets in validloader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            with torch.inference_mode():
                outputs = model(inputs).norm(p=2, dim=-1)
                loss = criterion(outputs, targets)

            record['val_loss'] += loss.item()
            record['val_accuracy'] += (outputs.argmax(-1) == targets).float().mean().item() * 100
            pbar.update()wmm

        record['val_loss'] /= len(validloader)
        record['val_accuracy'] /= len(validloader)
        pbar.set_postfix(record)
        history.append(record)

CapsNet parameter size: 6,804,384

Epoch 1/150


100%|██████████| 2188/2188 [03:23<00:00, 10.77it/s, loss=0.221, accuracy=75.6, val_loss=0.245, val_accuracy=74.2]


Epoch 2/150


100%|██████████| 2188/2188 [03:26<00:00, 10.61it/s, loss=0.167, accuracy=82.8, val_loss=0.264, val_accuracy=77.7]


Epoch 3/150


100%|██████████| 2188/2188 [03:25<00:00, 10.67it/s, loss=0.143, accuracy=85.8, val_loss=0.293, val_accuracy=76.8]


Epoch 4/150


100%|██████████| 2188/2188 [03:21<00:00, 10.88it/s, loss=0.13, accuracy=87, val_loss=0.378, val_accuracy=68.6]


Epoch 5/150


100%|██████████| 2188/2188 [03:24<00:00, 10.69it/s, loss=0.129, accuracy=87.5, val_loss=0.394, val_accuracy=72.3]


Epoch 6/150


100%|██████████| 2188/2188 [03:24<00:00, 10.72it/s, loss=0.12, accuracy=88.2, val_loss=0.474, val_accuracy=68.7]


Epoch 7/150


100%|██████████| 2188/2188 [03:22<00:00, 10.81it/s, loss=0.123, accuracy=88, val_loss=0.641, val_accuracy=61.2]


Epoch 8/150


100%|██████████| 2188/2188 [03:22<00:00, 10.81it/s, loss=0.124, accuracy=87.6, val_loss=0.696, val_accuracy=51.7]


Epoch 9/150


100%|██████████| 2188/2188 [03:24<00:00, 10.68it/s, loss=0.121, accuracy=88.4, val_loss=0.698, val_accuracy=53.2]


Epoch 10/150


100%|██████████| 2188/2188 [03:25<00:00, 10.67it/s, loss=0.124, accuracy=88.1, val_loss=0.744, val_accuracy=56.9]


Epoch 11/150


100%|██████████| 2188/2188 [03:34<00:00, 10.21it/s, loss=0.12, accuracy=88.3, val_loss=0.768, val_accuracy=58.6]


Epoch 12/150


100%|██████████| 2188/2188 [03:38<00:00, 10.00it/s, loss=0.113, accuracy=89.1, val_loss=0.756, val_accuracy=59]


Epoch 13/150


100%|██████████| 2188/2188 [03:33<00:00, 10.24it/s, loss=0.109, accuracy=89.5, val_loss=0.733, val_accuracy=64.9]


Epoch 14/150


100%|██████████| 2188/2188 [03:36<00:00, 10.09it/s, loss=0.102, accuracy=90.3, val_loss=0.856, val_accuracy=51.1]


Epoch 15/150


100%|██████████| 2188/2188 [03:36<00:00, 10.11it/s, loss=0.103, accuracy=90.2, val_loss=0.818, val_accuracy=57.3]


Epoch 16/150


100%|██████████| 2188/2188 [03:33<00:00, 10.23it/s, loss=0.1, accuracy=90.4, val_loss=0.864, val_accuracy=54.6]


Epoch 17/150


100%|██████████| 2188/2188 [03:38<00:00, 10.01it/s, loss=0.0982, accuracy=90.5, val_loss=0.872, val_accuracy=53.8]


Epoch 18/150


100%|██████████| 2188/2188 [03:42<00:00,  9.85it/s, loss=0.101, accuracy=90.5, val_loss=0.86, val_accuracy=53.7]


Epoch 19/150


100%|██████████| 2188/2188 [03:33<00:00, 10.25it/s, loss=0.104, accuracy=90.2, val_loss=0.937, val_accuracy=50]


Epoch 20/150


100%|██████████| 2188/2188 [03:32<00:00, 10.32it/s, loss=0.105, accuracy=89.9, val_loss=0.956, val_accuracy=50.1]


Epoch 21/150


100%|██████████| 2188/2188 [03:36<00:00, 10.12it/s, loss=0.103, accuracy=90, val_loss=0.817, val_accuracy=54]


Epoch 22/150


 79%|███████▊  | 1719/2188 [03:05<00:50,  9.26it/s]


KeyboardInterrupt: 