In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as functional
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from matplotlib.pyplot import *
from torch.optim import Adam
from torch.autograd import Variable

import numpy as np

In [3]:
#Load train set and test set and normalize the images in range [-1,1]
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

#50000 images training
trainset = torchvision.datasets.CIFAR10(root='./dataCaps', train=True,
                                        download=True, transform=transform)
#We load 4 samples per batchreduce the traininset to 12500
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)#batch size changed from 4 to 100

#print(len(trainset))
#10000 images test
testset = torchvision.datasets.CIFAR10(root='./dataCaps', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)#batch size changed from 4 to 100


classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [4]:
class ConvLayer(nn.Module):
    def __init__(self, in_channels=3, out_channels=256, kernel_size=9):
        
        super(ConvLayer, self).__init__()

        self.conv = nn.Conv2d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=kernel_size,
                               stride=1
                             )

    def forward(self, x):
        return F.relu(self.conv(x))

In [5]:
class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9):
        super(PrimaryCaps, self).__init__()
        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0) 
                          for _ in range(num_capsules)])
    def forward(self, x):
        u = [capsule(x) for capsule in self.capsules]
        u = torch.stack(u, dim=4) #it was 4 here
        u = u.view(x.size(0), 32 * 8 * 8, -1)#output = output.view(x.size(0), self.num_capsules*(self.gridsize)*(self.gridsize), -1)
        return self.squash(u)
    
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor


In [6]:
class DigitCaps(nn.Module):
    def __init__(self, num_capsules=10, num_routes=32 * 8 * 8, in_channels=8, out_channels=16):
        super(DigitCaps, self).__init__()

        self.in_channels = in_channels
        self.num_routes = num_routes
        self.num_capsules = num_capsules

        self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels))

    def forward(self, x):
        batch_size = x.size(0)
        x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)

        W = torch.cat([self.W] * batch_size, dim=0)
        
        u_hat = torch.matmul(W, x)
        

        b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1))
        #print('b_ij:',b_ij.size())
        #if USE_CUDA:
         #   b_ij = b_ij.cuda()

        num_iterations = 3
        for iteration in range(num_iterations):
            c_ij = F.softmax(b_ij, dim=2)
            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)

            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
            v_j = self.squash(s_j)
            
            if iteration < num_iterations - 1:
                a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))
                #print('u_hat.transpose(3, 4):',u_hat.transpose(3, 4).size(),'torch.cat([v_j] * self.num_routes, dim=1):',torch.cat([v_j] * self.num_routes, dim=1).size(),'aij:',a_ij.squeeze(4).mean(dim=0, keepdim=True).size())
                aij = a_ij.squeeze(4).mean(dim=0, keepdim=True)
                b_ij = b_ij + aij
                #print('bij:',b_ij.size())

        return v_j.squeeze(1)
    
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

In [7]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
#num capsule=10, capsule_size=16, imsize=32, img_channel=3        
        self.reconstraction_layers = nn.Sequential(
            nn.Linear(16 * 10, 512), #nn.Linear(capsule_size*num_capsules, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3072), #nn.Linear(1024, imsize*imsize*img_channel),
            nn.Sigmoid()
        )
        
    def forward(self, x, data):
        classes = torch.sqrt((x ** 2).sum(2))
        classes = F.softmax(classes)
        
        _, max_length_indices = classes.max(dim=1)
        masked = Variable(torch.sparse.torch.eye(10)) #Variable(torch.sparse.torch.eye(num_capsules))
        #if USE_CUDA:
        #    masked = masked.cuda()
        masked = masked.index_select(dim=0, index=max_length_indices.squeeze(1).data)
        
        reconstructions = self.reconstraction_layers((x * masked[:, :, None, None]).view(x.size(0), -1))
        reconstructions = reconstructions.view(-1, 3, 32, 32) #reconstructions.view(-1, 1, 28, 28)
        
        return reconstructions, masked

In [8]:
class CapsNet(nn.Module):
    def __init__(self):
        super(CapsNet, self).__init__()
        self.conv_layer = ConvLayer()
        self.primary_capsules = PrimaryCaps()
        self.digit_capsules = DigitCaps()
        self.decoder = Decoder()
        
        self.mse_loss = nn.MSELoss()
        
    def forward(self, data):
        output = self.digit_capsules(self.primary_capsules(self.conv_layer(data)))
        reconstructions, masked = self.decoder(output, data)
        return output, reconstructions, masked
    
    def loss(self, data, x, target, reconstructions):
        #return self.margin_loss(x, target) + self.reconstruction_loss(data, reconstructions)
        marg_loss = self.margin_loss(x, target)
        rec_loss = self.reconstruction_loss(data, reconstructions)
        total_loss = (marg_loss + 0.0005 * rec_loss).mean()
        return total_loss, rec_loss.mean(), marg_loss.mean()
    
    def margin_loss(self, x, labels, size_average=True):
        batch_size = x.size(0)

        v_c = torch.norm((x), dim=2, keepdim=True)
        #print ((functional.relu(0.9 - v_c)).size())
        left = functional.relu(0.9 - v_c).view(batch_size, -1) **2
        right = functional.relu(v_c - 0.1).view(batch_size, -1) **2
        #print('left:',left.size(),'right:',right.size(), 'labels:',labels.size())
        #print((labels.t()).size())
        #print('1st:',(labels*left.t()).size())
        #print('2nd:',((torch.ones(1,100) - labels.t()) * right.t()).size())
        
        loss = labels * left.t() + 0.5 * (1.0 - labels) * right.t()
        
        loss = loss.sum(dim=1)#.mean()

        return loss
    
    def reconstruction_loss(self, data, reconstructions):
        batch_size = reconstructions.size(0)
        
        reconstructions = reconstructions.view(batch_size, -1)
        #print('data',data[0])
        data = data.view(batch_size, -1)#here it was data[0]
        #print('rec',(reconstructions).size(),'data',data.size())
        loss = nn.MSELoss()(reconstructions,data)
        #self.mse_loss(reconstructions.view(reconstructions.size(0), -1), data.view(reconstructions.size(0), -1))
        return loss 

In [10]:
# Training network over 2 epochs
capsule_net = CapsNet()
#Define Loss function and optimizer
# Loss Function: cross entropy
# Optimizer: SGD
criterion = nn.CrossEntropyLoss()
#optimizer = Adam(capsule_net.parameters())
optimizer = optim.SGD(capsule_net.parameters(), lr=0.001, momentum=0.9)
epochs=2

dataiter = iter(testloader)
images, labels = dataiter.next()
#print('inputs',inputs.size(),'labels',labels.size())
for epoch in range(epochs):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        
        inputs, labels = Variable(inputs), Variable(labels)
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs, reconstructions, masked = capsule_net(inputs)
        #print('inputs',inputs.size(),'labels',labels,'outpu',outputs.size(), 'rec', reconstructions.size())
        #print(outputs)
        loss, rec_loss, marginal_loss = capsule_net.loss(inputs, outputs, labels, reconstructions)
        loss.backward()
        optimizer.step()
        
        #print('epoch:',epoch,'batch:',i)
        
        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

  classes = F.softmax(classes)


[1,  2000] loss: -107.021
[1,  4000] loss: -106.686
[1,  6000] loss: -106.097
[1,  8000] loss: -107.131
[1, 10000] loss: -106.051
[1, 12000] loss: -105.869
[2,  2000] loss: -105.383
[2,  4000] loss: -106.398
[2,  6000] loss: -107.207
[2,  8000] loss: -107.876
[2, 10000] loss: -106.029
[2, 12000] loss: -105.660
Finished Training


In [11]:
#training saving
PATH = './cifar_CapsNet_SGD_2epochs.pth'
torch.save(capsule_net.state_dict(), PATH)

In [11]:
#dataiter = iter(testloader)
#images, labels = dataiter.next()

PATH = './cifar_CapsNet_SGD_2epochs.pth'
#Load previously net from choosen training
capsule_net = CapsNet()
capsule_net.load_state_dict(torch.load(PATH))


capsule_net.eval()
test_loss = 0
correct = 0
total = 0
for i, datas in enumerate(testloader, 0):
    inputs, labels = datas
    #print('images',inputs.size(),'labels',labels.size())
    target = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
        #dataCaps, target = Variable(dataCaps), Variable(target)

        
    output, reconstructions, masked = capsule_net(inputs)
    #print ('outputs',output.size())
    loss, rec_loss, marg_loss = capsule_net.loss(inputs, output, labels, reconstructions)
    test_loss += loss.item()
    outputs= torch.sum(torch.abs(output), 2)
    _, predicted = torch.max(outputs.data, 1) #torch.max(input, dim) return maximum value of all element from input tensor in the given dim
    total += labels.size(0) #count the number of labels with right shape
    correct += (predicted == labels).sum().item() #count the number of right labels 

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))    
    
print (test_loss / len(testloader))

  classes = F.softmax(classes)


Accuracy of the network on the 10000 test images: 40 %
-106.43563184500378
