# Implementing Capsules Network

Author: YinTaiChen

## Packages

In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim import Adam
from torchvision import datasets, transforms

## CUDA

In [2]:
if torch.cuda.is_available():
    USE_CUDA = True

## MNIST dataset

In [3]:
# We show that a discriminatively trained, multi-layer capsule system achieves state-of-the-art performance on MNIST
class Mnist:
    def __init__(self, batch_size):
        dataset_transform = transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])

        train_dataset = datasets.MNIST('../data', train=True, download=True, transform=dataset_transform)
        test_dataset = datasets.MNIST('../data', train=False, download=True, transform=dataset_transform)
        
        self.train_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        self.test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

## Conv1

In [4]:
# Conv1 has 256, 9 x 9 convolution kernels with a stride of 1 and ReLU activation.
# out_channels=256
# kernel_size=9
# stride=1
# F.relu(self.conv(x))
class ConvLayer(nn.Module):
    def __init__(self, in_channels=1, out_channels=256, kernel_size=9):
        super(ConvLayer, self).__init__()

        self.conv = nn.Conv2d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=kernel_size,
                               stride=1
                             )

    def forward(self, x):
        return F.relu(self.conv(x))

## PrimaryCapsules

In [5]:
# The second layer (PrimaryCapsules) is a convolutional capsule layer with 32 channels of convolutional 8D capsules.
# (i.e. each primary capsule contains 8 convolutional units with 9 x 9 kernel and a stride of 2)
class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9):
        super(PrimaryCaps, self).__init__()
        
        # self.capsules is a 8D capsule
        # each capsule in the [6x6] grid is sharing their weights with each other
        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0) 
                          for _ in range(num_capsules)])
    
    def forward(self, x):
        # u is the ouput of a 8D capsule
        u = [capsule(x) for capsule in self.capsules]
        u = torch.stack(u, dim=1)
        # In total PrimaryCapsules has [32x6x6] capsule outputs (each output is an 8D vector)
        u = u.view(x.size(0), 32 * 6 * 6, -1)
        return self.squash(u)
    
    # We therefore use a non-linear "squashing" function to ensure that
    # short vectors get shrunk to almost zero length and 
    # long vectors get shrunk to a length slightly below 1
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

## DigitCaps

In [6]:
class DigitCaps(nn.Module):
    def __init__(self, num_capsules=10, num_routes=32 * 6 * 6, in_channels=8, out_channels=16):
        super(DigitCaps, self).__init__()

        self.in_channels = in_channels
        self.num_routes = num_routes
        self.num_capsules = num_capsules

        self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels))

    def forward(self, x):
        batch_size = x.size(0)
        x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)
        
        # u_hat is produced by multiplying the output of a capsule (x) in the layer below by a weight matirx W
        W = torch.cat([self.W] * batch_size, dim=0)
        u_hat = torch.matmul(W, x)
        
        # All the routing logits(b_ij) are initialized to zero
        # b_ij are the log prior probabilities that capsule i should be coupled to capsule j
        b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1))
        if USE_CUDA:
            b_ij = b_ij.cuda()
        
        # c_ij are coupling coefficients that are determined by the iterative dynamic routing process
        num_iterations = 3
        for iteration in range(num_iterations):
            # routing softmax
            c_ij = F.softmax(b_ij)
            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)
            
            # s_j is a weighted sum over all "prediction vecotrs" u_hat
            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
            v_j = self.squash(s_j)
            
            # the initial coupling coefficients are then iteratively refined 
            # by measuring the agreement between the current output v_j of each capsule j
            # and the prediction u_hat
            # the agreement a_ij is simply a scalar product
            if iteration < num_iterations - 1:
                a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))
                b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)

        return v_j.squeeze(1)
    
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

## Decoder

In [7]:
# See Figure 2
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        
        # The output of the digit capsule is fed into a decoder consisting of 3 fully connected layers
        self.reconstraction_layers = nn.Sequential(
            nn.Linear(16 * 10, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 784),
            nn.Sigmoid()
        )
        
    def forward(self, x, data):
        classes = torch.sqrt((x ** 2).sum(2))
        classes = F.softmax(classes)
        
        _, max_length_indices = classes.max(dim=1)
        print(max_length_indices)
        max_length_indices = Variable(max_length_indices)
        masked = Variable(torch.sparse.torch.eye(10))
        if USE_CUDA:
            masked = masked.cuda()
        masked = masked.index_select(dim=0, index=max_length_indices.squeeze(1).data)
        
        reconstructions = self.reconstraction_layers((x * masked[:, :, None, None]).view(x.size(0), -1))
        reconstructions = reconstructions.view(-1, 1, 28, 28)
        
        return reconstructions, masked

## CapsNet

In [8]:
class CapsNet(nn.Module):
    def __init__(self):
        super(CapsNet, self).__init__()
        self.conv_layer = ConvLayer()
        self.primary_capsules = PrimaryCaps()
        self.digit_capsules = DigitCaps()
        self.decoder = Decoder()
        
        self.mse_loss = nn.MSELoss()
        
    def forward(self, data):
        output = self.digit_capsules(self.primary_capsules(self.conv_layer(data)))
        reconstructions, masked = self.decoder(output, data)
        return output, reconstructions, masked
    
    def loss(self, data, x, target, reconstructions):
        return self.margin_loss(x, target) + self.reconstruction_loss(data, reconstructions)
    
    # See 3 Margin loss for digit existence
    def margin_loss(self, x, labels, size_average=True):
        batch_size = x.size(0)
        
        # We are using the length of the instantiation vector to represent the probability that a capsule's entity exists.
        v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True))
        
        # m+ = 0.9, m- = 0.1
        left = F.relu(0.9 - v_c).view(batch_size, -1)
        right = F.relu(v_c - 0.1).view(batch_size, -1)
        
        # lambda = 0.5
        loss = labels * left + 0.5 * (1.0 - labels) * right
        loss = loss.sum(dim=1).mean()

        return loss
    
    # See 4.1 Reconstruction as a regularization method
    def reconstruction_loss(self, data, reconstructions):
        # We minimize the sum of squared differences between the outputs of the logistic units and the pixel intensities.
        loss = self.mse_loss(reconstructions.view(reconstructions.size(0), -1), data.view(reconstructions.size(0), -1))
        
        # We scale down this reconstruction loss by 0.0005 so that it does not dominate the margin loss during training.
        return loss * 0.0005

In [9]:
capsule_net = CapsNet()
if USE_CUDA:
    capsule_net = capsule_net.cuda()
# ...and we use the Adam optimizer
optimizer = Adam(capsule_net.parameters())

RuntimeError: cuda runtime error (2) : out of memory at /pytorch/torch/lib/THC/generic/THCStorage.cu:58

## Train and Test

In [15]:
batch_size = 100
mnist = Mnist(batch_size)

n_epochs = 30

for epoch in range(n_epochs):
    capsule_net.train()
    train_loss = 0
    for batch_id, (data, target) in enumerate(mnist.train_loader):

        target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)
        data, target = Variable(data), Variable(target)

        if USE_CUDA:
            data, target = data.cuda(), target.cuda()

        optimizer.zero_grad()
        output, reconstructions, masked = capsule_net(data)
        loss = capsule_net.loss(data, output, target, reconstructions)
        loss.backward()
        optimizer.step()

        train_loss += loss.data[0]
        
        if batch_id % 100 == 0:
            print("Epoch ", epoch, " batch_id ", batch_id, " train accuracy:", sum(np.argmax(masked.data.cpu().numpy(), 1) == 
                                   np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size))
        
    print("Epoch ", epoch, " train loss: ", train_loss / len(mnist.train_loader))
        
    capsule_net.eval()
    test_loss = 0
    for batch_id, (data, target) in enumerate(mnist.test_loader):

        target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)
        data, target = Variable(data), Variable(target)

        if USE_CUDA:
            data, target = data.cuda(), target.cuda()

        output, reconstructions, masked = capsule_net(data)
        loss = capsule_net.loss(data, output, target, reconstructions)

        test_loss += loss.data[0]
        
        if batch_id % 100 == 0:
            print("Epoch", epoch, "batch_id", batch_id, "test accuracy:", sum(np.argmax(masked.data.cpu().numpy(), 1) == 
                                   np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size))
    
    print("Epoch ", epoch, " test loss ", test_loss / len(mnist.test_loader))

RuntimeError: cuda runtime error (2) : out of memory at /pytorch/torch/lib/THC/generic/THCStorage.cu:58

## Reconstruction

In [None]:
import matplotlib
import matplotlib.pyplot as plt

def plot_images_separately(images):
    "Plot the six MNIST images separately."
    fig = plt.figure()
    for j in xrange(1, 7):
        ax = fig.add_subplot(1, 6, j)
        ax.matshow(images[j-1], cmap = matplotlib.cm.binary)
        plt.xticks(np.array([]))
        plt.yticks(np.array([]))
    plt.show()

In [None]:
plot_images_separately(data[:6,0].data.cpu().numpy())

In [None]:
plot_images_separately(reconstructions[:6,0].data.cpu().numpy())