# Dynamic Routing Between Capsules

This notebook implements and demonstrates the method introduced in this paper.

In [1]:
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

# How the vector inputs and outputs of a capsule are computed

The length of an output vector represents the probability that the entity represented by the capsule is present given the input. To do this, a squashing function is introduced that maps the output between 0 and 1.

In [2]:
def squash1(x):
    """https://github.com/higgsfield/Capsule-Network-Tutorial/blob/master/Capsule%20Network.ipynb"""
    squared_norm = (x ** 2).sum(-1, keepdim=True)
    return squared_norm * x / ((1. + squared_norm) * torch.sqrt(squared_norm))

def squash2(x):
    """My interpretation."""
    x_norm = x.norm(dim=-1, keepdim=True) ** 2
    return (x_norm / (1 + x_norm)) * (x / torch.sqrt(x_norm))

def squash3(x, dim=-1):
    """https://github.com/gram-ai/capsule-networks/blob/master/capsule_network.py"""
    squared_norm = (x ** 2).sum(dim=dim, keepdim=True)
    scale = squared_norm / (1 + squared_norm)
    return scale * x / torch.sqrt(squared_norm)

x = torch.rand(1, 4, 2)
t1 = time.time()
s1 = squash1(x)
print(time.time() - t1)
t1 = time.time()
s2 = squash2(x)
print(time.time() - t1)
t1 = time.time()
s3 = squash3(x)
print(time.time() - t1)
print(s1)
print(s2)
print(s3)

0.0456385612487793
0.0020627975463867188
0.00041675567626953125
tensor([[[0.4407, 0.2494],
         [0.4033, 0.2699],
         [0.1853, 0.4685],
         [0.2928, 0.2110]]])
tensor([[[0.4407, 0.2494],
         [0.4033, 0.2699],
         [0.1853, 0.4685],
         [0.2928, 0.2110]]])
tensor([[[0.4407, 0.2494],
         [0.4033, 0.2699],
         [0.1853, 0.4685],
         [0.2928, 0.2110]]])


## What is `s` and where does it come from?

`s_j` is the input to capsule *j*. It comes from a weighted sum of prediction vectors `u_hat` and coupling coefficients (more on that later).

In [3]:
N = 4  # number of capsules in layer 0
M = 2  # number of capsules in layer 1
D = 3  # dimension of capsule in layer 0
c = torch.rand(M, N)
u_hat = torch.rand(N, D)

s = torch.empty(M, D)

# Calculating s_j
for j in range(M):
    for i in range(N):
        s[j] += c[j,i] * u_hat[i]
print(s)
print(c @ u_hat)
print(squash2(c @ u_hat))

tensor([[0.9763, 0.5313, 0.7494],
        [1.3963, 0.5850, 1.1037]])
tensor([[0.9763, 0.5313, 0.7494],
        [1.3963, 0.5850, 1.1037]])
tensor([[0.4679, 0.2546, 0.3592],
        [0.5800, 0.2430, 0.4585]])


# Routing

Measure the agreement between the current output of each capsule in the higher layer with each prediction from the lower capsules.

In [4]:
def routing(u_hat, r, l):
    """Computes the output vector for the capsule while updating the coupling
    coefficients.
    
    Args:
        u_hat (D): Prediction from lower level capsule.
        r (int): Number of iterations.
        l (int): Current layer.
    Returns:
        v (D): Output vector.
    """
    b = torch.zeros(M, N)
    
    for i in range(r):
        # for all capsule i in layer l
        c = F.softmax(b, dim=-1)
        
        # for all capsule j in layer (l+1)
        s = c @ u_hat
        print(u_hat)
        
        # for all capsule j in layer (l+1)
        v = squash2(s)
        
        if i < r - 1:
            b += v @ u_hat.transpose(0, 1)
            
        return v
        
v = routing(u_hat, 3, 0)
print(v)

tensor([[0.5693, 0.3967, 0.1622],
        [0.1070, 0.1079, 0.8489],
        [0.0482, 0.4059, 0.0409],
        [0.8526, 0.1019, 0.7829]])
tensor([[0.1808, 0.1161, 0.2104],
        [0.1808, 0.1161, 0.2104]])


# Margin loss for digit existence



In [75]:
m_plus = 0.9
m_minus = 0.1
lam = 0.5
labels = torch.empty(4, 10, dtype=torch.float32).random_(2)

x = torch.rand(4, 10, 16)
v_c = x.norm(p=2, dim=-1, keepdim=True)
left = F.relu(m_plus - v_c).view(4, -1)
right = F.relu(v_c - m_minus).view(4, -1)
loss = labels * left + lam * (1.0 - labels) * right
loss = loss.sum(dim=1).mean()
print(loss)

tensor(7.3688)


# Primary Capsule

Primary capsules consist of a convolutional layer with 32 channels of convolutional 8D capsules. This means that each primary capsule contains 8 conv units with a 9x9 kernel and stride 2. `PrimaryCapsules` has (32x6x6) capsule outputs.

In [77]:
class PrimaryCapsule(nn.Module):
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9):
        super(PrimaryCapsule, self).__init__()
        # Initialize the convolutional capsules
        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=2, padding=0) for _ in range(num_capsules)
        ])
        
    def forward(self, x):
        # Prediction for each capsule
        u = [capsule(x) for capsule in self.capsules]
        u = torch.stack(u, dim=1)
        u = u.view(x.size(0), 32 * 6 * 6, -1)
        return squash2(u)
        
cap = PrimaryCapsule()
conv1 = nn.Conv2d(1, 256, kernel_size=9, stride=1)
x = torch.rand(4, 1, 28, 28)
x = conv1.forward(x)
x = cap(x)

# Digit Capsule
Digit capsules implement the routing mechanism to determine the part-whole relationships based on the predictions of the primary capsules.

  $W_{ij}$ - Learned weights which are multiplied with the capsule output $u_i$ to produce the prediction vectors $\hat{u}_{j|i}$.
  
  $b_{ij}$ - Log prior probabilities that capsule $i$ should be coupled to capsule $j$.
  
  $c_{ij}$ - Coupling coefficients between capsule $i$ and all capsules in the higher layer. These sum to 1 for each capsule in the lower layer.

In [90]:
class DigitCap(nn.Module):
    def __init__(self, num_capsules=10, in_size=32 * 6 * 6, in_channels=8, out_channels=16,
                num_iterations=3):
        super(DigitCap, self).__init__()
        self.num_iterations = 3
        self.W = nn.Parameter(torch.randn(num_capsules, in_size, in_channels, out_channels))
    
    def forward(self, x):
        u_ji = x[:, None, :, None, :] @ self.W[None, :, :, :, :]
        u_ji = u_ji.squeeze(3)
        b_ij = torch.zeros(u_ji.shape, device=x.device)
        
        for i in range(self.num_iterations):
            c_ij = F.softmax(b_ij, dim=2)
            v_j = squash2((c_ij * u_ji).sum(dim=2, keepdim=True))
            
            if i < self.num_iterations - 1:
                a_ij = (u_ji * v_j).sum(dim=-1, keepdim=True)
                b_ij = b_ij + a_ij
        
        return v_j.squeeze(2)
    
dcap = DigitCap()
v = dcap(x)

# Reconstruction as a regularization method

A decoder is proposed which would attempt to reconstruct the digit image using the activation of the corresponding digit vector in `DigitCap`. This is done by masking out all activity vectors except for the one corresponding to the correct label. This is then transformed through 3 fully connected layers. The output is a 784 dimensional vector.

In [121]:
class Decoder(nn.Module):
    def __init__(self, in_channels=16 * 10):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(in_channels, 512)
        self.fc2 = nn.Linear(512, 1024)
        self.fc3 = nn.Linear(1024, 784)
        
    def forward(self, x, labels):
        mask = torch.sparse.torch.eye(10)
#         mask = mask.index_select(dim=0, index=labels.squeeze(1))
        x = x * mask[labels.squeeze(1), :, None]
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x.view(-1, 1, 28, 28)
        
x = torch.randn(2, 10, 16)
labels = torch.randint(0, 9, (2, 1))
dec = Decoder()
img = dec(x, labels)
print(F.mse_loss(img, img))

tensor(0., grad_fn=<MeanBackward1>)
