Based on CapsNet implementation by [XifengGuo](https://github.com/XifengGuo/CapsNet-Pytorch).

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from keras.utils import to_categorical
from torchvision import datasets, transforms

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
from sklearn import preprocessing
import math


Using TensorFlow backend.


In [0]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [0]:
class MNISTData(Dataset):
  def __init__(self,mode='train'):
    super().__init__()
    if 'test' in mode.lower():
      fname = 'mnist_test.csv'
    else:
      fname = 'mnist_train_small.csv'
      
    dataset = pd.read_csv('sample_data/'+fname).values
    if 'test' in mode:
      maskIdx = [3,7]
    else:
      maskIdx = [0,1,2,4,5,6,8,9]
    
    self.xData = dataset[:,1:]/255
    self.yData = dataset[:,0]
    
    mask = np.isin(self.yData,maskIdx)

    self.xData = self.xData[mask]
    self.yData = self.yData[mask]
    # self.xData = self.xData[:128]
    # self.yData = self.yData[:128]

  def __len__(self):
    return len(self.xData)
  
  def __getitem__(self, idx):
    return np.reshape(self.xData[idx],[1,28,28]),self.yData[idx]#to_categorical(self.yData[idx],num_classes=10)
    

In [0]:
def squash(x):
  norm = x.pow(2).sum(dim=2)
  lengths = norm.sqrt()
  x = x * (norm / (1+norm) / lengths).view(x.size(0),x.size(1),1)

  return x

In [0]:
class Routing(nn.Module):
  def __init__(self, inputCaps,outputCaps, nIters):
    super().__init__()
    self.nIters = nIters
    self.b = nn.Parameter(torch.zeros((inputCaps,outputCaps)))

  def forward(self, uPredict):
    batchSize, inputCaps, outputCaps, outputDim = uPredict.size()

    c = F.softmax(self.b)
    s = (c.unsqueeze(2) * uPredict).sum(dim=1)
    v = squash(s)

    if self.nIters > 0:
      bBatch = self.b.expand((batchSize,inputCaps,outputCaps))
      for r in range(self.nIters):
        v = v.unsqueeze(1)
        bBatch = bBatch + (uPredict * v).sum(-1)

        c = F.softmax(bBatch.view(-1,outputCaps)).view(-1,inputCaps,outputCaps,1)
        s = (c * uPredict).sum(dim=1)
        v = squash(s)
      
    return v
    

In [0]:
class DenseCaps(nn.Module):
  def __init__(self, inputCaps, inputDim, outputCaps, outputDim, routingModule):
    super().__init__()
    self.inpuptDim = inputDim
    self.inputCaps = inputCaps
    self.outputDim = outputDim
    self.outputCaps = outputCaps
    self.weights = nn.Parameter(torch.Tensor(inputCaps, inputDim, outputCaps * outputDim))
    self.routingModule = routingModule
    self.reset_params()

  def reset_params(self):
    stdv = 1/math.sqrt(self.inputCaps)
    self.weights.data.uniform_(-stdv, stdv)
  
  def forward(self,  capsOutput):
    capsOutput = capsOutput.unsqueeze(2)
    uPredict = capsOutput.matmul(self.weights)
    # import pdb; pdb.set_trace()
    uPredict = uPredict.view(uPredict.size(0),self.inputCaps, self.outputCaps, self.outputDim)
    v = self.routingModule(uPredict)
    return v
  

In [0]:
class PrimaryCaps(nn.Module):
  def __init__(self, inputChannels, outputCaps, outputDim, kernelSize, stride):
    super().__init__()
    self.conv = nn.Conv2d(inputChannels, outputCaps * outputDim, kernel_size=kernelSize, stride = stride)
    self.inputChannels = inputChannels
    self.outputCaps = outputCaps
    self.outputDim = outputDim

  def forward(self, input):
    out = self.conv(input)
    N,C,H,W = out.size()
    out = out.view(N,self.outputCaps, self.outputDim, H, W)

    #N x OUTCAPS x OUTDIM
    out = out.permute(0,1,3,4,2).contiguous()
    out = out.view(out.size(0),-1,out.size(4))
    out = squash(out)

    return out



In [0]:
class ReconNet(nn.Module):
  def __init__(self, nDim=16, nClasses=10):
    super().__init__()
    self.fc1 = nn.Linear(nDim * nClasses, 512)
    self.fc2 = nn.Linear(512,1024)
    self.fc3 = nn.Linear(1024,784)
    self.nDim = nDim
    self.nClasses = nClasses

  def forward(self, x, target):
    mask = Variable(torch.zeros((x.size()[0],self.nClasses)),requires_grad=False).to(device)
    # mask = mask.float()
    # import pdb; pdb.set_trace()
    mask.scatter_(1,target.view(-1,1).long(),1.)
    mask = mask.unsqueeze(2)
    x = x*mask
    x = x.view(-1,self.nDim * self.nClasses)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = F.sigmoid(self.fc3(x))
    return x

In [0]:
class CapsNet(nn.Module):
  def __init__(self, routingIters, nClasses=10):
    super().__init__()
    self.conv1 = nn.Conv2d(1,256,kernel_size=9,stride=1)
    self.primarycaps = PrimaryCaps(256,32,8,9,2) #686 output
    self.numPrimaryCaps = 32*6*6
    routingModule = Routing(self.numPrimaryCaps,nClasses,routingIters)
    self.digitCaps = DenseCaps(self.numPrimaryCaps,8,nClasses,16,routingModule=routingModule)

  def forward(self, input):
    x = self.conv1(input)
    x = F.relu(x)
    x = self.primarycaps(x)
    # print(x.shape)
    x = self.digitCaps(x)
    probs = x.pow(2).sum(dim=2).sqrt()

    return x,probs

In [0]:
class CapsNetRecon(nn.Module):
  def __init__(self,capsnet, reconNet, nClasses=10):
    super().__init__()
    self.capsnet = capsnet
    self.reconNet = reconNet
    self.nClasses=nClasses

  def forward(self,x, target):
    x, probs = self.capsnet(x)
    
    recon = self.reconNet(x,target)
    return recon, probs, x
    

In [0]:
class MarginLoss(nn.Module):
  def __init__(self, mPos, mNeg,lambda_):
    super().__init__()
    self.mPos = mPos
    self.mNeg = mNeg
    self.lambda_ =lambda_

  def forward(self, lengths, targets, avg=True):
    t = torch.zeros(lengths.size())
    t = t.to(device)

    t = t.scatter_(1, targets.data.view(-1,1).long(), 1.)
    targets = Variable(t)
    losses = targets.float() * F.relu(self.mPos - lengths).pow(2) + self.lambda_ * (1-targets.float()) * F.relu(lengths - self.mNeg).pow(2)

    return losses.mean() if avg else losses.sum()
    

In [0]:
class Trainer:
  
  def __init__(self,batchSize=128, lr=1e-3, epochs=100, routings=3):
    
    self.epochs=epochs
    
    # self.trDataset = MNISTData()
    # self.trLoader = DataLoader(self.trDataset,num_workers=1,batch_size=batchSize,shuffle=True)
    self.trLoader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.Pad(2), transforms.RandomCrop(28),
                           transforms.ToTensor()
                       ])),
        batch_size=batchSize, shuffle=True)
    # tsDataset = MNISTData(mode='test')
    # self.tstLoader = DataLoader(tsDataset,num_workers=1,batch_size=batchSize,shuffle=True)
    self.tstLoader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, download=True,
                       transform=transforms.Compose([
                           transforms.Pad(2), transforms.RandomCrop(28),
                           transforms.ToTensor()
                       ])),
        batch_size=batchSize, shuffle=True)
    capsnet = CapsNet(3)
    reconNet = ReconNet()
    self.model = CapsNetRecon(capsnet,reconNet)
    self.model.to(device)
    
    self.optim = torch.optim.Adam(self.model.parameters(),lr)
    self.lossfn = MarginLoss(0.9,0.1,0.5)


  def _trainEpoch(self):
    trLoss = []
    for x,y in self.trLoader:

      self.optim.zero_grad()

      x = x.float().to(device)
      y = y.float().to(device)
      recon, yPred,_ = self.model(x,y)
      
      loss = self.lossfn(yPred,y)
      reconLoss = F.mse_loss(recon,x.view(-1,784))
      loss = loss + (5e-2)*reconLoss


      loss.backward()
      self.optim.step()
      
      trLoss += [loss.item()]
      
    return np.sum(trLoss)
  
  def train(self):
    epLoss = []
    self.model.train()
    for ep in tqdm(range(self.epochs)):
      epLoss += [self._trainEpoch()]
      print(f'\n---EPOCH {ep}---')
      print(f'Train Loss: {epLoss[-1]:.6f}')
    return epLoss
    
  def evalModel(self):
    self.model.eval()
    correct = 0
    for x,y in tqdm(self.tstLoader):
      x = x.float().to(device)
      y = y.float().to(device)
      
      yPred, recon = self.model(x)
      
      yPred = yPred.data.max(1)[1]
      yTrue = y.data.max(1)[1]
      correct += yPred.eq(yTrue).cpu().sum()
    acc = correct / len(self.tstLoader.dataset)
    print(f'EVALUATION ACCURACY: {acc*100:.4f}%')  
    return acc

In [0]:
train = Trainer(epochs=15)

  0%|          | 0/9912422 [00:00<?, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


9920512it [00:00, 27621780.28it/s]                            


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw


32768it [00:00, 457532.00it/s]
  1%|          | 16384/1648877 [00:00<00:11, 142039.32it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


1654784it [00:00, 7459778.20it/s]                            
8192it [00:00, 176583.21it/s]


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw
Processing...
Done!


In [0]:
train.train()

  # Remove the CWD from sys.path while we load stuff.
  7%|▋         | 1/15 [03:02<42:29, 182.11s/it]


---EPOCH 0---
Train Loss: 6.364009


 13%|█▎        | 2/15 [06:03<39:26, 182.02s/it]


---EPOCH 1---
Train Loss: 2.230927


 20%|██        | 3/15 [09:05<36:23, 181.94s/it]


---EPOCH 2---
Train Loss: 1.707877


 27%|██▋       | 4/15 [12:07<33:20, 181.86s/it]


---EPOCH 3---
Train Loss: 1.466799


 33%|███▎      | 5/15 [15:09<30:18, 181.82s/it]


---EPOCH 4---
Train Loss: 1.315153


 40%|████      | 6/15 [18:10<27:15, 181.76s/it]


---EPOCH 5---
Train Loss: 1.216962


 47%|████▋     | 7/15 [21:12<24:13, 181.72s/it]


---EPOCH 6---
Train Loss: 1.138959


KeyboardInterrupt: ignored

In [0]:
model = train.model

In [0]:
torch.save(model,'./capsNet.pth')

In [0]:
_=model.eval()

In [0]:
x,y = iter(train.trLoader).next()

In [0]:
plt.imshow(x[4,0],cmap='gray')

In [0]:
x = x.float().to(device)
y = y.float().to(device)
recon, yp, dc = model(x,y)

In [0]:
recon = recon.view(-1,28,28)

In [0]:
recon = recon.cpu().detach().numpy()


In [0]:
plt.imshow(recon[4],cmap='gray')

In [0]:
yp[4]

In [0]:
torch.argmax(yp[4])

## Visualisations

In [0]:
import torch
from torch.autograd import Variable
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

In [0]:
dataset = MNIST('../data', train=False, transform=ToTensor())
# model.to(torch.device('cpu'))

In [0]:
# (1x28x28 tensor input)
def get_digit_caps(model, image):
    input_ = Variable(image.unsqueeze(0), volatile=True).to(device)
    digit_caps, probs= model.capsnet(input_)
    return digit_caps

# takes digit_caps output and target label
def get_reconstruction(model, digit_caps, label):
    target = Variable(torch.LongTensor([label]), volatile=True).to(device)
    reconstruction = model.reconNet(digit_caps, target)
    return reconstruction.data.cpu().numpy()[0].reshape(28, 28)

# create reconstructions with perturbed digit capsule
def dimension_perturbation_reconstructions(model, digit_caps, label, dimension, dim_values):
    reconstructions = []
    for dim_value in dim_values:
        digit_caps_perturbed = digit_caps.clone()
        digit_caps_perturbed[0, label, dimension] = dim_value
        reconstruction = get_reconstruction(model, digit_caps_perturbed, label)
        reconstructions.append(reconstruction)
    return reconstructions

In [0]:
# Get reconstructions
images = []
reconstructions = []
for i in range(8):
    image_tensor, label = dataset[i]
    digit_caps = get_digit_caps(model.float().to(device), image_tensor.float().to(device))
    reconstruction = get_reconstruction(model, digit_caps, label)
    images.append(image_tensor.numpy()[0])
    reconstructions.append(reconstruction)

In [0]:
# Plot reconstructions
fig, axs = plt.subplots(2, 8, figsize=(16, 4))
axs[0, 0].set_ylabel('Org image', size='large')
axs[1, 0].set_ylabel('Reconstruction', size='large')
for i in range(8):
    axs[0, i].imshow(images[i], cmap='gray')
    axs[1, i].imshow(reconstructions[i], cmap='gray')
    axs[0, i].set_yticks([])
    axs[0, i].set_xticks([])
    axs[1, i].set_yticks([])
    axs[1, i].set_xticks([])


In [0]:
digit, label = dataset[420]
perturbed_reconstructions = []
perturbation_values = [0.05*i for i in range(-5, 6)]
digit_caps = get_digit_caps(model, digit)
for dimension in range(16):
    perturbed_reconstructions.append(
        dimension_perturbation_reconstructions(model, digit_caps, label,
                                               dimension, perturbation_values)
    )

In [0]:
fig, axs = plt.subplots(16, 11, figsize=(11*1.5, 16*1.5))
for i in range(16):
    axs[i, 0].set_ylabel('dim {}'.format(i), size='large')
    for j in range(11):
        axs[i, j].imshow(perturbed_reconstructions[i][j], cmap='gray')
        axs[i, j].set_yticks([])
        axs[i, j].set_xticks([])