# Capsule Network (CapsNet) for MNIST Digit Classification
This project implements a Capsule Network (CapsNet) from scratch using PyTorch to classify handwritten digits from the MNIST dataset. CapsNet is an advanced deep learning architecture that preserves spatial hierarchies and part-whole relationships better than traditional CNNs.


In [None]:
pip install torch torchvision matplotlib


Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

##  Importing Required Libraries
We import essential libraries including PyTorch, torchvision for dataset handling, and NumPy for array operations.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt


##  Loading the MNIST Dataset
We download and load the MNIST dataset using torchvision. The images are normalized and loaded using DataLoader.


In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)


100%|██████████| 9.91M/9.91M [00:00<00:00, 15.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 488kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.80MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.65MB/s]


##  Squash Activation Function
The squash function ensures that the length of the output vector is between 0 and 1, which is crucial in Capsule Networks.


In [None]:
def squash(tensor, dim=-1):
    norm = torch.norm(tensor, dim=dim, keepdim=True)
    scale = (norm**2) / (1 + norm**2)
    return scale * tensor / (norm + 1e-8)


##  Defining the Capsule Network Architecture
We define a PyTorch neural network model using capsules instead of traditional layers. It includes:
- Conv layer
- Primary Capsules
- Digit Capsules


In [None]:
class PrimaryCapsules(nn.Module):
    def __init__(self, in_channels, out_capsules, capsule_dim, kernel_size, stride):
        super().__init__()
        self.capsules = nn.Conv2d(in_channels, out_capsules * capsule_dim, kernel_size, stride)

        self.out_capsules = out_capsules
        self.capsule_dim = capsule_dim

    def forward(self, x):
        batch_size = x.size(0)
        out = self.capsules(x)
        out = out.view(batch_size, self.out_capsules, self.capsule_dim, -1)
        out = out.permute(0, 3, 1, 2).contiguous()  # [B, num_capsules, out_capsules, capsule_dim]
        out = out.view(batch_size, -1, self.capsule_dim)  # Flatten spatial dims
        return squash(out)


In [None]:
class DigitCapsules(nn.Module):
    def __init__(self, num_caps_in, dim_caps_in, num_caps_out, dim_caps_out, routing_iters=3):
        super().__init__()
        self.num_caps_out = num_caps_out
        self.dim_caps_out = dim_caps_out
        self.routing_iters = routing_iters

        self.W = nn.Parameter(0.01 * torch.randn(1, num_caps_in, num_caps_out, dim_caps_out, dim_caps_in))

    def forward(self, x):
        batch_size = x.size(0)
        x = x.unsqueeze(2).unsqueeze(4)
        x = x.expand(batch_size, -1, self.num_caps_out, -1, 1)
        W = self.W.expand(batch_size, -1, -1, -1, -1)

        u_hat = torch.matmul(W, x).squeeze(-1)

        b = torch.zeros(batch_size, x.size(1), self.num_caps_out).to(x.device)

        for _ in range(self.routing_iters):
            c = F.softmax(b, dim=2)
            s = (c.unsqueeze(-1) * u_hat).sum(dim=1)
            v = squash(s)
            b = b + (u_hat * v.unsqueeze(1)).sum(dim=-1)

        return v


##  Defining Capsule Loss (Margin Loss)
We implement the Capsule Loss (also known as margin loss), which replaces cross-entropy in CapsNet.


In [None]:
class CapsuleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 256, kernel_size=9, stride=1)
        self.primary_caps = PrimaryCapsules(256, 32, 8, kernel_size=9, stride=2)
        self.digit_caps = DigitCapsules(32 * 6 * 6, 8, 10, 16)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.primary_caps(x)
        x = self.digit_caps(x)
        lengths = torch.norm(x, dim=-1)
        return lengths


##  Defining the Training Loop
The `train` function performs forward propagation, computes loss, backpropagates, and updates weights using the optimizer.


In [None]:
def train(model, optimizer, train_loader, epoch):
    model.train()
    total_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)

        loss = F.mse_loss(output, F.one_hot(target, num_classes=10).float())
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if batch_idx % 100 == 0:
            print(f'Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.4f}')
    return total_loss / len(train_loader)


##  Defining the Testing Function
The `test` function evaluates the model’s performance on the test set and prints accuracy.


In [None]:
def test(model, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()

    print(f'Test Accuracy: {correct}/{len(test_loader.dataset)} = {100. * correct / len(test_loader.dataset):.2f}%')


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CapsuleNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(1, 11):
    loss = train(model, optimizer, train_loader, epoch)
    test(model, test_loader)


Epoch 1 [0/60000] Loss: 0.0993
Epoch 1 [12800/60000] Loss: 0.0185
Epoch 1 [25600/60000] Loss: 0.0130
Epoch 1 [38400/60000] Loss: 0.0095
Epoch 1 [51200/60000] Loss: 0.0064
Test Accuracy: 9869/10000 = 98.69%
Epoch 2 [0/60000] Loss: 0.0053
Epoch 2 [12800/60000] Loss: 0.0051
Epoch 2 [25600/60000] Loss: 0.0031
Epoch 2 [38400/60000] Loss: 0.0049
Epoch 2 [51200/60000] Loss: 0.0046
Test Accuracy: 9897/10000 = 98.97%
Epoch 3 [0/60000] Loss: 0.0044
Epoch 3 [12800/60000] Loss: 0.0024
Epoch 3 [25600/60000] Loss: 0.0036
Epoch 3 [38400/60000] Loss: 0.0024
Epoch 3 [51200/60000] Loss: 0.0017
Test Accuracy: 9903/10000 = 99.03%
Epoch 4 [0/60000] Loss: 0.0037
Epoch 4 [12800/60000] Loss: 0.0018
Epoch 4 [25600/60000] Loss: 0.0017
Epoch 4 [38400/60000] Loss: 0.0030
Epoch 4 [51200/60000] Loss: 0.0027
Test Accuracy: 9914/10000 = 99.14%
Epoch 5 [0/60000] Loss: 0.0011
Epoch 5 [12800/60000] Loss: 0.0008
Epoch 5 [25600/60000] Loss: 0.0016
Epoch 5 [38400/60000] Loss: 0.0024
Epoch 5 [51200/60000] Loss: 0.0027
Test 

##  Final Test Accuracy and Observations
After training, we observe the test accuracy reach up to ~99.27%. This demonstrates the power of Capsule Networks on MNIST.


##  Conclusion
Capsule Networks are a powerful alternative to CNNs, especially for spatially aware tasks. This project successfully demonstrates a custom PyTorch implementation on the MNIST dataset.

---

🧠 Created by: Mayank Pratap Singh
