In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as functional
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from matplotlib.pyplot import *
from torch.optim import Adam
from torch.autograd import Variable

import numpy as np
from tensorboardX import SummaryWriter

In [5]:
#Load train set and test set and normalize the images in range [-1,1]
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

#50000 images training
trainset = torchvision.datasets.CIFAR10(root='./dataCaps', train=True,
                                        download=True, transform=transform)
#We load 4 samples per batchreduce the traininset to 12500
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)#batch size changed from 4 to 100

#print(len(trainset))
#10000 images test
testset = torchvision.datasets.CIFAR10(root='./dataCaps', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)#batch size changed from 4 to 100


classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./dataCaps/cifar-10-python.tar.gz


0it [00:00, ?it/s]

Extracting ./dataCaps/cifar-10-python.tar.gz to ./dataCaps
Files already downloaded and verified


In [6]:
class ConvLayer(nn.Module):
    def __init__(self, in_channels=3, out_channels=256, kernel_size=9):
        
        super(ConvLayer, self).__init__()

        self.conv = nn.Conv2d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=kernel_size,
                               stride=1
                             )
        #We want to transform 32x32x3 in 20x20x256 

    def forward(self, x):
        return F.relu(self.conv(x))

In [7]:
class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=32, in_channels=256, out_channels=8, kernel_size=9):
        super(PrimaryCaps, self).__init__()
        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels=in_channels, out_channels=num_capsules, kernel_size=kernel_size, stride=2, padding=0) 
                          for _ in range(out_channels)])
    def forward(self, x):
        u = [capsule(x) for capsule in self.capsules]
        u = torch.stack(u, dim=4) #it was 1 here but it is like this
        u = u.view(x.size(0), 32 * 8 * 8, -1)
        #output = output.view(x.size(0), self.num_capsules*(self.gridsize)*(self.gridsize), -1)
        #change shape of tensor to 32 capsules of 8x8x8  gridsize=8
        return self.squash(u)
    
    #Squash function maintaining the direction of the vector (use instead of RELU activation fct in CNN)
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor


In [8]:
class DigitCaps(nn.Module):
    def __init__(self, num_capsules=10, num_routes=32 * 8 * 8, in_channels=8, out_channels=16):
        super(DigitCaps, self).__init__()

        self.in_channels = in_channels
        self.num_routes = num_routes
        self.num_capsules = num_capsules

        
        self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels))
        self.bias = nn.Parameter(torch.rand(4,1, num_capsules, out_channels))
        
    def forward(self, x):
        batch_size = x.size(0)

        
        x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)
        #x = x.unsqueeze(2).unsqueeze(4)
        W = torch.cat([self.W] * batch_size, dim=0) #concatenate in zero dimension
        
        u_hat = torch.matmul(W, x) #matmul = matrix multiplication
        u_hat = u_hat.squeeze()
        num_capsules_in = x.shape[1]
        num_capsules_out = W.shape[2]
        
        #b_ij = Variable(x.new(batch_size, num_capsules_in, num_capsules_out, 1).zero_())
        b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1))
        
        
        num_iterations = 3 #number of routing
        #this is the routing algorithm
        for iteration in range(num_iterations):
            #compute coupling coefficient, Conceptually: measure how likely capsule i may activate capsule j
            c_ij = F.softmax(b_ij, dim=2)
            
            c_ij = torch.cat([c_ij], dim=0)
            #print("c_ij",c_ij.size(),"uhat",u_hat.size()),"bij"
            
            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
            #print("ok")
            #s_j = s_j + self.bias #before without bias
            
            v_j = self.squash(s_j)
            delta = (u_hat * v_j).sum(dim=0, keepdim=True)
            b_ij = b_ij + delta
        return v_j.squeeze(1)
    
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

In [9]:
class Decoder(nn.Module): #difference between decoder and reconstruction ????
    def __init__(self):
        super(Decoder, self).__init__()
#num capsule=10, capsule_size=16, imsize=32, img_channel=3        
        self.reconstraction_layers = nn.Sequential(
            nn.Linear(16 * 10, 512), #nn.Linear(capsule_size*num_capsules, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 3072), #nn.Linear(1024, imsize*imsize*img_channel),
            nn.Sigmoid()
        )
        
    def forward(self, output, data):
        batch_size = output.size(0)
        classes = torch.norm(output, dim=2)
        classes = F.softmax(classes)
        max_length_indices = classes.max(dim=1)[1].squeeze()
        #_, max_length_indices = classes.max(dim=1)
        #masked = Variable(torch.sparse.torch.eye(10)) #Variable(torch.sparse.torch.eye(num_capsules))
        #masked = masked.index_select(dim=0, index=max_length_indices.squeeze(1).data)
        
        #reconstructions = self.reconstraction_layers((x * masked[:, :, None, None]).view(x.size(0), -1))
        #reconstructions = reconstructions.view(-1, 3, 32, 32) #reconstructions.view(-1, 1, 28, 28)
        masked = Variable(output.new_tensor(torch.eye(10)))
    
        masked = masked.index_select(dim=0, index=max_length_indices.data)
        #print("x",x.size(),"masked",masked[:, :, None, None].size())
        decoder_input = (output * masked[:, :, None, None].squeeze(3)).view(batch_size, -1)

        reconstructions = self.reconstraction_layers(decoder_input)
        reconstructions = reconstructions.view(-1, 3, 32, 32)
        return reconstructions, masked

In [10]:
class CapsNet(nn.Module):
    def __init__(self):
        super(CapsNet, self).__init__()
        self.conv_layer = ConvLayer()
        self.primary_capsules = PrimaryCaps()
        self.digit_capsules = DigitCaps()
        self.decoder = Decoder()
        
        self.mse_loss = nn.MSELoss()
        
    def forward(self, x, target=None):
        output = self.conv_layer(x)
        #print("conv:",output.size())
        output = self.primary_capsules(output)
        #print("prim:",output.size())
        output = self.digit_capsules(output)
        #print("digit:",output.size())
        reconstructions, masked = self.decoder(output, target)
        return output, reconstructions, masked
    
    def loss(self, dataset_im, dataset_labs, caps_output, reconstructions):
        #return self.margin_loss(x, target) + self.reconstruction_loss(data, reconstructions)
        marg_loss = self.margin_loss(caps_output, dataset_labs)
        rec_loss = self.reconstruction_loss(dataset_im, reconstructions)
        total_loss = (marg_loss + 0.0005 * rec_loss).mean() #0.0005=regularization factor
        return total_loss, rec_loss.mean(), marg_loss.mean()
    
    def margin_loss(self, y_pred, labels):
        #lambda = 0.5 
        #print ("test", torch.max(0, 0.9 - torch.norm(y_pred)))
        
        norm_y_pred = torch.norm(y_pred.type(torch.FloatTensor))
        left = labels * torch.max(torch.tensor([0, 0.9 - norm_y_pred]))**2
        x = torch.max(torch.tensor([0, norm_y_pred - 0.1]))**2
        right = 0.5 * (1 - labels) * x
        loss = left + right
        return loss
    
    def reconstruction_loss(self, data, reconstructions):
        batch_size = reconstructions.size(0)
        
        reconstructions = reconstructions.view(batch_size, -1)
        #print('data',data[0])
        data = data.view(batch_size, -1)#here it was data[0]
        #print('rec',(reconstructions).size(),'data',data.size())
        loss = nn.MSELoss()(reconstructions,data)
        #self.mse_loss(reconstructions.view(reconstructions.size(0), -1), data.view(reconstructions.size(0), -1))
        return loss 

In [11]:
# Training network over 2 epochs
capsule_net = CapsNet()
#Define Loss function and optimizer
# Loss Function: cross entropy
# Optimizer: SGD
criterion = nn.CrossEntropyLoss()
#optimizer = Adam(capsule_net.parameters())
optimizer = optim.SGD(capsule_net.parameters(), lr=0.001, momentum=0.9)
epochs=2
capsule_net.train()
dataiter = iter(testloader)
images, labels = dataiter.next()
#print('inputs',inputs.size(),'labels',labels.size())
total = 0
correct = 0
writer = SummaryWriter(logdir='CapsNet_3/accuracy')
for epoch in range(epochs):  # loop over the dataset multiple times

    running_loss = 0.0
    
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        
        inputs, labels = Variable(inputs), Variable(labels)
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs, reconstructions, masked = capsule_net(inputs)
        _, predicted =torch.max(masked, 1)
        #print("label_gdtrue:",labels,"predicted",predicted)
        #print('inputs',inputs.size(),'labels',labels,'outpu',outputs.size(), 'rec', reconstructions.size())
        #print(outputs)
        
        loss, rec_loss, marginal_loss = capsule_net.loss(inputs, labels, predicted, reconstructions)
        loss.backward()
        optimizer.step()
        
        #print('epoch:',epoch,'batch:',i)
        
        # print statistics
        running_loss += loss.item()
        total += labels.size(0) #count the number of labels with right shape
        correct += (predicted == labels).sum().item() #count the number 
        
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0
            writer.add_scalar('accuracy', correct/total, i)
            correct = 0
            total = 0
            
writer.close()
print('Finished Training')

  classes = F.softmax(classes)
  masked = Variable(output.new_tensor(torch.eye(10)))


[1,  2000] loss: -150.004
[1,  4000] loss: -150.035
[1,  6000] loss: -147.596
[1,  8000] loss: -153.665
[1, 10000] loss: -145.743
[1, 12000] loss: -146.295
[2,  2000] loss: -143.982
[2,  4000] loss: -145.568
[2,  6000] loss: -146.401
[2,  8000] loss: -146.178
[2, 10000] loss: -154.472
[2, 12000] loss: -144.134
Finished Training


In [14]:
#training saving
PATH = './cifar_CapsNet_SGD_2epochs_3.1.pth'
torch.save(capsule_net.state_dict(), PATH)

In [15]:
PATH = './cifar_CapsNet_SGD_2epochs_3.pth'
#Load previously net from choosen training
capsule_net = CapsNet()
capsule_net.load_state_dict(torch.load(PATH))


capsule_net.eval()
test_loss = 0
correct = 0
total = 0

for i, datas in enumerate(testloader, 0):
    inputs, labels = datas
    
    target = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
        
        
    output, reconstructions, masked = capsule_net(inputs)
    
    loss, rec_loss, marg_loss = capsule_net.loss(inputs, labels, output, reconstructions)
    test_loss += loss.item()
    
    _, predicted =torch.max(masked, 1)
    total += labels.size(0) #count the number of labels with right shape
    correct += (predicted == labels).sum().item() #count the number of right labels 

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))    
    
print (test_loss / len(testloader))

  classes = F.softmax(classes)
  masked = Variable(output.new_tensor(torch.eye(10)))


Accuracy of the network on the 10000 test images: 10 %
-62.770535162907976
