In [19]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image


In [20]:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
traindata = datasets.MNIST(root='./data', train=True,download=True, transform=transform)
train_loader= torch.utils.data.DataLoader(traindata, batch_size=100,shuffle=True, num_workers=2)
testdata = datasets.MNIST(root='./data', train=False,download=True, transform=transform)
test_loader= torch.utils.data.DataLoader(testdata, batch_size=100,shuffle=False, num_workers=2)

In [21]:
class ConvLayer(nn.Module):
    def __init__(self,in_channels=1,out_channels=256,kernel_size=9):
        super(ConvLayer,self).__init__()
        self.conv = nn.Conv2d(in_channels,out_channels,kernel_size,1)
    def forward(self,x):
        return self.conv(x)

In [22]:
class PrimaryCaps(nn.Module):
    def __init__(self,num_capsules=8,in_channels=256,out_channels=32,kernel_size=9):
        super(PrimaryCaps,self).__init__()
        self.capsules = nn.ModuleList(nn.Conv2d(in_channels,out_channels,
                                     kernel_size, 2,0) for _ in range(num_capsules))
    def forward(self,x):
        out = [capsule(x) for capsule in self.capsules]
        out = torch.stack(out,dim=1) 
        out= out.view(out.size(0), 32 * 6 * 6, -1)
        return self.squash(out)
    
    def squash(self,input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

In [23]:
class DigitCaps(nn.Module):
    def __init__(self,num_capsules=10,num_routes=32*6*6,in_channels=8,out_channels=16):
        super(DigitCaps,self).__init__()
        self.in_channels = in_channels
        self.num_routes = num_routes
        self.num_capsules = num_capsules
        self.W = nn.Parameter(torch.randn(1,num_routes,num_capsules,out_channels,in_channels))
    def forward(self,x):
        batch_size = x.size(0)
        x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)

        W = torch.cat([self.W] * batch_size, dim=0)
        u_hat = torch.matmul(W, x)
        

        b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1))
        b_ij = b_ij.cuda()
        num_iterations = 3
        for iteration in range(num_iterations):
            c_ij = F.softmax(b_ij)
            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)

            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
            v_j = self.squash(s_j)
            m= u_hat.transpose(3, 4)
            #print(u_hat.shape,m.shape,v_j.shape)
            if iteration < num_iterations - 1:
                a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))
                b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)

        return v_j.squeeze(1)
    
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

In [24]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        
        self.reconstraction_layers = nn.Sequential(
            nn.Linear(16 * 10, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 784),
            nn.Sigmoid()
        )
        
    def forward(self, x, data):
        #print(x.shape)
        classes = torch.sqrt((x ** 2).sum(2))
        classes = F.softmax(classes)
        #print(classes.shape)
        
        _, max_length_indices = classes.max(dim=1)
        masked = Variable(torch.eye(10))
        #print(masked.shape)
        #print(max_length_indices)
        masked = masked.cuda()
        masked = masked.index_select(dim=0, index=max_length_indices.squeeze(1).data)
        #print(masked.shape)
        #print(masked[:, :, None, None].shape)
        reconstructions = self.reconstraction_layers((x * masked[:, :, None, None]).view(x.size(0), -1))
        reconstructions = reconstructions.view(-1, 1, 28, 28)
        
        return reconstructions, masked

In [25]:
class CapsNet(nn.Module):
    def __init__(self):
        super(CapsNet, self).__init__()
        self.conv_layer = ConvLayer()
        self.primary_capsules = PrimaryCaps()
        self.digit_capsules = DigitCaps()
        self.decoder = Decoder()
        
        self.mse_loss = nn.MSELoss()
        
    def forward(self, data):
        output = self.digit_capsules(self.primary_capsules(self.conv_layer(data)))
        reconstructions, masked = self.decoder(output, data)
        return output, reconstructions, masked
    
    def loss(self, data, x, target, reconstructions):
        return self.margin_loss(x, target) + self.reconstruction_loss(data, reconstructions)
    
    def margin_loss(self, x, labels, size_average=True):
        batch_size = x.size(0)

        v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True))

        left = F.relu(0.9 - v_c).view(batch_size, -1)
        right = F.relu(v_c - 0.1).view(batch_size, -1)

        loss = labels * left + 0.5 * (1.0 - labels) * right
        loss = loss.sum(dim=1).mean()

        return loss
    
    def reconstruction_loss(self, data, reconstructions):
        loss = self.mse_loss(reconstructions.view(reconstructions.size(0), -1), data.view(reconstructions.size(0), -1))
        return loss * 0.05

In [26]:
capsule_net = CapsNet()
capsule_net = capsule_net.cuda()
optimizer = optim.Adam(capsule_net.parameters())

In [27]:
batch_size = 100
n_epochs = 20


for epoch in range(n_epochs):
    capsule_net.train()
    train_loss = 0
    for batch_id, (data, target) in enumerate(train_loader):

        target = torch.eye(10).index_select(dim=0, index=target)
        data, target = Variable(data), Variable(target)
        data, target = data.cuda(), target.cuda()

        optimizer.zero_grad()
        output, reconstructions, masked = capsule_net(data)
        loss = capsule_net.loss(data, output, target, reconstructions)
        loss.backward()
        optimizer.step()

        train_loss += loss.data[0]
        
        if batch_id % 100 == 0:
            print ("train accuracy in epoch{0}:".format(epoch), sum(np.argmax(masked.data.cpu().numpy(), 1) == 
                                   np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size))
        
    print( train_loss / len(train_loader))
        
    capsule_net.eval()
    test_loss = 0
    for batch_id, (data, target) in enumerate(test_loader):

        target = torch.eye(10).index_select(dim=0, index=target)
        data, target = Variable(data), Variable(target)
        data, target = data.cuda(), target.cuda()

        output, reconstructions, masked = capsule_net(data)
        loss = capsule_net.loss(data, output, target, reconstructions)

        test_loss += loss.data[0]
        
        if batch_id % 100 == 0:
            print ("test accuracy in epoch{0}:".format(epoch), sum(np.argmax(masked.data.cpu().numpy(), 1) == 
                                   np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size))
            save_image(reconstructions.data,"./samples/CAPSNET/output.png",nrow=10)
    
    print (test_loss / len(test_loader))



train accuracy in epoch0: 0.15
train accuracy in epoch0: 0.82
train accuracy in epoch0: 0.76
train accuracy in epoch0: 0.83
train accuracy in epoch0: 0.89
train accuracy in epoch0: 0.9
tensor(0.4929, device='cuda:0')
test accuracy in epoch0: 0.94




tensor(0.3059, device='cuda:0')
train accuracy in epoch1: 0.88
train accuracy in epoch1: 0.95
train accuracy in epoch1: 0.9
train accuracy in epoch1: 0.87
train accuracy in epoch1: 0.89
train accuracy in epoch1: 0.93
tensor(0.2661, device='cuda:0')
test accuracy in epoch1: 0.97
tensor(0.2331, device='cuda:0')
train accuracy in epoch2: 0.93
train accuracy in epoch2: 0.89
train accuracy in epoch2: 0.93
train accuracy in epoch2: 0.94
train accuracy in epoch2: 0.93
train accuracy in epoch2: 0.95
tensor(0.2195, device='cuda:0')
test accuracy in epoch2: 0.98
tensor(0.2094, device='cuda:0')
train accuracy in epoch3: 0.96
train accuracy in epoch3: 0.91
train accuracy in epoch3: 0.93
train accuracy in epoch3: 0.95
train accuracy in epoch3: 0.96
train accuracy in epoch3: 0.95
tensor(0.1985, device='cuda:0')
test accuracy in epoch3: 0.99
tensor(0.1918, device='cuda:0')
train accuracy in epoch4: 0.95
train accuracy in epoch4: 0.94
train accuracy in epoch4: 0.96
train accuracy in epoch4: 0.93
train