# MNIST CapsNet 
*By Danny Luo*

Implementation of capsules network model on MNIST in PyTorch (heavily) based on GramAI's implementation by Kenta Iwasaki and subsequent versions. CUDA disabled by default.

This notebook offers further explanation and annotation to offer the user a deep understanding of capsules machinery.

*Capsules Paper*:
* Sara Sabour, Nicholas Frosst, and Geoffrey E. Hinton *Dynamic Routing Between Capsules*

*References*
* https://gist.github.com/kendricktan/9a776ec6322abaaf03cc9befd35508d4 
* https://github.com/gram-ai/capsule-networks
* https://github.com/naturomics/CapsNet-Tensorflow

In [1]:
import sys
sys.setrecursionlimit(15000)

import torch
import torch.nn.functional as F

from torch import nn
from torch import optim
from torchvision import transforms
from torchvision.datasets.mnist import MNIST
from torch.autograd import Variable
from tqdm import tqdm

import datetime
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

In [2]:
dataset = MNIST(root='./data', download=True, train=True)

In [3]:
def squash(t):
    """
    Squash Function - Eq (1)
    """
    t_norm = torch.norm(t)
    return t_norm**2 / (1 + t_norm ** 2) * t / t_norm


def softmax(input, dim=1):
    """
    Softmax along specific dimensions
    """
    transposed_input = input.transpose(dim, len(input.size()) - 1)
    softmaxed_output = F.softmax(transposed_input.contiguous().view(-1, transposed_input.size(-1)))
    return softmaxed_output.view(*transposed_input.size()).transpose(dim, len(input.size()) - 1)

def index_to_one_hot(index_tensor, num_classes=10):
    """
    Converts index value to one hot vector.

    e.g. [2, 5] (with 10 classes) becomes:
        [
            [0 0 1 0 0 0 0 0 0 0]
            [0 0 0 0 1 0 0 0 0 0]
        ]
    """
    index_tensor = index_tensor.long()
    return torch.eye(num_classes).index_select(dim=0, index=index_tensor)

### CapsNet Architecture

MNIST Input: (N, 1, 28, 28), where N is the batch size. 28 x 28 is the pixel dimensions of MNIST.

1. **Conv1**: 256, 9 x 9 2D convolution kernels with stride 1 (28 -> 20), ReLU. Output: (N, 256, 20, 20)
2. **PrimaryCaps**: 32 channels of convolutional 8D capsules, each capsule contains 8 convolutional units with 9 x 9 kernel and stride 2 (20 -> 6) . Output: (N, 32 \* 6 \* 6 = 1152, 8). 1152 number of $u_i$'s.

     *Routing between PrimaryCaps and DigitCaps.*
     
3. **DigitCaps**: One 16D capsule for each digit class. 10 number of $v_j$'s. Output: (N, 10, 16)
    
    #### Decoder
     
4. FC ReLU (size 512) -> FC ReLU (size 1024) -> FC Sigmoid (size 784) 

In [4]:
class CapsuleLayer(nn.Module):
    def __init__(self, num_capsules, in_channels, out_channels, num_routings, kernel_size=3, stride=1, num_routing_iterations=5):
            super().__init__()
            self.num_capsules = num_capsules
            self.in_channels = in_channels
            self.out_channels = out_channels
            
            self.num_routings = num_routings
            self.num_routing_iterations = num_routing_iterations
            
            # If num_routings = -1, then it is the first capsules layer
            if num_routings == -1:
                self.capsules = nn.ModuleList(
                    [nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0) for capsule in range(num_capsules)])
            
            else:
                #Initialize Weights W_ij which is 8 (in) x 16 (out) matrix for each i in (1, num_routings=32x6x6=1152) and j in (1, num_classes=10)
                self.weights = nn.Parameter(torch.randn(num_capsules, num_routings, in_channels, out_channels))
            
            
    # Procedure 1: Routing algorithm. 
    def forward(self, x):
        if self.num_routings == -1:
            # Primary layer
            outputs = [capsule(x).view(x.size(0), -1, 1) for capsule in self.capsules] # view = reshape, flattened 32 channels 2DConv output to vector
            outputs = torch.cat(outputs, dim=-1) # Concatenating outputs of multiple 2D conv layers
            outputs = squash(outputs)
            
        else:
            # Procedure 1: Routing By Agreement
            # http://pytorch.org/docs/master/torch.html?highlight=matmul#torch.matmul
            
            #Input Prediction Vectors u_j|i Eq. 2
            # None adds an extra index at the selected spot | x dim (1, _, _, 1, _) |  W dim (_, 1, _, _, _)
            
            # x:  torch.Size([1, 4, 1152, 1, 8])
            # W:  torch.Size([10, 1, 1152, 8, 16])
            # priors:  torch.Size([10, 4, 1152, 1, 16])
            
            pred = x[None, :, :, None, :].matmul(self.weights[:, None, :, :, :]) # "prediction vectors" u_j|i
            
            logits = Variable(torch.zeros(*pred.size())) #b_ij
            
            for i in range(self.num_routing_iterations): #
                cc = softmax(logits, dim=2) #coupling coefficients
                outputs = squash((cc * pred).sum(dim=2, keepdim=True))

                delta_logits = (pred * outputs).sum(dim=-1, keepdim=True)
                logits = logits + delta_logits
            
        return outputs

In [42]:
class CapsNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)
        self.primarycaps = CapsuleLayer(num_capsules=8, num_routings=-1, in_channels=256, out_channels=32,
                                             kernel_size=9, stride=2)
        self.digitcaps = CapsuleLayer(num_capsules=10, num_routings=32 * 6 * 6, in_channels=8,
                                           out_channels=16)
        
        # Check below
        
        self.decoder = nn.Sequential(
            nn.Linear(16 * NUM_CLASSES, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 784),
            nn.Sigmoid()
        )
        
    def forward(self, x, y=None):
        
        #1:  torch.Size([4, 1, 28, 28]) 
        #2:  torch.Size([4, 256, 20, 20]) 
        #3:  torch.Size([4, 1152, 8])
        #4:  torch.Size([4, 10, 16])
       
        #1
        x = F.relu(self.conv1(x), inplace=True) #F is torch function
        x = self.primarycaps(x)
        x = self.digitcaps(x).squeeze().transpose(0, 1) #squeeze removes all 1d, transpose dim 1 and 2
        #4
    
        classes = (x ** 2).sum(dim=-1) ** 0.5 # Sums along Capsule dimension (16)
        classes = F.softmax(classes)
    
        if y is None:
            # In all batches, get the most active capsule.
            _, max_length_indices = classes.max(dim=1) # maxs, indices = torch.max(x, [dim])
            y = Variable(torch.sparse.torch.eye(NUM_CLASSES)).index_select(dim=0, index=max_length_indices.data)
    
        reconstructions = self.decoder((x * y[:, :, None]).view(x.size(0), -1))
        print(reconstructions.size())
        return classes, reconstructions
           

## Capsule Loss

In [6]:
class CapsuleLoss(nn.Module):
    def __init__(self):
        super(CapsuleLoss, self).__init__()
        self.reconstruction_loss = nn.MSELoss(size_average=False)

    def forward(self, images, labels, classes, reconstructions):
        left = F.relu(0.9 - classes, inplace=True) ** 2 # Here ReLU is complicated way of saying max(0, ...), as in the paper
        right = F.relu(classes - 0.1, inplace=True) ** 2

        margin_loss = labels * left + 0.5 * (1. - labels) * right
        margin_loss = margin_loss.sum()

        reconstruction_loss = self.reconstruction_loss(reconstructions, images)

        return (margin_loss + 0.0005 * reconstruction_loss) / images.size(0)

### Training Model

In [7]:
#Loading Data

train_loader = torch.utils.data.DataLoader(
    MNIST(root='/tmp', download=True, train=True,
          transform=transforms.ToTensor()),
    batch_size=4, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    MNIST(root='/tmp', download=True, train=False,
          transform=transforms.ToTensor()),
    batch_size=4, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [None]:
# Globals
NUM_CLASSES = 10
EPOCH = 2

# Model
model = CapsNet()

#model.cuda()

optimizer = optim.Adam(model.parameters())
capsule_loss = CapsuleLoss()

for e in range(10):
    # Training
    train_loss = 0

    model.train()
    for idx, (img, target) in enumerate(tqdm(train_loader, desc='Training')):
        img = Variable(img)
        target = Variable(index_to_one_hot(target))

        
        #img = img.cuda()
        #target = target.cuda()

        optimizer.zero_grad()

        classes, reconstructions = model(img, target) #CapsNet.forward

        loss = capsule_loss(img, target, classes, reconstructions)
        loss.backward()

        train_loss += loss.data.cpu()[0]

        optimizer.step()

    print('Training:, Avg Loss: {:.4f}'.format(train_loss))
   

In [95]:
import datetime
# ... after training, save your model f
#model.save_state_dict('capsulestraining_'+datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")+'.pt')
torch.save(model.state_dict(), 'capsulestraining_'+datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")+'.pt')

In [None]:
# # Testing
correct = 0
test_loss = 0

model.eval()
for idx, (img, target) in enumerate(tqdm(test_loader, desc='test set')):
    img = Variable(img)
    target_index = target
    target = Variable(index_to_one_hot(target))

    #img = img.cuda()
    #target = target.cuda()

    classes, reconstructions = model(img, target)

    test_loss += margin_loss(img, target, classes, reconstructions).data.cpu()

    # Get index of the max log-probability
    pred = classes.data.max(1, keepdim=True)[1].cpu()
    correct += pred.eq(target_index.view_as(pred)).cpu().sum()

test_loss /= len(test_loader.dataset)
correct = 100. * correct / len(test_loader.dataset)
print('Test Set: Avg Loss: {:.4f}, Accuracy: {:.4f}'.format(
    test_loss[0], correct))