# Import Dependencies

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn.functional as F
from torch.autograd import Variable
from tqdm import tqdm_notebook
from collections import defaultdict

# Load MNIST

In [None]:
transforms = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop((28, 28), padding=2),
    torchvision.transforms.ToTensor(),
])
trn_dataset = torchvision.datasets.MNIST('.', train=True, download=True, transform=transforms)
tst_dataset = torchvision.datasets.MNIST('.', train=False, download=True, transform=transforms)

In [None]:
batch_size = 128
trn_loader = torch.utils.data.DataLoader(trn_dataset, batch_size, shuffle=True)
tst_loader = torch.utils.data.DataLoader(tst_dataset, batch_size, shuffle=False)

# Define CapsNet

In [None]:
class Decoder(torch.nn.Module):
    def __init__(self, in_features, out_features):
        super(Decoder, self).__init__()
        self.decoder = self.assemble_decoder(in_features, out_features)
    
    def assemble_decoder(self, in_features, out_features):
        return torch.nn.Sequential(
            torch.nn.Linear(in_features, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 1024),
            torch.nn.ReLU(),
            torch.nn.Linear(1024, out_features),
            torch.nn.Sigmoid(),
        )
    
    def forward(self, x, y):
        x = x[np.arange(0, x.size()[0]), y.cpu().data.numpy(), :].cuda()
        x = self.decoder(x)
        x = x.view(-1, 1, 28, 28)
        return x

In [None]:
class Norm(torch.nn.Module):
    def __init__(self):
        super(Norm, self).__init__()
    
    def forward(self, x):
        x = torch.norm(x, p=2, dim=-1)
        return x

In [None]:
class Routing(torch.nn.Module):
    def __init__(self, caps_size_before, caps_size_after, n_capsules_before, n_capsules_after):
        super(Routing, self).__init__()
        self.n_capsules_before = n_capsules_before
        self.n_capsules_after = n_capsules_after
        self.caps_size_before = caps_size_before
        self.caps_size_after = caps_size_after
        
        n_in = 1152 * 10 * 16 * 8
        variance = 2 / (n_in)
        std = np.sqrt(variance)
        self.W = torch.nn.Parameter(torch.randn(1152, 10, 16, 8) * std, requires_grad=True)
    
    # Equation (1)
    @staticmethod
    def squash(s):
        s_norm = torch.norm(s, p=2, dim=-1, keepdim=True)
        s_norm2 = torch.pow(s_norm, 2)
        v = (s_norm2 / (1.0 + s_norm2)) * (s / s_norm)
        return v
    
    # Equation (2)
    def affine(self, x):
        x = self.W @ x.unsqueeze(2).expand(-1, -1, 10, -1).unsqueeze(-1)
        return x.squeeze()
    
    # Equation (3)
    @staticmethod
    def softmax(x, dim=-1):
        exp = torch.exp(x)
        return exp / torch.sum(exp, dim, keepdim=True)
    
    # Procedure 1 - Routing algorithm.
    def routing(self, u, r, l):
        b = Variable(torch.zeros(u.size()[0], l[0], l[1]), requires_grad=False).cuda() # torch.Size([256, 1152, 10])
        
        for iteration in range(r):
            c = Routing.softmax(b) # torch.Size([256, 1152, 10])
            s = (c.unsqueeze(-1).expand(-1, -1, -1, u.size()[-1]) * u).sum(1) # torch.Size([256, 1152, 16])
            v = Routing.squash(s) # torch.Size([256, 10, 16])
            b += (u * v.unsqueeze(1).expand(-1, l[0], -1, -1)).sum(-1)
        return v
    
    def forward(self, x, n_routing_iter):
        x = x.view((-1, self.n_capsules_before, self.caps_size_before))
        x = self.affine(x) # torch.Size([256, 1152, 10, 16])
        x = self.routing(x, n_routing_iter, (self.n_capsules_before, self.n_capsules_after))
        return x

In [None]:
class PrimaryCapsules(torch.nn.Module):
    def __init__(self):
        super(PrimaryCapsules, self).__init__()
        self.capsule_dim = 8
        self.out_channels = 32
        self.input_shape = (256, 20, 20)
        self.kernel_size = 9
        self.stride = 2
        self.in_channels = self.input_shape[0]
        
        self.conv = torch.nn.Conv2d(
            self.in_channels,
            self.out_channels * self.capsule_dim,
            self.kernel_size,
            self.stride
        )
    
    def forward(self, x):
        x = self.conv(x)
        x = x.permute(0, 2, 3, 1).contiguous()
        x = x.view(-1, x.size()[1], x.size()[2], self.out_channels, self.capsule_dim)
        return x

In [None]:
class CapsNet(torch.nn.Module):
    def __init__(self, input_shape, n_routing_iter):
        super(CapsNet, self).__init__()
        assert len(input_shape) == 3
        
        self.input_shape = input_shape
        self.n_routing_iter = n_routing_iter
        
        self.conv1 = self.assemble_conv1(input_shape[0], 256, 9)
        self.primary_capsules = PrimaryCapsules()
        self.routing = Routing(8, 16, 6 * 6 * 32, 10)
        self.norm = Norm()
        self.decoder = Decoder(16, int(np.prod(input_shape)))
    
    def n_parameters(self):
        return np.sum([np.prod(x.size()) for x in self.parameters()])
    
    def assemble_conv1(self, in_channels, out_channels, kernel_size):
        return torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size),
            torch.nn.ReLU()
        )
    
    def forward(self, x, y=None):
        conv1 = self.conv1(x)
        primary_capsules = self.primary_capsules(conv1)
        digit_caps = self.routing(primary_capsules, self.n_routing_iter)
        scores = self.norm(digit_caps)
        reconstruction = None if y is None else self.decoder(digit_caps, y).view((-1,) + self.input_shape)
        return scores, reconstruction

In [None]:
model = CapsNet(input_shape=(1, 28, 28), n_routing_iter=3).cuda()
model

In [None]:
print('Number of Parameters: %d' % model.n_parameters())

# Define Loss Functions

In [None]:
def to_categorical(y, num_classes):
    """ 1-hot encodes a tensor """
    new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
    if (y.is_cuda):
        return new_y.cuda()
    return new_y

In [None]:
class MarginLoss(torch.nn.Module):
    def __init__(self, m_pos=0.9, m_neg=0.1, lamb=0.5):
        super(MarginLoss, self).__init__()
        self.m_pos = m_pos
        self.m_neg = m_neg
        self.lamb = lamb
    
    # Equation (4)
    def forward(self, scores, y):
        y = Variable(to_categorical(y, 10))
        
        Tc = y.float()
        loss_pos = torch.pow(torch.clamp(self.m_pos - scores, min=0), 2)
        loss_neg = torch.pow(torch.clamp(scores - self.m_neg, min=0), 2)
        loss = Tc * loss_pos + self.lamb * (1 - Tc) * loss_neg
        loss = loss.sum(-1)
        return loss.mean()

In [None]:
class SumSquaredDifferencesLoss(torch.nn.Module):
    def __init__(self):
        super(SumSquaredDifferencesLoss, self).__init__()
    
    def forward(self, x_reconstruction, x):
        loss = torch.pow(x - x_reconstruction, 2).sum(-1).sum(-1)
        return loss.mean()

In [None]:
class CapsNetLoss(torch.nn.Module):
    def __init__(self, reconstruction_loss_scale=0.0005):
        super(CapsNetLoss, self).__init__()
        self.digit_existance_criterion = MarginLoss()
        self.digit_reconstruction_criterion = SumSquaredDifferencesLoss()
        self.reconstruction_loss_scale = reconstruction_loss_scale
    
    def forward(self, x, y, x_reconstruction, scores):
        margin_loss = self.digit_existance_criterion(y_pred.cuda(), y)
        reconstruction_loss = self.reconstruction_loss_scale *\
                              self.digit_reconstruction_criterion(x_reconstruction, x)
        loss = margin_loss + reconstruction_loss
        return loss, margin_loss, reconstruction_loss

In [None]:
criterion = CapsNetLoss()

# Train

In [None]:
def exponential_decay(optimizer, learning_rate, global_step, decay_steps, decay_rate, staircase=False):
    if (staircase):
        decayed_learning_rate = learning_rate * np.power(decay_rate, global_step // decay_steps)
    else:
        decayed_learning_rate = learning_rate * np.power(decay_rate, global_step / decay_steps)
        
    for param_group in optimizer.param_groups:
        param_group['lr'] = decayed_learning_rate
    
    return optimizer

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08)

In [None]:
def save_checkpoint(epoch, train_accuracy, test_accuracy, model, optimizer, path=None):
    if (path is None):
        path = 'checkpoint-%f-%04d.pth' % (test_accuracy, epoch)
    state = {
        'epoch': epoch,
        'train_accuracy': train_accuracy,
        'test_accuracy': test_accuracy,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }
    torch.save(state, path)

In [None]:
def show_example(model, x, y, x_reconstruction, y_pred):
    x = x.squeeze().cpu().data.numpy()
    y = y.cpu().data.numpy()[0]
    x_reconstruction = x_reconstruction.squeeze().cpu().data.numpy()
    _, y_pred = torch.max(y_pred, -1)
    y_pred = y_pred.cpu().data.numpy()[0]
    
    fig, ax = plt.subplots(1, 2)
    ax[0].imshow(x, cmap='Greys')
    ax[0].set_title('Input: %d' % y)
    ax[1].imshow(x_reconstruction, cmap='Greys')
    ax[1].set_title('Output: %d' % y_pred)
    plt.show()

In [None]:
def test(model, loader):
    metrics = defaultdict(lambda:list())
    for batch_id, (x, y) in tqdm_notebook(enumerate(loader), total=len(loader)):
        x = Variable(x).float().cuda()
        y = Variable(y).cuda()
        y_pred, x_reconstruction = model(x, y)
        _, y_pred = torch.max(y_pred, -1)
        metrics['accuracy'].append((y_pred == y).cpu().data.numpy())
    metrics['accuracy'] = np.concatenate(metrics['accuracy']).mean()
    return metrics

In [None]:
best_tst_accuracy = 0.0

In [None]:
global_epoch = 0
global_step = 0
history = defaultdict(lambda:list())

In [None]:
n_epochs = 500
for epoch in range(n_epochs):
    print('Epoch %d (%d/%d):' % (global_epoch + 1, epoch + 1, n_epochs))
    optimizer = exponential_decay(optimizer, 0.001, global_step, 100, 0.99)
    
    for batch_id, (x, y) in tqdm_notebook(enumerate(trn_loader), total=len(trn_loader)):
        x = Variable(x).float().cuda()
        y = Variable(y).cuda()
        
        y_pred, x_reconstruction = model(x, y)
        loss, margin_loss, reconstruction_loss = criterion(x, y, x_reconstruction, y_pred.cuda())
        
        history['margin_loss'].append(margin_loss.cpu().data.numpy()[0])
        history['reconstruction_loss'].append(reconstruction_loss.cpu().data.numpy()[0])
        history['loss'].append(loss.cpu().data.numpy()[0])
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        global_step += 1

    trn_metrics = test(model, trn_loader)
    tst_metrics = test(model, tst_loader)
    
    print('Margin Loss: %f' % history['margin_loss'][-1])
    print('Reconstruction Loss: %f' % history['reconstruction_loss'][-1])
    print('Loss: %f' % history['loss'][-1])
    print('Train Accuracy: %f' % trn_metrics['accuracy'])
    print('Test Accuracy: %f' % tst_metrics['accuracy'])
    
    print('Example:')
    idx = np.random.randint(0, len(x))
    show_example(model, x[idx], y[idx], x_reconstruction[idx], y_pred[idx])
    
    if (tst_metrics['accuracy'] >= best_tst_accuracy):
        best_tst_accuracy = tst_metrics['accuracy']
        save_checkpoint(
            global_epoch + 1,
            trn_metrics['accuracy'],
            tst_metrics['accuracy'],
            model,
            optimizer
        )
    global_epoch += 1

In [None]:
N = 10
n_points = 1000
plt.figure(figsize=(20, 10))

rolling_mean = np.convolve(np.asarray(history['loss'])[-n_points:],
                           np.ones((N,))/N,
                           mode='valid')
plt.plot(rolling_mean, '-g')

rolling_mean = np.convolve(np.asarray(history['margin_loss'])[-n_points:],
                           np.ones((N,))/N,
                           mode='valid')
plt.plot(rolling_mean, '-b')

rolling_mean = np.convolve(np.asarray(history['reconstruction_loss'])[-n_points:],
                           np.ones((N,))/N,
                           mode='valid')
plt.plot(rolling_mean, '-r')

plt.legend(['Total Loss', 'Margin Loss', 'Reconstruction Loss'])
plt.show()

Done!