In [1]:
from collections import defaultdict

import torchvision.transforms as T
from torchvision.datasets import MNIST
from tqdm import tqdm

from capsnet import *
from loss_function import MarginLoss

torch.manual_seed(0)
device = torch.device('mps' if torch.has_mps else 'cuda' if torch.has_cuda else 'cpu')

print(f'Using {device.type}')

Using mps


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = torch.nn.Sequential(
    torch.nn.Conv2d(1, 256, 9),
    torch.nn.ReLU(),
    layer.PrimaryCapsule(256, 32, 8, 9, 2),
    layer.Squash('dr'),
    layer.CapsuleTransform(32 * 6 * 6, 8, 10, 16),
    layer.DynamicRouter(10, 16)
).to(device)

transform = T.ToTensor()

trainset = MNIST('mnist', train=True, transform=transform, download=True)
validset = MNIST('mnist', train=False, transform=transform)

epochs = 100
batch_size = 16

trainloader = torch.utils.data.DataLoader(trainset, batch_size, True, num_workers=8)
validloader = torch.utils.data.DataLoader(validset, batch_size, num_workers=8)

criterion = MarginLoss()
optimizer = torch.optim.AdamW(model.parameters(), 1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(trainloader) * 50)

print(f'CapsNet parameter size: {sum(p.numel() for p in model.parameters()):,}')
print()

history = []
for i in range(1, epochs + 1):
    record = defaultdict(float)

    print(f'Epoch {i}/{epochs}')
    with tqdm(total=len(trainloader) + len(validloader)) as pbar:
        model.train()
        for inputs, targets in trainloader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs).norm(p=2, dim=-1)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            record['loss'] += loss.item()
            record['accuracy'] += (outputs.argmax(-1) == targets).float().mean().item() * 100
            pbar.update()

        record['loss'] /= len(trainloader)
        record['accuracy'] /= len(trainloader)
        pbar.set_postfix(record)

        model.eval()
        for inputs, targets in validloader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            with torch.inference_mode():
                outputs = model(inputs).norm(p=2, dim=-1)
                loss = criterion(outputs, targets)

            record['val_loss'] += loss.item()
            record['val_accuracy'] += (outputs.argmax(-1) == targets).float().mean().item() * 100
            pbar.update()

        record['val_loss'] /= len(validloader)
        record['val_accuracy'] /= len(validloader)
        pbar.set_postfix(record)
        history.append(record)

TypeError: AgreementRouter.__init__() missing 1 required positional argument: 'out_capsules'