In [1]:
"""
Dynamic Routing Between Capsules
https://arxiv.org/abs/1710.09829

Heavily inspired from implementation by Kenta Iwasaki @ Gram.AI.
"""

import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
from torch.optim import Adam
from torchvision.datasets.mnist import MNIST
import numpy as np

# Select which gpu you want, I have 3. Otherwise remove all .cuda from code
gpu = 2

Capsule Convolutional Layer

In [2]:
class ConvCaps2D(nn.Module):
    def __init__(self):
        super(ConvCaps2D, self).__init__()
        self.capsules = nn.ModuleList([nn.Conv2d(in_channels = 256, out_channels = 8, kernel_size=9, stride=2)
                                       for _ in range(32)])
    def squash(self, tensor, dim=-1):
        norm = (tensor**2).sum(dim=dim, keepdim = True) # norm.size() is (None, 1152, 1)
        scale = norm / (1 + norm)        
        return scale*tensor / torch.sqrt(norm)
        
    def forward(self, x):
        outputs = [capsule(x).view(x.size(0), 8, -1) for capsule in self.capsules] # 32 list of (None, 1, 8, 36)
        outputs = torch.cat(outputs, dim = 2).permute(0, 2, 1)  # outputs.size() is (None, 1152, 8)
        return self.squash(outputs)

DigitCaps Layer

In [3]:
class Caps1D(nn.Module):
    def __init__(self):
        super(Caps1D, self).__init__()
        self.num_caps = 10
        self.num_iterations = 3
        self.routing_matrix = nn.Parameter(torch.randn(10, 1152, 8, 16))
        
    def softmax(self, x, dim = 1):
        transposed_input = x.transpose(dim, len(x.size()) - 1)
        softmaxed_output = F.softmax(transposed_input.contiguous().view(-1, transposed_input.size(-1)))
        return softmaxed_output.view(*transposed_input.size()).transpose(dim, len(x.size()) - 1)

    def squash(self, tensor, dim=-1):
        norm = (tensor**2).sum(dim=dim, keepdim = True) # norm.size() is (None, 1152, 1)
        scale = norm / (1 + norm)        
        return scale*tensor / torch.sqrt(norm)
    
    def forward(self, x):
        # x.size() is (None, 1152, 8)
        '''
        From documentation
        For example, if tensor1 is a j x 1 x n x m Tensor and tensor2 is a k x m x p Tensor, 
        out will be an j x k x n x p Tensor.
        
        We need j = None, 1, n = 1152, k = 10, m = 8, p = 16
        '''
        b = torch.matmul(x[:, None, :, None, :], self.routing_matrix)
        logits = Variable(torch.zeros(b.size())).cuda(gpu)
        
        for i in range(self.num_iterations):
            probs = self.softmax(logits, dim=2)
            outputs = self.squash((probs * b).sum(dim=2, keepdim=True))

            if i != self.num_iterations - 1:
                delta_logits = (b * outputs).sum(dim=-1, keepdim=True)
                logits = logits + delta_logits

        
        outputs = outputs.squeeze()
        classes = (outputs ** 2).sum(dim=-1) ** 0.5
        classes = F.softmax(classes)
        
        return classes

In [4]:
class CapsNet(nn.Module):
    def __init__(self):
        super(CapsNet, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 256, kernel_size = 9, stride = 1)
        
        self.primaryCaps = ConvCaps2D()
        self.digitCaps = Caps1D()
        
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.primaryCaps(x)
        x = self.digitCaps(x)
        
        return x

net = CapsNet().cuda(gpu)
# out = net.forward(Variable(torch.from_numpy(np.random.rand(32, 1, 28, 28).astype('float32')).cuda(gpu)))

In [5]:
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters())

In [6]:
def evaluate(model, X, Y, batch_size = 50):
    results = []
    predicted = []
    
    for i in range(len(X)/batch_size):
        s = i*batch_size
        e = i*batch_size+batch_size
        
        inputs = Variable(torch.from_numpy(X[s:e]).cuda(gpu))
        pred = model(inputs)
        
        predicted += list(np.argmax(pred.data.cpu().numpy(), axis = 1))
        
    acc = sum(Y == predicted)*1.0/(len(Y))
    return acc

Load the data

In [7]:
dataset = MNIST(root='./data', download=True, train=True)

X_train = np.expand_dims(getattr(dataset, 'train_data').numpy().astype('float32'), 1)/255.0
y_train = getattr(dataset, 'train_labels').numpy().astype('int64')

indices = np.random.permutation(len(X_train))
X_train = X_train[indices]
y_train = y_train[indices]

dataset = MNIST(root='./data', download=True, train=False)

X_test = np.expand_dims(getattr(dataset, 'test_data').numpy().astype('float32'), 1)/255.0
y_test = getattr(dataset, 'test_labels').numpy().astype('int64')

indices = np.random.permutation(len(X_test))
X_test = X_test[indices]
y_test = y_test[indices]

In [None]:
batch_size = 200

for epoch in range(50):  # 50 epochs
    print "\nEpoch ", epoch
    
    running_loss = 0.0
    for i in range(len(X_train)/batch_size-1):
        print i, 
        s = i*batch_size
        e = i*batch_size+batch_size
        
        inputs = torch.from_numpy(X_train[s:e])
        labels = torch.LongTensor(np.array(y_train[s:e]))
        
        
        # wrap them in Variable
        inputs, labels = Variable(inputs.cuda(gpu)), Variable(labels.cuda(gpu))

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        
        loss = criterion(outputs, labels)
        loss.backward()
        
        optimizer.step()
        running_loss += loss.data[0]
        
        del inputs, labels
    
    # Validation accuracy
    print "Epoch, Loss - {}, {}".format(epoch, running_loss)
    print "Train - ", evaluate(net, X_train[0:1000], y_train[0:1000], batch_size = 100)
    print "Test - ", evaluate(net, X_test, y_test, 100)


Epoch  0
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102