In [1]:
from torchvision.datasets import MNIST
import torchvision.transforms as T
import torch
from capsnet import *
from loss_function import MarginLoss
from collections import defaultdict
from tqdm import tqdm

torch.manual_seed(0)
device = torch.device('mps' if torch.has_mps else 'cuda' if torch.has_cuda else 'cpu')

print(f'Using {device.type}')

  from .autonotebook import tqdm as notebook_tqdm


Using mps


In [None]:
model = torch.nn.Sequential(
    torch.nn.Conv2d(1, 32, 5, bias=False),
    torch.nn.InstanceNorm2d(32, affine=True),
    torch.nn.ReLU(),
    torch.nn.Conv2d(32, 64, 3, bias=False),
    torch.nn.InstanceNorm2d(64, affine=True),
    torch.nn.ReLU(),
    torch.nn.Conv2d(64, 64, 3, bias=False),
    torch.nn.InstanceNorm2d(64, affine=True),
    torch.nn.ReLU(),
    torch.nn.Conv2d(64, 128, 3, 2, bias=False),
    torch.nn.InstanceNorm2d(128, affine=True),
    torch.nn.ReLU(),
    PrimaryCapsule(128, 8, 16, 9, activation=SquashSA, depthwise=True),
    CapsuleTransform(8, 16, 16, 10),
    SelfAttentionRouter(16, 10, activation=SquashSA)
).to(device)

transform = T.ToTensor()

trainset = MNIST('mnist', train=True, transform=transform, download=True)
validset = MNIST('mnist', train=False, transform=transform)

epochs = 100
batch_size = 16

trainloader = torch.utils.data.DataLoader(trainset, batch_size, True, num_workers=8)
validloader = torch.utils.data.DataLoader(validset, batch_size, num_workers=8)

criterion = MarginLoss()
optimizer = torch.optim.AdamW(model.parameters(), 1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(trainloader) * 50)

print(f'CapsNet parameter size: {sum(p.numel() for p in model.parameters()):,}')
print()

history = []
for i in range(1, epochs + 1):
    record = defaultdict(float)

    print(f'Epoch {i}/{epochs}')
    with tqdm(total=len(trainloader) + len(validloader)) as pbar:
        model.train()
        for inputs, targets in trainloader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs).norm(p=2, dim=-1)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            record['loss'] += loss.item()
            record['accuracy'] += (outputs.argmax(-1) == targets).float().mean().item() * 100
            pbar.update()

        record['loss'] /= len(trainloader)
        record['accuracy'] /= len(trainloader)
        pbar.set_postfix(record)

        model.eval()
        for inputs, targets in validloader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            with torch.inference_mode():
                outputs = model(inputs).norm(p=2, dim=-1)
                loss = criterion(outputs, targets)

            record['val_loss'] += loss.item()
            record['val_accuracy'] += (outputs.argmax(-1) == targets).float().mean().item() * 100
            pbar.update()

        record['val_loss'] /= len(validloader)
        record['val_accuracy'] /= len(validloader)
        pbar.set_postfix(record)
        history.append(record)

CapsNet parameter size: 161,536

Epoch 1/100


100%|██████████| 4375/4375 [01:23<00:00, 52.32it/s, loss=0.645, accuracy=71.8, val_loss=0.533, val_accuracy=83.1] 


Epoch 2/100


100%|██████████| 4375/4375 [01:23<00:00, 52.33it/s, loss=0.234, accuracy=39, val_loss=0, val_accuracy=0] 


Epoch 3/100


100%|██████████| 4375/4375 [01:21<00:00, 53.70it/s, loss=0, accuracy=0, val_loss=0, val_accuracy=0] 


Epoch 4/100


100%|██████████| 4375/4375 [01:21<00:00, 53.58it/s, loss=0, accuracy=0, val_loss=0, val_accuracy=0] 


Epoch 5/100


100%|██████████| 4375/4375 [01:21<00:00, 53.75it/s, loss=0, accuracy=0, val_loss=0, val_accuracy=0] 


Epoch 6/100


100%|██████████| 4375/4375 [01:21<00:00, 53.94it/s, loss=0, accuracy=0, val_loss=0, val_accuracy=0] 


Epoch 7/100


100%|██████████| 4375/4375 [01:21<00:00, 53.78it/s, loss=0, accuracy=0, val_loss=0, val_accuracy=0] 


Epoch 8/100


100%|██████████| 4375/4375 [01:21<00:00, 53.49it/s, loss=0, accuracy=0, val_loss=0, val_accuracy=0] 


Epoch 9/100


100%|██████████| 4375/4375 [01:24<00:00, 52.00it/s, loss=0, accuracy=0, val_loss=0, val_accuracy=0] 


Epoch 10/100


100%|██████████| 4375/4375 [01:24<00:00, 51.62it/s, loss=0, accuracy=0, val_loss=0, val_accuracy=0] 


Epoch 11/100


100%|██████████| 4375/4375 [01:24<00:00, 51.54it/s, loss=0, accuracy=0, val_loss=0, val_accuracy=0] 


Epoch 12/100


100%|██████████| 4375/4375 [01:24<00:00, 51.72it/s, loss=0, accuracy=0, val_loss=0, val_accuracy=0] 


Epoch 13/100


100%|██████████| 4375/4375 [01:24<00:00, 51.52it/s, loss=0, accuracy=0, val_loss=0, val_accuracy=0] 


Epoch 14/100


100%|██████████| 4375/4375 [01:24<00:00, 51.62it/s, loss=0, accuracy=0, val_loss=0, val_accuracy=0] 


Epoch 15/100


100%|██████████| 4375/4375 [01:24<00:00, 51.52it/s, loss=0, accuracy=0, val_loss=0, val_accuracy=0] 


Epoch 16/100


100%|██████████| 4375/4375 [01:46<00:00, 40.94it/s, loss=0, accuracy=0, val_loss=0, val_accuracy=0] 


Epoch 17/100


100%|██████████| 4375/4375 [18:38<00:00,  3.91it/s, loss=0, accuracy=0, val_loss=0, val_accuracy=0] 


Epoch 18/100


100%|██████████| 4375/4375 [14:49<00:00,  4.92it/s, loss=0, accuracy=0, val_loss=0, val_accuracy=0] 


Epoch 19/100


100%|██████████| 4375/4375 [01:19<00:00, 54.82it/s, loss=0, accuracy=0, val_loss=0, val_accuracy=0] 


Epoch 20/100


100%|██████████| 4375/4375 [01:22<00:00, 53.01it/s, loss=0, accuracy=0, val_loss=0, val_accuracy=0] 


Epoch 21/100


100%|██████████| 4375/4375 [01:24<00:00, 52.01it/s, loss=0, accuracy=0, val_loss=0, val_accuracy=0] 


Epoch 22/100


100%|██████████| 4375/4375 [01:21<00:00, 53.65it/s, loss=0, accuracy=0, val_loss=0, val_accuracy=0] 


Epoch 23/100


 52%|█████▏    | 2282/4375 [00:47<00:43, 48.59it/s]