In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as functional
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from matplotlib.pyplot import *
from torch.optim import Adam
from torch.autograd import Variable

import numpy as np
from tensorboardX import SummaryWriter

In [2]:
#Load train set and test set and normalize the images in range [-1,1]
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

#50000 images training
trainset = torchvision.datasets.CIFAR10(root='./dataCaps', train=True,
                                        download=True, transform=transform)
#We load 4 samples per batchreduce the traininset to 12500
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)#batch size changed from 4 to 100

#print(len(trainset))
#10000 images test
testset = torchvision.datasets.CIFAR10(root='./dataCaps', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)#batch size changed from 4 to 100


classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [3]:
class ConvLayer(nn.Module):
    def __init__(self, in_channels=3, out_channels=256, kernel_size=9):
        
        super(ConvLayer, self).__init__()

        self.conv = nn.Conv2d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=kernel_size,
                               stride=1
                             )
        #We want to transform 32x32x3 in 20x20x256 

    def forward(self, x):
        return F.relu(self.conv(x))

In [4]:
class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=32, in_channels=256, out_channels=8, kernel_size=9):
        super(PrimaryCaps, self).__init__()
        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels=in_channels, out_channels=num_capsules, kernel_size=kernel_size, stride=2, padding=0) 
                          for _ in range(out_channels)])
    def forward(self, x):
        u = [capsule(x) for capsule in self.capsules]
        u = torch.stack(u, dim=4) #it was 1 here but it is like this
        u = u.view(x.size(0), 32 * 8 * 8, -1)
        #output = output.view(x.size(0), self.num_capsules*(self.gridsize)*(self.gridsize), -1)
        #change shape of tensor to 32 capsules of 8x8x8  gridsize=8
        return self.squash(u)
    
    #Squash function maintaining the direction of the vector (use instead of RELU activation fct in CNN)
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor


In [5]:
class DigitCaps(nn.Module):
    def __init__(self, num_capsules=10, num_routes=32 * 8 * 8, in_channels=8, out_channels=16):
        super(DigitCaps, self).__init__()

        self.in_channels = in_channels
        self.num_routes = num_routes
        self.num_capsules = num_capsules

        
        self.W = nn.Parameter(0.01*torch.randn(1, num_routes, num_capsules, out_channels, in_channels))#0.01 added from danielhavir
        #self.bias = nn.Parameter(torch.rand(4,1, num_capsules, out_channels))
        
    def forward(self, x):
        batch_size = x.size(0)

        
        x = torch.stack([x], dim=2).unsqueeze(4) #supp [x]*self.num_capsules
        #x = x.unsqueeze(2).unsqueeze(4)
        W = torch.cat([self.W] * batch_size, dim=0) #concatenate in zero dimension
        
        u_hat = torch.matmul(W, x) #matmul = matrix multiplication
        u_hat = u_hat.squeeze()
        temp_u_hat = u_hat.detach()
        num_capsules_in = x.shape[1]
        num_capsules_out = W.shape[2]
        
        #b_ij = Variable(x.new(batch_size, num_capsules_in, num_capsules_out, 1).zero_())
        b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1))
        
        
        num_iterations = 3 #number of routing
        #this is the routing algorithm
        for iteration in range(num_iterations-1):
            #compute coupling coefficient, Conceptually: measure how likely capsule i may activate capsule j
            c_ij = F.softmax(b_ij, dim=2)
            
            c_ij = torch.cat([c_ij], dim=0)
            #print("c_ij",c_ij.size(),"uhat",u_hat.size()),"bij"
            
            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
            #print("ok")
            #s_j = s_j + self.bias #before without bias
            
            v_j = self.squash(s_j)
            delta = (temp_u_hat * v_j).sum(dim=0, keepdim=True)
            b_ij = b_ij + delta
        
        c_ij = F.softmax(b_ij, dim=2)
        c_ij = torch.cat([c_ij], dim=0)
            
        s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
        v_j = self.squash(s_j)
        
        return v_j.squeeze(1)
    
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

In [6]:
class Decoder(nn.Module): #difference between decoder and reconstruction ????
    def __init__(self):
        super(Decoder, self).__init__()
#num capsule=10, capsule_size=16, imsize=32, img_channel=3        
        self.reconstraction_layers = nn.Sequential(
            nn.Linear(16 * 10, 512), #nn.Linear(capsule_size*num_capsules, 512),
            #nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            #nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 3072), #nn.Linear(1024, imsize*imsize*img_channel),
            nn.Sigmoid()
        )
        
    def forward(self, output, data):
        batch_size = output.size(0)
        classes = torch.norm(output, dim=2)
        #classes = F.softmax(classes)
        
        _, max_length_indices = classes.max(dim=1)

        masked = torch.eye(10)
    
        masked = masked.index_select(dim=0, index=max_length_indices).unsqueeze(2)
        #print("x",x.size(),"masked",masked[:, :, None, None].size())
        #decoder_input = (output * masked[:, :, None, None].squeeze(3)).view(batch_size, -1)

        #reconstructions = self.reconstraction_layers(decoder_input)
        reconstructions = self.reconstraction_layers( (output*masked).view(batch_size, -1) )
        reconstructions = reconstructions.view(-1, 3, 32, 32)
        return reconstructions

In [7]:
class CapsNet(nn.Module):
    def __init__(self):
        super(CapsNet, self).__init__()
        self.conv_layer = ConvLayer()
        self.primary_capsules = PrimaryCaps()
        self.digit_capsules = DigitCaps()
        self.decoder = Decoder()
        
        self.mse_loss = nn.MSELoss()
        
    def forward(self, x):
        output = self.conv_layer(x)
        #print("conv:",output.size())
        output = self.primary_capsules(output)
        #print("prim:",output.size())
        output = self.digit_capsules(output)
        #print("digit:",output.size())
        
        reconstructions = self.decoder(output, x)
        return output, reconstructions
    
    def loss(self, dataset_im, dataset_labs, caps_output, reconstructions):
        #return self.margin_loss(x, target) + self.reconstruction_loss(data, reconstructions)
        marg_loss = self.margin_loss(caps_output, dataset_labs)
        rec_loss = self.reconstruction_loss(dataset_im, reconstructions)
        total_loss = (marg_loss + 0.0005 * rec_loss).mean() #0.0005=regularization factor
        return total_loss, rec_loss.mean(), marg_loss.mean()
    #loss_lambda=0.5 ( in margin loss ), recon_loss_scale=5e-4
    def margin_loss(self, y_pred, labels):
        #lambda = 0.5 
        
        norm_y_pred = torch.norm(y_pred.type(torch.FloatTensor))
        left = labels * torch.max(torch.tensor([0, 0.9 - norm_y_pred]))**2
        x = torch.max(torch.tensor([0, norm_y_pred - 0.1]))**2
        right = 0.5 * (1 - labels) * x
        loss = left + right
        return loss
    
    def reconstruction_loss(self, data, reconstructions):
        batch_size = reconstructions.size(0)
        
        reconstructions = reconstructions.view(batch_size, -1)
        #print('data',data[0])
        data = data.view(batch_size, -1)#here it was data[0]
        #print('rec',(reconstructions).size(),'data',data.size())
        loss = nn.MSELoss()(reconstructions,data)
        #self.mse_loss(reconstructions.view(reconstructions.size(0), -1), data.view(reconstructions.size(0), -1))
        return loss 

In [None]:
# Training network over 2 epochs
capsule_net = CapsNet()
#Define Loss function and optimizer
# Loss Function: cross entropy
# Optimizer: SGD

optimizer = Adam(capsule_net.parameters())
#optimizer = optim.SGD(capsule_net.parameters(), lr=0.001, momentum=0.9)
epochs=2
capsule_net.train()
dataiter = iter(testloader)
images, labels = dataiter.next()
#print('inputs',inputs.size(),'labels',labels.size())
total = 0
correct = 0
writer = SummaryWriter(logdir='CapsNet_3/accuracy')

for epoch in range(epochs):  # loop over the dataset multiple times
    writer = SummaryWriter(logdir='CapsNet_3/epoch%d'%(epoch))
    #writer = SummaryWriter(logdir='CapsNet_3/accuracy_epoch%d'%(epoch))
    running_loss = 0.0
    
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        
        inputs, labels = Variable(inputs), Variable(labels)
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs, reconstructions = capsule_net(inputs)
        xy=torch.norm(outputs, dim=2)
        _,predicted=torch.max(xy,1)
        #print("label_gdtrue:",labels,"predicted",predicted)
        #print('inputs',inputs.size(),'labels',labels,'outpu',outputs.size(), 'rec', reconstructions.size())
        #print(outputs)
        
        loss, rec_loss, marginal_loss = capsule_net.loss(inputs, labels, outputs, reconstructions)
        loss.backward()
        optimizer.step()
        
        #print('epoch:',epoch,'batch:',i)
        
        # print statistics
        running_loss += loss.item()
        total += labels.size(0) #count the number of labels with right shape
        correct += (predicted == labels).sum().item() #count the number 
        
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0
            writer.add_scalar('accuracy', 100*correct/total, i)
            writer.add_scalar('loss', running_loss / 2000, i)
            #writer.add_scalar('loss and accuracy',{'loss': running_loss / 2000,'accuracy': 100*correct/total}, i)
            correct = 0
            total = 0
            
    writer.close()
print('Finished Training')

[1,  2000] loss: -6.118
[1,  4000] loss: -6.444
[1,  6000] loss: -6.373
[1,  8000] loss: -6.348
[1, 10000] loss: -6.471
[1, 12000] loss: -6.359
[2,  2000] loss: -6.441
[2,  4000] loss: -6.626
[2,  6000] loss: -6.756
[2,  8000] loss: -7.668


In [None]:
#training saving
PATH = './cifar_CapsNet_SGD_2epochs_3.6.pth'
torch.save(capsule_net.state_dict(), PATH)

In [None]:
PATH = './cifar_CapsNet_SGD_2epochs_3.6.pth'
#Load previously net from choosen training
capsule_net = CapsNet()
capsule_net.load_state_dict(torch.load(PATH))


capsule_net.eval()
test_loss = 0
correct = 0
total = 0

for i, datas in enumerate(testloader, 0):
    inputs, labels = datas
    
    target = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
        
        
    output, reconstructions = capsule_net(inputs)
    xy=torch.norm(output, dim=2)
    _,predicted=torch.max(xy,1)
    loss, rec_loss, marg_loss = capsule_net.loss(inputs, labels, output, reconstructions)
    test_loss += loss.item()
    #print('labels',labels,'predicted',predicted)
    
    total += labels.size(0) #count the number of labels with right shape
    correct += (predicted == labels).sum().item() #count the number of right labels 

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))    
    
print (test_loss / len(testloader))

In [14]:
print(labels)

tensor([3, 5, 1, 7])


In [15]:
print(predicted)

tensor([9, 8, 3, 1])


In [16]:
for i, datas in enumerate(testloader, 0):
    inputs, labels = datas
    print(labels)

tensor([3, 8, 8, 0])
tensor([6, 6, 1, 6])
tensor([3, 1, 0, 9])
tensor([5, 7, 9, 8])
tensor([5, 7, 8, 6])
tensor([7, 0, 4, 9])
tensor([5, 2, 4, 0])
tensor([9, 6, 6, 5])
tensor([4, 5, 9, 2])
tensor([4, 1, 9, 5])
tensor([4, 6, 5, 6])
tensor([0, 9, 3, 9])
tensor([7, 6, 9, 8])
tensor([0, 3, 8, 8])
tensor([7, 7, 4, 6])
tensor([7, 3, 6, 3])
tensor([6, 2, 1, 2])
tensor([3, 7, 2, 6])
tensor([8, 8, 0, 2])
tensor([9, 3, 3, 8])
tensor([8, 1, 1, 7])
tensor([2, 5, 2, 7])
tensor([8, 9, 0, 3])
tensor([8, 6, 4, 6])
tensor([6, 0, 0, 7])
tensor([4, 5, 6, 3])
tensor([1, 1, 3, 6])
tensor([8, 7, 4, 0])
tensor([6, 2, 1, 3])
tensor([0, 4, 2, 7])
tensor([8, 3, 1, 2])
tensor([8, 0, 8, 3])
tensor([5, 2, 4, 1])
tensor([8, 9, 1, 2])
tensor([9, 7, 2, 9])
tensor([6, 5, 6, 3])
tensor([8, 7, 6, 2])
tensor([5, 2, 8, 9])
tensor([6, 0, 0, 5])
tensor([2, 9, 5, 4])
tensor([2, 1, 6, 6])
tensor([8, 4, 8, 4])
tensor([5, 0, 9, 9])
tensor([9, 8, 9, 9])
tensor([3, 7, 5, 0])
tensor([0, 5, 2, 2])
tensor([3, 8, 6, 3])
tensor([4, 0,

tensor([9, 4, 0, 9])
tensor([4, 9, 5, 7])
tensor([5, 5, 9, 5])
tensor([3, 0, 1, 9])
tensor([7, 2, 4, 1])
tensor([0, 8, 0, 3])
tensor([1, 7, 0, 0])
tensor([4, 8, 6, 2])
tensor([4, 0, 0, 9])
tensor([0, 8, 4, 5])
tensor([9, 3, 9, 0])
tensor([5, 6, 5, 0])
tensor([1, 4, 8, 1])
tensor([0, 5, 2, 1])
tensor([0, 2, 8, 1])
tensor([5, 6, 7, 7])
tensor([2, 6, 2, 5])
tensor([0, 1, 4, 2])
tensor([5, 4, 6, 2])
tensor([2, 1, 7, 2])
tensor([8, 5, 5, 3])
tensor([0, 4, 8, 3])
tensor([7, 6, 3, 8])
tensor([1, 0, 1, 3])
tensor([3, 0, 7, 4])
tensor([9, 5, 3, 6])
tensor([0, 1, 4, 4])
tensor([4, 4, 2, 2])
tensor([5, 8, 1, 5])
tensor([9, 8, 1, 1])
tensor([5, 3, 9, 9])
tensor([7, 6, 5, 0])
tensor([8, 4, 7, 0])
tensor([9, 2, 8, 4])
tensor([7, 1, 3, 9])
tensor([6, 8, 9, 0])
tensor([4, 9, 6, 7])
tensor([8, 9, 4, 8])
tensor([9, 7, 2, 5])
tensor([3, 7, 1, 0])
tensor([2, 9, 5, 5])
tensor([8, 5, 4, 2])
tensor([8, 3, 5, 5])
tensor([7, 7, 8, 6])
tensor([2, 8, 2, 3])
tensor([5, 6, 8, 0])
tensor([2, 3, 7, 0])
tensor([1, 9,

tensor([5, 4, 5, 6])
tensor([4, 7, 9, 4])
tensor([2, 0, 6, 4])
tensor([0, 0, 6, 4])
tensor([6, 1, 9, 5])
tensor([5, 2, 2, 6])
tensor([3, 4, 5, 9])
tensor([1, 7, 2, 3])
tensor([9, 6, 5, 0])
tensor([2, 9, 7, 1])
tensor([7, 2, 2, 0])
tensor([8, 6, 4, 3])
tensor([2, 7, 7, 0])
tensor([4, 1, 6, 5])
tensor([1, 3, 0, 3])
tensor([9, 0, 0, 2])
tensor([5, 0, 4, 0])
tensor([1, 9, 8, 4])
tensor([9, 4, 2, 4])
tensor([3, 3, 4, 0])
tensor([4, 3, 2, 8])
tensor([9, 1, 5, 8])
tensor([1, 8, 2, 4])
tensor([5, 2, 4, 1])
tensor([1, 6, 6, 8])
tensor([5, 2, 2, 5])
tensor([0, 8, 2, 3])
tensor([6, 2, 9, 6])
tensor([1, 4, 5, 9])
tensor([0, 1, 0, 0])
tensor([8, 1, 1, 6])
tensor([6, 9, 5, 4])
tensor([1, 7, 8, 6])
tensor([9, 1, 7, 6])
tensor([0, 9, 3, 5])
tensor([3, 2, 5, 3])
tensor([4, 9, 7, 1])
tensor([4, 4, 6, 1])
tensor([3, 8, 8, 0])
tensor([6, 7, 7, 6])
tensor([7, 2, 3, 2])
tensor([2, 6, 2, 7])
tensor([4, 0, 3, 6])
tensor([2, 6, 3, 3])
tensor([0, 9, 5, 1])
tensor([1, 5, 3, 6])
tensor([4, 3, 4, 1])
tensor([0, 4,

tensor([0, 8, 2, 0])
tensor([0, 2, 4, 8])
tensor([6, 2, 4, 6])
tensor([3, 5, 1, 5])
tensor([3, 7, 2, 2])
tensor([9, 8, 0, 0])
tensor([0, 3, 4, 4])
tensor([6, 1, 6, 7])
tensor([4, 4, 3, 9])
tensor([4, 0, 8, 0])
tensor([4, 6, 5, 7])
tensor([9, 7, 0, 5])
tensor([7, 7, 3, 1])
tensor([9, 3, 0, 9])
tensor([5, 3, 7, 9])
tensor([4, 4, 1, 7])
tensor([7, 1, 4, 1])
tensor([2, 8, 7, 0])
tensor([0, 4, 7, 2])
tensor([9, 7, 6, 9])
tensor([3, 5, 8, 0])
tensor([3, 6, 8, 3])
tensor([2, 4, 7, 1])
tensor([1, 3, 9, 7])
tensor([5, 1, 0, 8])
tensor([7, 0, 1, 6])
tensor([9, 3, 2, 7])
tensor([7, 8, 1, 0])
tensor([3, 4, 6, 7])
tensor([5, 2, 0, 1])
tensor([5, 5, 1, 4])
tensor([1, 3, 0, 8])
tensor([6, 2, 1, 3])
tensor([6, 4, 1, 9])
tensor([0, 4, 1, 0])
tensor([1, 9, 8, 6])
tensor([9, 2, 4, 7])
tensor([2, 2, 7, 4])
tensor([9, 1, 3, 2])
tensor([6, 3, 4, 4])
tensor([9, 4, 8, 2])
tensor([6, 6, 1, 6])
tensor([3, 6, 5, 8])
tensor([4, 6, 7, 1])
tensor([9, 3, 6, 7])
tensor([6, 0, 7, 1])
tensor([9, 5, 2, 6])
tensor([7, 7,

tensor([9, 0, 7, 7])
tensor([5, 4, 2, 6])
tensor([4, 5, 7, 7])
tensor([8, 7, 2, 6])
tensor([2, 2, 4, 4])
tensor([0, 7, 1, 3])
tensor([9, 6, 0, 0])
tensor([2, 3, 8, 2])
tensor([2, 4, 3, 5])
tensor([2, 9, 1, 0])
tensor([0, 6, 5, 5])
tensor([7, 9, 9, 6])
tensor([5, 5, 0, 5])
tensor([7, 1, 6, 6])
tensor([4, 1, 4, 4])
tensor([1, 5, 0, 0])
tensor([4, 5, 8, 4])
tensor([8, 3, 0, 5])
tensor([0, 5, 3, 1])
tensor([6, 7, 0, 9])
tensor([1, 5, 7, 6])
tensor([5, 5, 5, 6])
tensor([0, 0, 1, 7])
tensor([5, 1, 9, 2])
tensor([4, 1, 3, 7])
tensor([8, 2, 0, 9])
tensor([6, 6, 0, 6])
tensor([5, 8, 2, 7])
tensor([4, 0, 2, 7])
tensor([7, 8, 8, 7])
tensor([0, 4, 9, 1])
tensor([4, 4, 3, 5])
tensor([4, 6, 2, 3])
tensor([1, 0, 3, 3])
tensor([3, 6, 3, 1])
tensor([2, 8, 9, 7])
tensor([9, 3, 8, 7])
tensor([3, 1, 7, 7])
tensor([3, 2, 2, 8])
tensor([9, 5, 9, 2])
tensor([1, 7, 4, 4])
tensor([0, 5, 7, 1])
tensor([5, 4, 0, 8])
tensor([4, 9, 8, 7])
tensor([8, 4, 2, 3])
tensor([4, 0, 5, 4])
tensor([1, 8, 2, 5])
tensor([4, 5,

tensor([6, 8, 9, 6])
tensor([2, 0, 4, 9])
tensor([4, 9, 3, 9])
tensor([6, 6, 7, 0])
tensor([9, 7, 1, 8])
tensor([6, 0, 6, 7])
tensor([4, 1, 9, 4])
tensor([6, 7, 9, 8])
tensor([3, 9, 2, 1])
tensor([2, 7, 6, 1])
tensor([0, 0, 5, 6])
tensor([0, 4, 3, 2])
tensor([8, 8, 0, 6])
tensor([9, 5, 2, 8])
tensor([7, 0, 6, 5])
tensor([9, 7, 2, 3])
tensor([6, 9, 6, 2])
tensor([2, 4, 1, 0])
tensor([5, 0, 8, 9])
tensor([3, 5, 9, 3])
tensor([8, 1, 6, 3])
tensor([7, 5, 6, 2])
tensor([0, 2, 8, 2])
tensor([8, 7, 7, 8])
tensor([1, 0, 8, 9])
tensor([7, 0, 3, 8])
tensor([0, 5, 9, 5])
tensor([8, 4, 2, 0])
tensor([9, 2, 2, 4])
tensor([4, 9, 2, 2])
tensor([2, 5, 1, 3])
tensor([2, 0, 0, 4])
tensor([0, 6, 5, 8])
tensor([0, 5, 8, 6])
tensor([4, 8, 5, 2])
tensor([9, 7, 9, 7])
tensor([1, 0, 1, 9])
tensor([6, 9, 2, 7])
tensor([9, 4, 4, 0])
tensor([6, 2, 4, 1])
tensor([3, 7, 2, 8])
tensor([5, 9, 0, 3])
tensor([2, 3, 2, 7])
tensor([6, 3, 2, 5])
tensor([9, 0, 5, 9])
tensor([9, 8, 7, 7])
tensor([4, 8, 6, 5])
tensor([2, 3,