In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader


In [2]:
def squash(tensor, dim=-1):
    norm = torch.norm(tensor, dim=dim, keepdim=True)
    scale = (norm**2) / (1 + norm**2)
    return scale * (tensor / (norm + 1e-8))


In [3]:
class PrimaryCapsules(nn.Module):
    def __init__(self, in_channels, out_capsules, capsule_dim, kernel_size=9, stride=2):
        super(PrimaryCapsules, self).__init__()
        self.capsules = nn.Conv2d(in_channels, out_capsules * capsule_dim, kernel_size=kernel_size, stride=stride)
        self.out_capsules = out_capsules
        self.capsule_dim = capsule_dim

    def forward(self, x):
        x = self.capsules(x)
        # Reshape to [batch_size, num_capsules, capsule_dim, height, width]
        batch_size = x.size(0)
        x = x.view(batch_size, self.out_capsules, self.capsule_dim, -1)
        x = x.permute(0, 1, 3, 2).contiguous()
        x = x.view(batch_size, self.out_capsules, -1)
        return squash(x)


In [4]:
class DigitCapsules(nn.Module):
    def __init__(self, in_capsules, in_dim, out_capsules, out_dim, num_routing=3):
        super(DigitCapsules, self).__init__()
        self.in_capsules = in_capsules
        self.in_dim = in_dim
        self.out_capsules = out_capsules
        self.out_dim = out_dim
        self.num_routing = num_routing

        # Weights for transforming primary capsules to digit capsules
        self.weights = nn.Parameter(0.01 * torch.randn(1, in_capsules, out_capsules, out_dim, in_dim))

    def forward(self, x):
        # Expand input and weights to perform matrix multiplication
        x = x.unsqueeze(2).unsqueeze(4)
        x = x.repeat(1, 1, self.out_capsules, 1, 1)
        weights = self.weights.repeat(x.size(0), 1, 1, 1, 1)
        
        # Predict capsules
        u_hat = torch.matmul(weights, x)
        
        # Routing algorithm
        b = torch.zeros(1, self.in_capsules, self.out_capsules, 1).to(x.device)
        for _ in range(self.num_routing):
            c = F.softmax(b, dim=2)
            s = (c * u_hat).sum(dim=1, keepdim=True)
            v = squash(s)
            b = b + (u_hat * v).sum(dim=-1, keepdim=True)
        
        return v.squeeze(1)


In [5]:
class CapsuleNetwork(nn.Module):
    def __init__(self, in_channels=1, num_classes=10, primary_capsules=8, primary_dim=16, digit_dim=16):
        super(CapsuleNetwork, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 256, kernel_size=9, stride=1)
        self.primary_capsules = PrimaryCapsules(256, primary_capsules, primary_dim, kernel_size=9, stride=2)
        self.digit_capsules = DigitCapsules(primary_capsules * 6 * 6, primary_dim, num_classes, digit_dim)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.primary_capsules(x)
        x = self.digit_capsules(x)
        return x


In [6]:
class CapsuleLoss(nn.Module):
    def __init__(self, m_pos=0.9, m_neg=0.1, lambda_=0.5):
        super(CapsuleLoss, self).__init__()
        self.m_pos = m_pos
        self.m_neg = m_neg
        self.lambda_ = lambda_

    def forward(self, outputs, labels):
        batch_size = labels.size(0)
        labels = torch.eye(outputs.size(1)).to(outputs.device).index_select(dim=0, index=labels)
        
        v_norm = torch.norm(outputs, dim=-1)
        loss_pos = labels * F.relu(self.m_pos - v_norm).pow(2)
        loss_neg = (1 - labels) * F.relu(v_norm - self.m_neg).pow(2)
        
        loss = loss_pos + self.lambda_ * loss_neg
        return loss.sum() / batch_size


In [7]:
# Define data loader, optimizer, and device
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CapsuleNetwork().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = CapsuleLoss()

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1000)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:02<00:00, 3.88MB/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1000)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 61.1kB/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1000)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:01<00:00, 1.22MB/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1000)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<?, ?B/s]


Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw



RuntimeError: The size of tensor a (288) must match the size of tensor b (8) at non-singleton dimension 1

In [None]:
def evaluate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            predicted = torch.norm(outputs, dim=-1).argmax(dim=-1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f'Accuracy: {accuracy:.2f}%')

# Test the model
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
evaluate(model, test_loader, device)
