# Capsule Neural Network (CapsNet) Implementation

Implementation of the paper [Dynamic Routing Between Capsules](https://arxiv.org/pdf/1710.09829.pdf) by Sara Sabour, Nicholas Frosst, and Geoffrey E. Hinton. Used [jindongwang/Pytorch-CapsuleNet](https://github.com/jindongwang/Pytorch-CapsuleNet) and [laubonghaudoi/CapsNet_guide_PyTorch](https://github.com/laubonghaudoi/CapsNet_guide_PyTorch) to clarify some confusions, and borrowed some code.

## Setup PyTorch

In [1]:
!pip install torch torchvision
!pip install matplotlib
!pip install import-ipynb
!pip install tqdm
!pip install pytorch_extras

Collecting import-ipynb
  Downloading https://files.pythonhosted.org/packages/63/35/495e0021bfdcc924c7cdec4e9fbb87c88dd03b9b9b22419444dc370c8a45/import-ipynb-0.1.3.tar.gz
Building wheels for collected packages: import-ipynb
  Building wheel for import-ipynb (setup.py) ... [?25l[?25hdone
  Created wheel for import-ipynb: filename=import_ipynb-0.1.3-cp37-none-any.whl size=2976 sha256=747d0dbbe3f8feb12a696859201a0181659be9def08e73e56c0b05c21d8369a1
  Stored in directory: /root/.cache/pip/wheels/b4/7b/e9/a3a6e496115dffdb4e3085d0ae39ffe8a814eacc44bbf494b5
Successfully built import-ipynb
Installing collected packages: import-ipynb
Successfully installed import-ipynb-0.1.3
Collecting pytorch_extras
  Downloading https://files.pythonhosted.org/packages/66/79/42d7d9a78c27eb897b14790c9759dd9a991f67bc987e9e137527a68db9dc/pytorch-extras-0.1.3.tar.gz
Building wheels for collected packages: pytorch-extras
  Building wheel for pytorch-extras (setup.py) ... [?25l[?25hdone
  Created wheel for pytor

In [2]:
from google.colab import drive
drive.mount("mnt")

Mounted at mnt


In [3]:
%cd "mnt/My Drive"

/content/mnt/My Drive


In [4]:
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
import torch_extras
import torch.nn as nn
import torchvision.utils as tv_utils
import torch.nn.functional as F
from torch.autograd import Variable
import import_ipynb
import load_DrawData_with_transform
import numpy as np
import argparse
from tqdm import tqdm
# %mkdir -p /content/project/
# %cd /content/project/

importing Jupyter notebook from load_DrawData_with_transform.ipynb
Collecting ndjson
  Downloading https://files.pythonhosted.org/packages/70/c9/04ba0056011ba96a58163ebfd666d8385300bd12da1afe661a5a147758d7/ndjson-0.3.1-py2.py3-none-any.whl
Installing collected packages: ndjson
Successfully installed ndjson-0.3.1
Collecting cairocffi
[?25l  Downloading https://files.pythonhosted.org/packages/84/ca/0bffed5116d21251469df200448667e90acaa5131edea869b44a3fbc73d0/cairocffi-1.2.0.tar.gz (70kB)
[K     |████████████████████████████████| 71kB 3.8MB/s 
Building wheels for collected packages: cairocffi
  Building wheel for cairocffi (setup.py) ... [?25l[?25hdone
  Created wheel for cairocffi: filename=cairocffi-1.2.0-cp37-none-any.whl size=89548 sha256=3b7d66e5814ce46344c8ef9cdc89e1872864acc5136a4f190a384f2596ac3bb2
  Stored in directory: /root/.cache/pip/wheels/40/76/48/f1effadceea83b32e7d957dd0f92db4db8b537d7b72b4ef374
Successfully built cairocffi
Installing collected packages: cairocffi
Succ

## CapsNet Modules

In [61]:
def squash(s):
    '''Non-linear "squashing" function to ensure that short vectors get shrunk 
    to almost zero length and long vectors get shrunk to a length slightly 
    below 1. Equation (1) in the paper.
    
    Input:
      s: 	total input vector
    
    Output:
      squashed output vector
    '''
    norm_sqrd = torch.sum(s**2, dim=2, keepdim=True)
    return (norm_sqrd / (1 + norm_sqrd)) * (s / (torch.sqrt(norm_sqrd) + 1e-8))

In [62]:
class PrimaryCapsules(nn.Module):
    '''The layer after Conv1. Section 4 of the paper.
    '''

    def __init__(self):
        super(PrimaryCapsules, self).__init__()
        self.capsules = nn.ModuleList(
            [nn.Conv2d(in_channels=256,
                    out_channels=8,
                    kernel_size=9,
                    stride=2)
            for i in range(32)]
        )

    def forward(self, x):
        '''Section 4 of the paper.

        Input:
        x: outut of the ReLUConv1 layer, of shape (batch_size x 256 x 32 x 32)

        Output:
        squashed PrimaryCapsules output tensor, of shape (batch_size x 4608 x 8)
        '''
        batch_size = x.size(0)
        
        # Get the output of each primary capsule; combine and prepare them to 
        # serve as the input to the next layer (DoodleCapsules)
        all_u = [] 
        assert x.shape==torch.Size([batch_size, 256, 32, 32]), x.shape
        for cap in self.capsules:
            u = cap(x)  # (batch_size x 8 x 12 x 12)
            assert u.shape == torch.Size([batch_size, 8, 12, 12]), u.shape
            u = u.view(batch_size, 8, 144, 1)  # (batch_size x 8 x 144 x 1)
            all_u.append(u)
        all_u = torch.cat(all_u, dim=3)  # (batch_size x 8 x 144 x 32)
        assert all_u.shape == torch.Size([batch_size, 8, 144, 32]), all_u.shape
        all_u = all_u.view(batch_size, 8, -1)  # (batch_size x 8 x 4)
        all_u = torch.transpose(all_u, 1, 2)  # (batch_size x 11460852 x 8)
        all_u = squash(all_u)  # (batch_size x 4608 x 8)
        assert all_u.shape == torch.Size([batch_size, 4608, 8]), all_u.shape
        
        return all_u

In [63]:
class DoodleCapsules(nn.Module):
    '''The layer after PrimaryCapsules. Section 4 of the paper.
    '''

    def __init__(self, opt):
        super(DoodleCapsules, self).__init__()
        self.opt = opt
        self.W = nn.Parameter(torch.randn(1, 4608, opt.n_classes, 8, 24))
    
    def forward(self, u):
        '''Equation (2) and Procedure 1 in the paper.

        Input:
        u: output of the PrimaryCapsules layer, of shape (batch_size x 8)

        Output:
        output tensor of the DoodleCapsules layer, of shape (batch_size x 4608 x n_classes x 1)
        '''
        batch_size = u.size(0)
        
        u = torch.unsqueeze(u, dim=2)  # (batch_size x 1 x 8)
        u = torch.unsqueeze(u, dim=2)  # (batch_size x 1 x 1 x 8)
        u_hat = torch.matmul(u, self.W).squeeze()  # (batch_size x 4608 x n_classes x 24)

        b = Variable(torch.zeros(batch_size, 4608, self.opt.n_classes, 1))  # (batch_size x 4608 x n_classes x 1)
        b = b.cuda()

        for i in range(self.opt.iterations):
            c = F.softmax(b, dim=2)  # (batch_size x 4608 x n_classes x 1)
            s = torch.sum(u_hat * c, dim=1)  # (batch_size x n_classes x 24)
            v = squash(s)  # (batch_size x n_classes x 24)
            b = b + torch.sum(u_hat * v.unsqueeze(1), dim=3, keepdim=True)  # (batch_size x 4608 x n_classes x 1)
        
        return v  # (batch_size x n_classes x 24)

In [64]:
class DoodleDecoder(nn.Module):
    '''Decoder structure to reconstruct the doodle from the output of the DoodleCapsules layer.
    Section 4 of the paper. For an illustrative explaination, see Figure 2 of the paper.
    '''

    def __init__(self, opt):
        super(DoodleDecoder, self).__init__()
        self.opt = opt
        self.layers = nn.Sequential(
            nn.Linear(opt.n_classes * 24, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, opt.image_size * opt.image_size),
            nn.Sigmoid()
        )

    def forward(self, v, target):
        '''Takes a 24-dimensional vector v from the *correct* DoodleCapsules, and 
        learns to decode it into an image of a doodle. Mask out the other (n_classes - 1) classes.
        Section 4 of the paper. For an illustrative explaination, see Figure 2 of the paper.

        Input:
        v: output of DoodleCapsules, of shape (batch_size x n_classes x 24)
        target: one-hot targets, of shape (batch_size, n_classes)

        Output:
        decoder constructed images, of shape (batch_size x image_size^2)
        '''
        # TODO: the true target or the most probable?
        # correct = torch.sqrt((v ** 2).sum(2))  # (batch_size x n_classes)
        # correct = F.softmax(correct, dim=0)  # (batch_size x n_classes)
        # correct = correct.max(dim=1)[1]  # (batch_size)
        batch_size = v.size(0)

        # Create the mask, which is 1 only for the correct class and 0 otherwise
        mask = target.type(torch.FloatTensor).unsqueeze(-1)  # (batch_size x n_classes x 1)
        mask = torch.repeat_interleave(mask, 24, dim=2)  # (batch_size x n_classes x 24)
        mask = mask.cuda()
        assert mask.size() == torch.Size([batch_size, 2, 24]), mask.size()
        
        masked = (v * mask).view(batch_size, -1)  # (batch_size x n_classes x 24)
        result = self.layers(masked)  # (batch_size x f)
        assert result.shape == torch.Size([batch_size, opt.image_size*opt.image_size]), result.shape
        
        return result

In [65]:
class CapsuleNetwork(nn.Module):
    '''Consists of a ReLU Convolution layer, a PrimaryCapsules layer, a DoodleCapsules
    layer, and a Decoder layer. Section 4 of the paper.
    '''

    def __init__(self, opt):
        super(CapsuleNetwork, self).__init__()
        self.opt = opt
        self.ReLUConv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9),
            nn.ReLU()
        )
        self.PrimaryCapsules = PrimaryCapsules()
        self.DoodleCapsules = DoodleCapsules(opt)
        self.DoodleDecoder = DoodleDecoder(opt)
    
    def forward(self, x):
        '''Section 4 of the paper.

        Input:
        the input to the network, of shape (batch_size x image_size x image_size)

        Output:
        the output of the network, of shape (batch_size x n_classes x 24)
        '''
        v = torch.unsqueeze(x, 1)
        v = self.ReLUConv1(v)  # (batch_size x 256 x 32 x 32)
        v = self.PrimaryCapsules(v)  # (batch_size x 4608 x 8)
        v = self.DoodleCapsules(v)  # (batch_size x n_classes x 24)
        return v  # (batch_size x n_classes x 24)

    def marginal_loss(self, v, target):
        '''Section 3, Equation (4) of the paper.

        Input:
        v: the output of the network, of shape (batch_size x n_classes x 24)
        target: the one-hot target, of shape (batch_size x n_classes)
        lambd: a scalor for down-weighting of the loss for absent doodle classes
        
        Output:
        marginal loss (a scalor summed over all batches and classes)
        '''
        v_norm = torch.sqrt((v ** 2).sum(dim=2))  # (batch_size x n_classes)
        
        zeros = torch.zeros(v_norm.size())  # (batch_size x n_classes)
        zeros = zeros.cuda()
        
        max1 = torch.max(zeros, 0.9 - v_norm) ** 2  # (batch_size x n_classes)
        max2 = torch.max(zeros, v_norm - 0.1) ** 2  # (batch_size x n_classes) 
        loss = target * max1 + (1 - target) * 0.5 * max2  # (batch_size x n_classes)
        
        return torch.sum(loss)  # scalor

    def reconstruction_loss(self, data, reconstruction):
        '''Reconstruction for regularization. Ecourages the doodle capsules to 
        encode the instantiation parameters of the input doodle. Section 4.1 of the paper.

        Input:
        data: the real image, of shape (batch_size, image_size, image_size)
        reconstruction: the reconstructed image, of shape (batch_size, image_size * image_size)

        Output:
        reconstruction loss (a scalor summed over all batches and classes)
        '''
        batch_size = data.size(0)
        return torch.sum((reconstruction - data.view(batch_size, -1)) ** 2)

    def loss(self, v, data, target):
        '''Loss is marginal loss + 0.0005 * reconstruction loss. 0.0005 to ensure
        the reconstruction loss does not dominate the training. Section 4.1 of the paper.

        Input:
        v: output of the network, of shape (batch_size x n_classes x 24)
        target: one-hot target, of shape (batch_size x n_classes)
        data: the input to the network (the image), of shape (batch_size, image_size, image_size)
        
        Output:
        averaged loss (a scalor) over batches 
        averaged marginal loss (a scalor) over batches
        averaged reconstruction loss (a scalor) over batches
        '''
        batch_size = data.size(0)
        
        marginal_loss = self.marginal_loss(v, target)  # scalor
        
        reconstruction = self.DoodleDecoder(v, target)  # (batch_size, image_size^2)
        assert reconstruction.shape == torch.Size([batch_size, 40 * 40]), (reconstruction.shape, opt.image_size)
        reconstruction_loss = self.reconstruction_loss(data, reconstruction)  # scalor
        
        loss = marginal_loss + 0.0005 * reconstruction_loss  # scalor
        
        return loss/batch_size, marginal_loss/batch_size, reconstruction_loss/batch_size

## Training

In [66]:
def get_opts():
    parser = argparse.ArgumentParser(description='CapsuleNetwork')
    parser.add_argument('-batch_size', type=int, default=32)
    parser.add_argument('-lr', type=float, default=1e-6)
    parser.add_argument('-epochs', type=int, default=200)
    parser.add_argument('-image_size', type=int, default=40)
    parser.add_argument('-n_classes', type=int, default=2)
    parser.add_argument('-iterations', type=int, default=3)
    parser.add_argument('-print_every', type=int, default=10)
    parser.add_argument('-gamma', type=float, default=0.8)
    opt, _ = parser.parse_known_args()
    return opt

In [67]:
def evaluate(opt, test_loader, model, epoch, num_batches, dataset_type):
    sum_loss = 0
    sum_marginal_loss = 0
    sum_reconstruction_loss = 0
    correct = 0
    num_sample = len(test_loader.dataset)
    num_batch = len(test_loader)

    model.eval()
    for data, target in test_loader:
        data = data.to(torch.float32)
        target = target.to(torch.int64)
        batch_size = data.size(0)
        assert target.size() == torch.Size([batch_size, opt.n_classes])

        # Use GPU if available
        with torch.no_grad():
            data, target = Variable(data), Variable(target)
        data, target = data.cuda(), target.cuda()

        output = model(data)  # (batch_size, n_classes, 24)
        loss, marginal_loss, reconstruction_loss = model.loss(output, data, target)
        sum_loss += loss.item()
        sum_marginal_loss += marginal_loss.item()
        sum_reconstruction_loss += reconstruction_loss.item()

        norms = torch.sqrt(torch.sum(output**2, dim=2))  # (batch_size, n_classes)
        pred = norms.data.max(1, keepdim=True)[1].type(torch.LongTensor)  # (batch_size, 1)
        label = target.max(1, keepdim=True)[1].type(torch.LongTensor)  # (batch_size, 1)
        correct += pred.eq(label.view_as(pred)).sum().item()

    recons = model.DoodleDecoder(output, target)
    recons = recons.view(batch_size, 1, opt.image_size, opt.image_size)
    recons = tv_utils.make_grid(recons.data, normalize=True, scale_each=True)

    sum_loss /= num_batch
    sum_marginal_loss /= num_batch
    sum_reconstruction_loss /= num_batch
    
    print('{}'.format(dataset_type))
    print('\tLoss: {:.4f}   Marginal loss: {:.4f}   Reconstruction loss: {:.4f}'.format(
        sum_loss, sum_marginal_loss, sum_reconstruction_loss))
    print('\tAccuracy: {}/{} {:.4f}'.format(correct, num_sample,
        correct / num_sample))

In [68]:
def train(opt, train_loader, test_loader, model):
    num_sample = len(train_loader.dataset)
    num_batches = len(train_loader)
    train_loss_list = []
    correct = 0

    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, opt.gamma)
    model.train()
    for epoch in range(opt.epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            data = data.to(torch.float32)
            batch_size = data.size(0)
            target = target.to(torch.int64)
            assert target.size() == torch.Size([batch_size, opt.n_classes])

            # Use GPU if available
            with torch.no_grad():
                data, target = Variable(data), Variable(target)
            data, target = data.cuda(), target.cuda()
            
            output = model(data)
            loss, marginal_loss, reconstruction_loss = model.loss(output, data, target)

            norms = torch.sqrt(torch.sum(output**2, dim=2))  # (batch_size, n_classes)
            pred = norms.data.max(1, keepdim=True)[1].type(torch.LongTensor)  # (batch_size, 1)
            label = target.max(1, keepdim=True)[1].type(torch.LongTensor)  # (batch_size, 1)
            correct += pred.eq(label.view_as(pred)).sum().item()

            loss.backward()
            optimizer.step()

        if epoch % 1 == 0: 
            train_loss_list.append(loss.item())
        if epoch % 4 == 0:
            print('\nEpoch: {}'.format(epoch))
            evaluate(opt, train_loader, model, epoch, num_batches, 'TRAIN') 
            evaluate(opt, test_loader, model, epoch, num_batches, 'TEST') 
        scheduler.step()
    fig = plt.figure()
    plt.plot([i for i in range(len(train_loss_list))], train_loss_list, '-')

In [69]:
opt = get_opts()

model = CapsuleNetwork(opt)
model = model.cuda()

train_loader = load_DrawData_with_transform.train_loader

train(opt, train_loader, test_loader, model)


Epoch: 0
TRAIN
	Loss: 3563.7421   Marginal loss: 0.4050   Reconstruction loss: 7126673.8864
	Accuracy: 351/700 0.5014
TEST
	Loss: 3532.8020   Marginal loss: 0.4050   Reconstruction loss: 7064793.7059
	Accuracy: 296/518 0.5714

Epoch: 4
TRAIN
	Loss: 3565.1707   Marginal loss: 0.4050   Reconstruction loss: 7129531.0000
	Accuracy: 436/700 0.6229
TEST
	Loss: 3532.7789   Marginal loss: 0.4050   Reconstruction loss: 7064747.3824
	Accuracy: 314/518 0.6062

Epoch: 8
TRAIN
	Loss: 3564.5872   Marginal loss: 0.4050   Reconstruction loss: 7128363.9318
	Accuracy: 432/700 0.6171
TEST
	Loss: 3532.7707   Marginal loss: 0.4050   Reconstruction loss: 7064731.2353
	Accuracy: 327/518 0.6313


KeyboardInterrupt: ignored