# Capsule Neural Network (CapsNet) Implementation

Implementation of the paper [Dynamic Routing Between Capsules](https://arxiv.org/pdf/1710.09829.pdf) by Sara Sabour, Nicholas Frosst, and Geoffrey E. Hinton. Used [jindongwang/Pytorch-CapsuleNet](https://github.com/jindongwang/Pytorch-CapsuleNet) and [laubonghaudoi/CapsNet_guide_PyTorch](https://github.com/laubonghaudoi/CapsNet_guide_PyTorch) to clarify some confusions, and borrowed some code.

## Setup PyTorch

In [10]:
!pip install torch torchvision
!pip install matplotlib

%mkdir -p /content/project/
%cd /content/project/

/content/project


In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

## CapsNet Modules

In [12]:
def squash(s):
    '''Non-linear "squashing" function to ensure that short vectors get shrunk 
    to almost zero length and long vectors get shrunk to a length slightly 
    below 1. Equation (1) in the paper.
    
    Input:
      s: 	total input vector
    
    Output:
      squashed output vector
    '''
    norm_sqrd = torch.sum(s**2, dim=2, keepdim=True)
    return (norm_sqrd / (1 + norm_sqrd)) * (s / (torch.sqrt(norm_sqrd) + 1e-8))

In [13]:
class PrimaryCapsules(nn.Module):
  '''The layer after Conv1. Section 4 of the paper.
  '''

  def __init__(self):
    super(PrimaryCapsule, self).__init__()
    self.capsules = nn.ModuleList(
        [nn.Conv2d(in_channels=256,
                   out_channels=8,
                   kernel_size=9,
                   stride=2)
        for i in range(32)]
    )

  def forward(self, x):
    '''Section 4 of the paper.

    Input:
      x: outut of the ReLUConv1 layer, of shape (batch_size x 256 x 20 x 20)

    Output:
      squashed PrimaryCapsules output tensor, of shape (batch_size x 1152 x 8)
    '''
    batch_size = x.size(0)
    all_u = []
    for cap in self.capsules:
        u = cap(x)  # (batch_size x 8 x 6 x 6)
        u = u.view(batch_size, 8, 36, 1)  # (batch_size x 8 x 36 x 1)
        u.append(u)
    all_u = torch.cat(all_u, dim=3)  # (batch_size x 8 x 36 x 32)
    all_u = all_u.view(batch_size, 8, -1)  # (batch_size x 8 x 1152)
    all_u = torch.transpose(all_u, 1, 2)  # (batch_size x 1152 x 8)
    all_u = squash(all_u)  # (batch_size x 1152 x 8)
    return all_u

In [14]:
class DoodleCapsules(nn.Module):
  '''The layer after PrimaryCapsules. Section 4 of the paper.
  '''

  def __init__(self):
    super(DigitCaps, self).__init__()
    self.opt = opt
    self.W = nn.Parameter(torch.randn(1, 1152, opt.n_classes, 8, 16))
  
  def forward(self, u):
    '''Equation (2) and Procedure 1 in the paper.

    Input:
      u: output of the PrimaryCapsules layer, of shape (batch_size x 8)

    Output:
      output tensor of the DoodleCapsules layer, of shape (batch_size x 1152 x 10 x 1)
    '''
    batch_size = u.size(0)
    u = torch.unsqueeze(u, dim=2)  # (batch_size x 1 x 8)
    u = torch.unsqueeze(u, dim=2)  # (batch_size x 1 x 1 x 8)
    u_hat = torch.matmul(u, self.W).squeeze()  # (batch_size x 1152 x n_classes x 16)

    b = Variable(torch.zeros(batch_size, 1152, self.opt.n_classes, 1))  # (batch_size x 1152 x n_classes x 1)
    if self.opt.use_cuda & torch.cuda.is_available(): b = b.cuda()

    for i in range(self.opt.iterations):
        c = F.softmax(b, dim=2)  # (batch_size x 1152 x n_classes x 1)
        s = torch.sum(u_hat * c, dim=1)  # (batch_size x n_classes x 16)
        v = squash(s)  # (batch_size x n_classes x 16)
        b = b + torch.sum(u_hat * v.unsqueeze(1), dim=3, keepdim=True)  # (batch_size x 1152 x n_classes x 1)
    return v  # (batch_size x n_classes x 16)

In [15]:
class DoodleDecoder(nn.Module):
  '''Decoder structure to reconstruct the doodle from the output of the DoodleCapsules layer.
  Section 4 of the paper. For an illustrative explaination, see Figure 2 of the paper.
  '''

  def __init__(self, opt):
    super(DoodleDecoder, self).__init__()
    self.opt = opt
    self.layers = nn.Sequential(
      nn.Linear(opt.n_classes * 16, 512),
      nn.ReLU(inplace=True),
      nn.Linear(512, 1024),
      nn.ReLU(inplace=True),
      nn.Linear(1024, opt.height * opt.width),
      nn.Sigmoid()
    )

  def forward(self, v, target):
    '''Takes a 16-dimensional vector v from the *correct* DoodleCapsules, and 
    learns to decode it into an image of a doodle. Mask out the other (n_classes - 1) classes.
    Section 4 of the paper. For an illustrative explaination, see Figure 2 of the paper.

    Input:
      v: output of DoodleCapsules, of shape (batch_size x n_classes x 16)
      target: one-hot targets, of shape (batch_size, n_classes)

    Output:
      decoder constructed images, of shape (batch_size x 784)
    '''
    
    # TODO: the true target or the most probable?
    # correct = torch.sqrt((v ** 2).sum(2))  # (batch_size x n_classes)
    # correct = F.softmax(correct, dim=0)  # (batch_size x n_classes)
    # correct = correct.max(dim=1)[1]  # (batch_size)

    # Create the mask, which is 1 only for the correct class and 0 otherwise
    mask = target.type(torch.FloatTensor).unsqueeze(-1)  # (batch_size x n_classes x 1)
    mask = torch.repeat_interleave(mask, 16, dim=2)  # (batch_size x n_classes x 16)
    if self.opt.use_cuda & torch.cuda.is_available(): mask = mask.cuda()

    masked = v * mask  # (batch_size x n_classes x 16)
    return self.layers(masked)  # (batch_size x 784)

In [16]:
class CapsuleNetwork(nn.Module):
  '''Consists of a ReLU Convolution layer, a PrimaryCapsules layer, a DoodleCapsules
  layer, and a Decoder layer. Section 4 of the paper.
  '''

  def __init__(self, opt):
    super(CapsuleNetwork, self).__init__()
    self.opt = opt
    self.ReLUConv1 = nn.Sequential(
        nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9),
        F.relu()
    )
    self.PrimaryCapsules = PrimaryCaps()
    self.DoodleCapsules = DoodleCapsules(opt)
    self.DoodleDecoder = DoodleDecoder(opt)
  
  def forward(self, x):
    '''Section 4 of the paper.

    Input:
      the input to the network, of shape (batch_size x 1 x 28 x 28)

    Output:
      the output of the network, of shape (batch_size x n_classes x 16)
    '''
    v = self.ReLUConv1(x)  # (batch_size x 256 x 20 x 20)
    v = self.PrimaryCaps(v)  # (batch_size x 1152 x 8)
    v = self.DoodleCapsules(v)  # (batch_size x n_classes x 16)
    return v  # (batch_size x n_classes x 16)

  def marginal_loss(self, v, target):
    '''Section 3, Equation (4) of the paper.

    Input:
      v: the output of the network, of shape (batch_size x n_classes x 16)
      target: the one-hot target, of shape (batch_size x n_classes)
      lambd: a scalor for down-weighting of the loss for absent doodle classes
    
    Output:
      marginal loss (a scalor summed over all batches and classes)
    '''
    v_norm = torch.sqrt((v ** 2).sum(dim=2))  # (batch_size x n_classes)
    zeros = torch.zeros(norm.size())  # (batch_size x n_classes)
    if self.opt.use_cuda & torch.cuda.is_available(): zeros = zeros.cuda()
    max1 = torch.max(zeros, 0.9 - norm) ** 2  # (batch_size x n_classes)
    max2 = torch.max(zeros, norm - 0.1) ** 2  # (batch_size x n_classes) 
    loss = target * max1 + (1 - target) * 0.5 * max2  # (batch_size x n_classes)
    return torch.sum(loss)  # scalor

  def reconstruction_loss(self, data, reconstruction):
    '''Reconstruction for regularization. Ecourages the doodle capsules to 
    encode the instantiation parameters of the input doodle. Section 4.1 of the paper.

    Input:
      data: the real image, of shape (batch_size, 784)
      reconstruction: the reconstructed image, of shape (batch_size, 784)

    Output:
      reconstruction loss (a scalor summed over all batches and classes)
    '''
    return torch.sum((reconstruction - image) ** 2)

  def loss(self, v, data, target):
    '''Loss is marginal loss + 0.0005 * reconstruction loss. 0.0005 to ensure
    the reconstruction loss does not dominate the training. Section 4.1 of the paper.

    Input:
      v: output of the network, of shape (batch_size x 10 x 16)
      target: one-hot target, of shape (batch_size x n_classes)
      data: the input to the network (the image), of shape (batch_size, 784)
    
    Output:
      averaged loss (a scalor) over batches 
    '''
    batch_size = data.size(0)
    marginal_loss = self.marginal_loss(v, target)  # scalor
    reconstruction = self.DoodleDecoder(v, target)  # (batch_size, 784)
    reconstruction_loss = self.reconstruction_loss(data, reconstruction)  # scalor
    return (marginal_loss + 0.0005 * reconstruction_loss)/batch_size  # scalor

## Training

In [17]:
# opt = {
#     iterations: 3,
#     lr: 0.005,
#     width: 28,
#     height: 28,
#     n_classes: 10
# }

# model = CapsuleNetwork(opt)
# if opt.use_cuda & torch.cuda.is_available(): model.cuda()

# train(model, opt, train_loader, test_loader)