# Hinton's Capsule Networks

<div style = "text-align:justify"> We have given a brief intro to Capsules already. In this notebook we will describe the dynamic routing algorithm between capsules and a capsule net architecture called CapsNet for recognizing and reconstructing handwritten digits from MNIST dataset. We will also implement this architecture. </div>

First, let's load the required packages.

In [78]:
#run this cell
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader, sampler
from torch.optim import lr_scheduler
import numpy as np
import matplotlib.pyplot as plt
import h5py
import time
import torchvision.utils
from capsnet_utils import *
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


<div style = "text-align:justify"> Think of capsule as a group of neurons. In traditional CNN's, the activity of a neuron is a scalar. But the activity or the ouput of a capsule is a vector. So, if a capsule is made up of 8 neurons, its activity is a 8 dimensional vector. Length of the vector will be the probability of presence of an entity/feature. It's orientation encodes other instantiation parameters associated with the entity like pose, illumination, deformation etc. Since the length has to be probability it has to be squashed between 0 and 1. We use the following non-linear function to squash the length.</div>
<br>
<center>$v_j=\frac{||s_j||^2}{1+||s_j||^2}\frac{s_j}{||s_j||}\tag{1}$</center> 

<div style = "text-align:justify">Let's write this squash function. One point to keep in mind is that we will write a vectorized code that squashes lengths of all capsules in a particular layer for the entire mini-batch. So $s_j$ will of dimension *batch_size x #capsules in the layer x capsule_dim x 1.*</div>

In [18]:
def my_squash(s, dim = 2):
    
    s_norm = torch.norm(s, p = 2, dim = dim, keepdim = True)
    s_norm_sqr = s_norm ** 2     
    scalar = s_norm_sqr / (1. + s_norm_sqr)
    v = scalar * (s / s_norm)
    return v

<div style = "text-align:justify"> Like how neurons in a lower layer connect to neurons in higher layer, lower layer capsules connect to capsules in higher layer. But how? First, each capsule predicts the activity of every capsule in the higher layer. This prediction is obtained by multiplying the activity vector of the lower layer capsule with a weight matrix that ouputs an prediction vector for the higher layer capsule. See Fig 1.</div> 

<img src="images/capsule_vs_ordinary.jpeg" style="width:450px;height:250px;">
<caption><center> <u> <font color='purple'> **Figure 1** </u><font color='purple'>  : **Capsule vs Ordinary neuron [[source]](https://medium.com/ai%C2%B3-theory-practice-business/understanding-hintons-capsule-networks-part-ii-how-capsules-work-153b6ade9f66)** </center></caption>
<br>
<div style = "text-align:justify"> Let us say $u_1$, $u_2$ and $u_3$ are activity vectors of 3 capsules in the lower layer. Consider capsule j in the higher layer. $u_1$ may encode position of nose, $u_2$ may encode mouth and $u_3$ may encode right eye. Capsule j in the higher layer may encode a face. Matrix $w_{ij}$ encodes the relationship between $u_1$ and capsule j, for eg that face is centred around nose, face is 10 times bigger than nose and its orientation corresponds to orientation of the nose. Similarly for other capsules. In other words the predictions, say, $\hat{u}_{j|1}$, $\hat{u}_{j|2}$ and $\hat{u}_{j|3}$ represent where the face should be according to the detected positions of nose, mouth and right eye respectively. Mathematically,</div>
<center>$\hat{u}_{j|i} = W_{ij}u_{i}$</center>
This was missing in CNN!! 
<div style = "text-align:justify"> Let's code this. Note that again we will be writing a vectorized code. $W$ is tensor of dimension *batch_size x # capsules in lower_layer x # capsules in higher_layer x capsule_dim in higher layer x capsule_dim in lower layer.* Dimension of $u$ is *batch_size x # capsules in lower_layer x capsule_dim of lower layer x 1* 



In [19]:
def predict_activity(W, u):
    
    num_capsules_higher = W.size(2)
    
    # size of W: batch_size x # capsules in lower_layer x # capsules in higher_layer 
    #                       x capsule_dim in higher layer x capsule_dim in lower layer
    # size of u: batch_size x # capsules in lower_layer  
    #                       x capsule_dim in lower layer x 1
    # if we stack u along dim 2 # capsules in higher_layer times, size of u will become
    #           batch_size x # capsules in lower_layer x # capsules in higher_layer
    #                      x capsule_dim in lower layer x 1
    # so we can directly invoke torch.matmul to do the multiplication of two tensors and get u_hat whose size is
    #           batch_size x # capsules in lower_layer x # capsules in higher_layer 
    #                      x capsule_dim in higher layer x 1
    
    u = torch.stack([u]*num_capsules_higher, dim = 2)
    u_hat = torch.matmul(W, u) 
    return u_hat

<div style = "text-align:justify">Individual predictions about higher layer capsules by lower layer capsules are available. We need to decide which capsules in higher layer should be coupled with which capsules in lower layer. Suppose capsule i in lower layer encodes position of nose that  is oriented frontal. And, say the three capsules in higher layer encode a frontal face, a profile face and an aeroplane. Then clearly capsule i should be coupled with first capsule in the higher layer. This is the essence of dynamic routing algorithm. It couples capsules between layers that agree with each other. More they agree, higher is the probability of coupling. !!Remember the inverse graphics!! - the visual scene is deconstructed into patterns/features,then matched with agreeable higher level features to recognize the scene. Coupling probabilities are denoted by $c_{ij}$. Note that</div>
<center>$\sum_{j}c_{ij} = 1$</center>
<div style = "text-align:justify"> Since $c_{ij}$'s are probabilities for each i, we will compute them using softmax. i.e </div>
<center>$c_{ij} = \frac{e^{b_{ij}}}{\sum_{j}{e^{b_{ij}}}}\tag{3}$</center> where $b_{ij}$'s are initial priors that are learnt discriminatively at the same time as all other weights.
<div style = "text-align:justify"> The final prediction $s_j$ for capsule j will be  a weighted linear combination of the individual predictions. That is,</div>
<center>$s_j=\sum_{i}c_{ij}\hat{u}_{j|i}$</center> <br> 
<div style = "text-align:justify"> $v_j$'s are squashed $s_j$'s.</div>
Priors $b_{ij}$ are updated iteratively as  
<center>$b_{ij}\,\,+=\, \hat{u}_{j|i}^T\,v_j$</center> where the last term which is a dot product quantifies the agreement between capsule i and capsule j. The number of routing iterations will be denoted by r.
<br>
The dynamic routing algorithm is shown in Fig 2. Let's code the dynamic routing algorithm.
<img src="images/routing.png" style="width:550px;height:250px;">
<caption><center> <u> <font color='purple'> **Figure 2** </u><font color='purple'>  : **Dynamic Routing Algorithm [[source]](https://arxiv.org/pdf/1710.09829.pdf)** </center></caption>
<br>


In [20]:
def dynamic_routing(u_hat, r):
    """
    Ip:
        u_hat: batch_size x #caps in layer l x #caps in layer l+1 x cap_dim in layer l+1 x 1
        r: number of routing iterations    
    Return:
        v: activity vectors in layer l+1.
           size of v is batch_size x #caps in layer l+1 x cap_dim in layer l+1 x 1
    """
    batch_size = u_hat.size(0)
    b = Variable(torch.zeros(1, u_hat.size(1), u_hat.size(2), 1))
    if torch.cuda.is_available():
        b = b.cuda()
    for riter in range(r):
        c = softmax(b, dim = 1) #see eqn(3)
        c = torch.cat([c] * batch_size, dim = 0).unsqueeze(4) # to take advantage of python 
                                                                # broadcasting to do step 5 in Fig 2
        s = torch.sum(c * u_hat, dim = 1) # step 5
        v = my_squash(s, dim = 2) # step 6
        v_temp = torch.stack([v] * u_hat.size(1), dim = 1) # required for vectorizing step 7 in Fig 2
        b = b + torch.matmul(u_hat.transpose(3, 4), v_temp).squeeze(4).mean(dim = 0, keepdim = True)
                                                           # step 7
    return v  


<div style = "text-align:justify"> Let's come to the CapsNet architecture for handwritten digit recognition. It is shown in Fig 3.</div>
<img src="images/capsnet1.png" style="width:650px;height:150px;">
<caption><center> <u> <font color='purple'> **Figure 3** </u><font color='purple'>  : **CapsNet  architecture [[source]](https://arxiv.org/pdf/1710.09829.pdf)** </center></caption>
<br>
<div style = "text-align:justify"> It takes 28 x 28 gray scale input. A standard ReLU convolution layer follows this with 256 filters each of size 9 x 9 with stride size = (1, 1) and no padding. So the output size is *batch_size x 256 x 20 x 20.* Then comes the two capsule layers - the primary capsule layer and the next level capsule layer called as digits capsule layer.</div>
<br>
Characterisitcs of primary capsule layer are as follows:
- Each capsule is 8D. Each capsule is made up of 8 ReLU conv units with filter size 9 x 9 and stride size =                                                                                     (2, 2). No padding. So each conv unit will reduce 20 x 20 to 6 x 6
- There are 32 blocks of capsules. So, the total number of capsules are 32 \* 6 \* 6 = 1152

Characterisitcs of digits capsule layer are as follows:
- Each capsule is 16D
- There are 10 capsules corresponding to 10 digits
- $W_{ij}$ is a 8 x 16 matrix required for capsule i in primary layer to predict capsule j in digits capsule layer
- digits capsule layers activities are based on routing. No convolutional/dense units involved.

Note that digits capsule layer is based on routing while primary capsule layer gets input from previous convolutional layer.

Let's code *CapsuleLayer* class and then *CapsNet* class which uses the former class.

In [21]:
class CapsuleLayer(nn.Module):
    
    def __init__(self, use_routing = True):
        
        super().__init__()
        self.use_routing = use_routing
        if not self.use_routing: # create primary capsule layer
            self.conv_units = nn.ModuleList([nn.Conv2d(256, 32, kernel_size = (9, 9),
                                                          stride = (2, 2)) for i in range(8)])
        else: # digits capsule; so create W as a learnable set of parameters initialized randomly
            self.W = nn.Parameter(torch.randn(1, 1152, 10, 16, 8)) # 1st dim is 1 since right now we dont
                                                                   # know batch_size
            self.routing_iterations = 3
    
    def forward(self, x):
        
        if not self.use_routing:  # forward prop through primary capsule layer          
            outputs = [m(x) for i, m in enumerate(self.conv_units)] # a list of 8 outputs
            outputs = torch.stack(outputs, dim = 4) # stack all of them along dim 4
            #assert capsules.size() == [x.size(0), 32, 6, 6, 8]
            s = outputs.view(x.size(0), -1, 8).unsqueeze(dim = 3) 
                                                # reshape to size batch_size x 1152 x 8 x 1
            v = my_squash(s, 2) # squash along dim 2           
            #assert capsules.size() == [x.size(0), 1152, 8, 1]
            return v
        
        else: #forward prop through digits capsule by prediciting activity and routing
            u_hat = predict_activity(self.W, x)
            v = dynamic_routing(u_hat, self.routing_iterations)
            return v   
    

The CapsNet class....

In [20]:
#self explanatory - For a particular reason which will be explained later we will call this CapsNet1
class CapsNet1(nn.Module):
    
    def __init__(self):
        
        super().__init__()       
        self.conv = nn.Conv2d(1, 256, kernel_size = (9, 9), stride = (1, 1))        
        self.relu = nn.ReLU()
        self.primary_capsule = CapsuleLayer(use_routing = False)
        self.digits_capsule = CapsuleLayer(use_routing = True)        
               
    def forward(self, x):
        
        x = self.conv(x)        
        x = self.relu(x)
        x = self.primary_capsule(x)
        x = self.digits_capsule(x)            
        return x       
    

<div style = "text-align:justify"> Now we need to compute loss. We have two types of losses - margin loss for digit recognition and reconstruction loss for digit reconstruction. Margin Loss is defined as shown in the following equation.
<img src="images/margin_loss.png" style="width:450px;height:40px;">
<br>
where $T_k$ is 1 iff digit of class $k$ is present, $m^+ = 0.9$, $m^- = 0.1$ and $\lambda = 0.5$. The total margin loss is simply the sum of the losses of all digit capsules.

Let's code this loss.

In [21]:
#self explanatory - For a particular reason which will be explained later we will call this margin_loss1
def margin_loss1(v, labels, size_average = True):
    
        """
        Ip:
        v is batch_size x 10 x 16 x 1
        labels is batch_size x 10 (one_hot_encoded labels)
        
        Returns:
        L which is the margin loss
        """
        m_plus = 0.9
        m_minus = 0.1
        lambd = 0.5        
        norm_v = torch.norm(v, p = 2, dim = 2).squeeze() # norm_v is batch_size x 10
        Lk_first_term = m_plus - norm_v
        Lk_first_term[Lk_first_term < 0] = 0
        Lk_first_term = Lk_first_term ** 2        
        Lk_first_term = Lk_first_term * labels

        Lk_second_term = norm_v - m_minus
        Lk_second_term[Lk_second_term < 0] = 0
        Lk_second_term = Lk_second_term ** 2
        Lk_second_term = lambd * (1 - labels) * Lk_second_term 

        Lk = Lk_first_term + Lk_second_term
        L = torch.sum(L, dim = 1)

        if size_average: # average over the batch
            L = L.mean()
        return L
        

<div style = "text_align:justify"> For reconstruction loss, we will use MSE loss. The combined loss is defined as margin_loss + 0.0005 \* reconstruction_loss.</div>

<div style = "text_align:justify">But before we code reconstruction loss we need to reconstruct the image from the output vectors of the digits capsule layer. Towards this we will use the decoder shown in Fig 4.</div>
<img src="images/capsnet2.png" style="width:450px;height:150px;">
<caption><center> <u> <font color='purple'> **Figure 4** </u><font color='purple'>  : **Decoder for reconstruction [[source]](https://arxiv.org/pdf/1710.09829.pdf)** </center></caption>
<br>
<div style = "text_align:justify"> In Fig 4, in the first block, orange color indicates the activity vector with the largest length while the grayish blue indicates activities masked to zero. The activity with largest length is forwarded while other activity vectors are masked to zero. First we will code the reconstruction loss assuming we get the ouput from the decoder. Then we will code the decoder as part of *CapsNet* class and use it whenever the boolean variable *use_reconstruction* is set. We are going to make *margin_loss* and *reconstruction_loss* defined above as part of *CapsNet* class for ease of implementation. </div> 

In [17]:
#self explanatory - For a particular reason which will be explained later we will call this reconstruction_loss1
def reconstruction_loss1(x, images, size_average = True):
    """
    Ip:
    x is batch_size x 784
    images is batch_size x 28 x 28
    
    Returns:
    L which is reconstruction loss
    """
    L = (x - images.view(x.size(0), -1)) ** 2
    L = torch.sum(L, dim = 1)
    if size_average:
        L = L.mean()
    return L

In [22]:
#self explanatory - now it should be clear why we called above class as CapsNet1
     # and the methods as margin_loss1 and reconstruction_loss1
    
class CapsNet(nn.Module):
    
    def __init__(self, use_reconstruction):
        
        super().__init__() 
        self.use_reconstruction = use_reconstruction
        self.conv = nn.Conv2d(1, 256, kernel_size = (9, 9), stride = (1, 1))        
        self.relu = nn.ReLU()
        self.primary_capsule = CapsuleLayer(use_routing = False)
        self.digits_capsule = CapsuleLayer(use_routing = True) 
        
        if self.use_reconstruction:
            self.fc1 = nn.Linear(10 * 16, 512)
            self.relu1 = nn.ReLU()
            self.fc2 = nn.Linear(512, 1024)
            self.relu2 = nn.ReLU()
            self.fc3 = nn.Linear(1024, 784)
            self.sigmoid = nn.Sigmoid()
                
               
    def forward(self, x):
        
        x = self.conv(x)        
        x = self.relu(x)
        x = self.primary_capsule(x)
        x = self.digits_capsule(x) 
        return x
    
    def model_loss(self, x, labels, images, size_average = True):
        mloss = self.margin_loss(x, labels, size_average)
        loss = mloss
        reconstructed = None
        if self.use_reconstruction:
            x = mask(x)
            x = x.view(-1, 10 * 16 * 1)
            x = self.fc1(x)
            x = self.relu1(x)
            x = self.fc2(x)
            x = self.relu2(x)
            x = self.fc3(x)
            reconstructed = self.sigmoid(x)
            rloss = self.reconstruction_loss(reconstructed, images, size_average)
            loss = mloss + 0.0005 * rloss
        return loss, reconstructed 
    
    def margin_loss(self, v, labels, size_average = True):
    
        """
        Ip:
        v is batch_size x 10 x 16 x 1
        labels is batch_size x 10 (one_hot_encoded labels)

        Returns:
        L which is the margin loss
        """
        m_plus = 0.9
        m_minus = 0.1
        lambd = 0.5        
        norm_v = torch.norm(v, p = 2, dim = 2).squeeze() # norm_v is batch_size x 10
        Lk_first_term = m_plus - norm_v
        Lk_first_term[Lk_first_term < 0] = 0
        Lk_first_term = Lk_first_term ** 2        
        Lk_first_term = Lk_first_term * labels

        Lk_second_term = norm_v - m_minus
        Lk_second_term[Lk_second_term < 0] = 0
        Lk_second_term = Lk_second_term ** 2
        Lk_second_term = lambd * (1 - labels) * Lk_second_term 

        Lk = Lk_first_term + Lk_second_term
        L = torch.sum(Lk, dim = 1)

        if size_average: # average over the batch
            L = L.mean()
        return L
    
    def reconstruction_loss(self, x, images, size_average = True):
        """
        Ip:
        x is batch_size x 784
        images is batch_size x 28 x 28

        Returns:
        L which is reconstruction loss
        """
        L = (x - images.view(x.size(0), -1)) ** 2
        L = torch.sum(L, dim = 1)
        if size_average:
            L = L.mean()
        return L

All set for training and testing...
<br>
Let's load data first.

In [23]:
train_loader, test_loader = load_mnist()

===> Loading training datasets
===> Loading testing datasets


In [8]:
def train_model(model, data_loader, optimizer, save_file, scheduler = None):
    
    since = time.time()
    train_loss_history = [] 
    num_epochs = model.num_epochs
    num_batches = len(data_loader)
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        model.train(True)
        if scheduler:
            scheduler.step()
        running_loss = 0.0
        
        for inputs, labels in data_loader:
            labels = one_hot_encode(labels, 10)
            inputs, labels = Variable(inputs), Variable(labels)
            #do not uncomment the four lines below; we will be working with CPU
            """
            if torch.cuda.is_available():
                inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
            else:
                inputs, labels = Variable(inputs), Variable(labels)
            """
            optimizer.zero_grad()
            outputs = model(inputs)
            loss, reconstructed = model.model_loss(outputs, labels, inputs)
                    
            loss.backward()
            optimizer.step()
            
            running_loss += loss.data[0]
             
        epoch_loss = running_loss / num_batches           
        train_loss_history.append(epoch_loss)        
        print('Train Loss: {:.8f}'.format(epoch_loss))
        
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
    torch.save(model.state_dict(), save_file)
    torch.save(train_loss_history, save_file + '_loss_history')
    return

In [10]:
# train the model - each epoch takes around 4 to 5 minutes in Titan-X GPU
# we have a pre-trianed model for you
#you need not train
caps_net = CapsNet(use_reconstruction = True)
optimizer = torch.optim.Adam(caps_net.parameters(), lr = 0.01)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size = 10, gamma = 0.95)
caps_net.num_epochs = 1
#if torch.cuda.is_available():
 #   caps_net.cuda()
#train_model(caps_net, training_data_loader, optimizer, 'caps_net_1.pth') 

Let's test for accuracy and also see the reconstructed images.

In [83]:
model = CapsNet(use_reconstruction = True)
#test_model(model, 'caps_net_100.pth', test_loader)

<div style = "text-align:justify"> Some of the original and reconstructed images are shown in Fig 5. Left column is original and right column is reconstructed. We clearly see that the CapsNet removes noise, smooths the digits and completes the digits. </div>


In [None]:
# Run this cell
from IPython.display import HTML, display
display(HTML("<table><tr><td><img src='results/original_image_test_29.png'></td>\
               <td><img src='results/reconstructed_image_test_29.png'> <caption><center>\
               <u> <font color='purple'> <font size = 4>Figure 5 </u><font color='purple'>\
               : Original and reconstructed digits</center></caption></td></tr></table>"))

Congratulations for completing this tutorial. More to come in the next two days. Hope you enjoyed. Sairam.