# Dynamic Routing Between Capsules

This notebook implements and demonstrates the method introduced in this paper.

In [1]:
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

# How the vector inputs and outputs of a capsule are computed

The length of an output vector represents the probability that the entity represented by the capsule is present given the input. To do this, a squashing function is introduced that maps the output between 0 and 1.

In [19]:
def squash1(x):
    """https://github.com/higgsfield/Capsule-Network-Tutorial/blob/master/Capsule%20Network.ipynb"""
    squared_norm = (x ** 2).sum(-1, keepdim=True)
    return squared_norm * x / ((1. + squared_norm) * torch.sqrt(squared_norm))

def squash2(x):
    """My interpretation."""
    x_norm = x.norm(dim=-1, keepdim=True) ** 2
    return (x_norm / (1 + x_norm)) * (x / torch.sqrt(x_norm))

def squash3(x, dim=-1):
    """https://github.com/gram-ai/capsule-networks/blob/master/capsule_network.py"""
    squared_norm = (x ** 2).sum(dim=dim, keepdim=True)
    scale = squared_norm / (1 + squared_norm)
    return scale * x / torch.sqrt(squared_norm)

x = torch.rand(1, 4, 2)
t1 = time.time()
s1 = squash1(x)
print(time.time() - t1)
t1 = time.time()
s2 = squash2(x)
print(time.time() - t1)
t1 = time.time()
s3 = squash3(x)
print(time.time() - t1)
print(s1)
print(s2)
print(s3)

0.0005314350128173828
0.0005106925964355469
0.00038743019104003906
tensor([[[0.1753, 0.1390],
         [0.0590, 0.2599],
         [0.3021, 0.0313],
         [0.2541, 0.0536]]])
tensor([[[0.1753, 0.1390],
         [0.0590, 0.2599],
         [0.3021, 0.0313],
         [0.2541, 0.0536]]])
tensor([[[0.1753, 0.1390],
         [0.0590, 0.2599],
         [0.3021, 0.0313],
         [0.2541, 0.0536]]])


## What is `s` and where does it come from?

`s_j` is the input to capsule *j*. It comes from a weighted sum of prediction vectors `u_hat` and coupling coefficients (more on that later).

In [21]:
N = 4  # number of capsules in layer 0
M = 2  # number of capsules in layer 1
D = 3  # dimension of capsule in layer 0
c = torch.rand(M, N)
u_hat = torch.rand(N, D)

s = torch.empty(M, D)

# Calculating s_j
for j in range(M):
    for i in range(N):
        s[j] += c[j,i] * u_hat[i]
print(s)
print(c @ u_hat)
print(squash2(c @ u_hat))

tensor([[2.0669, 0.7517, 1.7444],
        [1.5092, 0.5192, 1.4241]])
tensor([[2.0669, 0.7517, 1.7444],
        [1.5092, 0.5192, 1.4241]])
tensor([[0.6534, 0.2376, 0.5514],
        [0.5790, 0.1992, 0.5464]])


# Routing

Measure the agreement between the current output of each capsule in the higher layer with each prediction from the lower capsules.

In [23]:
def routing(u_hat, r, l):
    """Computes the output vector for the capsule while updating the coupling
    coefficients.
    
    Args:
        u_hat (D): Prediction from lower level capsule.
        r (int): Number of iterations.
        l (int): Current layer.
    Returns:
        v (D): Output vector.
    """
    b = torch.zeros(M, N)
    
    for i in range(r):
        # for all capsule i in layer l
        c = F.softmax(b, dim=-1)
        
        # for all capsule j in layer (l+1)
        s = c @ u_hat
        print(u_hat)
        
        # for all capsule j in layer (l+1)
        v = squash2(s)
        
        if i < r - 1:
            b += v @ u_hat.transpose(0, 1)
            
        return v
        
v = routing(u_hat, 3, 0)
print(v)

tensor([[0.9684, 0.3446, 0.8407],
        [0.6700, 0.0335, 0.7692],
        [0.6673, 0.9354, 0.4818],
        [0.8452, 0.3606, 0.5036]])
tensor([[0.3920, 0.2083, 0.3229],
        [0.3920, 0.2083, 0.3229]])


# Margin loss for digit existence



In [24]:
m_plus = 0.9
m_minus = 0.1
lam = 0.5
labels = torch.empty(4, 10, dtype=torch.float32).random_(2)
print(labels)

x = torch.rand(4, 10, 16)
v_c = x.norm(p=2, dim=2, keepdim=True)
left = F.relu(m_plus - v_c).view(4, -1)
right = F.relu(v_c - m_minus).view(4, -1)
loss = labels * left + lam * (1.0 - labels) * right
loss = loss.sum(dim=1).mean()
print(loss)

tensor([[1., 0., 1., 1., 1., 0., 1., 1., 1., 1.],
        [0., 0., 1., 1., 0., 0., 1., 1., 0., 0.],
        [1., 1., 0., 0., 1., 0., 1., 0., 0., 1.],
        [0., 1., 1., 0., 1., 1., 0., 1., 1., 0.]])
tensor(4.7228)


# Primary Capsule

Primary capsules consist of a convolutional layer with 32 channels of convolutional 8D capsules. This means that each primary capsule contains 8 conv units with a 9x9 kernel and stride 2. `PrimaryCapsules` has (32x6x6) capsule outputs.

In [76]:
class PrimaryCapsule(nn.Module):
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9):
        super(PrimaryCapsule, self).__init__()
        # Initialize the convolutional capsules
        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=2, padding=0) for _ in range(num_capsules)
        ])
        
    def forward(self, x):
        # Prediction for each capsule
        u = [capsule(x) for capsule in self.capsules]
        u = torch.stack(u, dim=1)
        u = u.view(x.size(0), 32 * 6 * 6, -1)
        return squash2(u)
        
cap = PrimaryCapsule()
conv1 = nn.Conv2d(1, 256, kernel_size=9, stride=1)
x = torch.rand(4, 1, 28, 28)
x = conv1.forward(x)
x = cap(x)
print(x.shape)
w = torch.randn(10, 32 * 6 * 6, 8, 16)
r = w[None, :, :, :, :]
t = x[:, None, :, None, :]
print(r.shape)
print(t.shape)
print((t@r).shape)

torch.Size([4, 1152, 8])
torch.Size([1, 10, 1152, 8, 16])
torch.Size([4, 1, 1152, 1, 8])
torch.Size([4, 10, 1152, 1, 16])


# Digit Capsule
Digit capsules implement the routing mechanism to determine the part-whole relationships based on the predictions of the primary capsules.

In [None]:
class DigitCap(nn.Module):
    def __init__(self, num_capsules=10, in_size=32 * 6 * 6, in_channels=8, out_channels=16,
                num_iterations=3):
        self.num_iterations = 3
        self.W = nn.Parameter(torch.randn(num_capsules, in_size))
    
    def forward(self, x):
        pred = x[:, None, :, None, :] @ self.W[None, :, :, :, :]
        pred = pred.squeeze(3)  # necessary?
        logits = torch.zeros(pred.shape, device=x.device)
        
        for i in range(self.num_iterations):
            probs = F.softmax(logits, dim=2)
            outputs = squash2((probs * pred).sum(dim=2, keepdim=True))