In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as functional
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.optim import Adam
from torch.autograd import Variable

import numpy as np
from tensorboardX import SummaryWriter

In [13]:
#Load train set and test set and normalize the images in range [-1,1]
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

#50000 images training
trainset = torchvision.datasets.CIFAR10(root='./dataCaps', train=True,
                                        download=True, transform=transform)
#We load 4 samples per batch reduce the traininset to 12500
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=0)

#print(len(trainset))
#10000 images test
testset = torchvision.datasets.CIFAR10(root='./dataCaps', train=False,
                                       download=True, transform=transform)
#We load 4 samples per batchreduce the traininset to 2500
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=0)


classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [14]:
class ConvLayer(nn.Module):
    def __init__(self, in_channels=3, out_channels=256, kernel_size=9):
        
        super(ConvLayer, self).__init__()

        self.conv = nn.Conv2d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=kernel_size,
                               stride=1
                             )
        self.relu = nn.ReLU(inplace=True)
        #We want to transform 32x32x3 in 20x20x256 

    def forward(self, x):
        return self.relu(self.conv(x))

In [15]:
class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=32, in_channels=256, out_channels=8, kernel_size=9):
        super(PrimaryCaps, self).__init__()
        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels=in_channels, out_channels=num_capsules, kernel_size=kernel_size, stride=2, padding=0) 
                          for _ in range(out_channels)])
    def forward(self, x):
        u = [capsule(x) for capsule in self.capsules]
        u = torch.stack(u, dim=4) 
        u = u.view(x.size(0), 32 * 8 * 8, -1)
        #change shape of tensor to 32 capsules of 8x8x8  gridsize=8
        return self.squash(u)
    
    #Squash function maintaining the direction of the vector (use instead of RELU activation fct in CNN)
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm)+ 1e-8)
        return output_tensor

In [16]:
class DigitCaps(nn.Module):
    def __init__(self, num_capsules=10, num_routes=32 * 8 * 8, in_channels=8, out_channels=16):
        super(DigitCaps, self).__init__()

        self.in_channels = in_channels
        self.num_routes = num_routes
        self.num_capsules = num_capsules

        
        self.W = nn.Parameter(0.01*torch.randn(1, num_capsules, num_routes, out_channels, in_channels))#0.01 added from danielhavir
        
        
    def forward(self, x):
        batch_size = x.size(0)

        x = x.unsqueeze(1)
        x = x.unsqueeze(4)
        
        
        u_hat = torch.matmul(self.W, x) #matmul = matrix multiplication
        
        u_hat = u_hat.squeeze(-1)
        temp_u_hat = u_hat.detach()
        
        b_ij = torch.zeros(batch_size,  self.num_capsules, self.num_routes,1)
        num_iterations = 3 #number of routing

        #this is the routing algorithm
        for iteration in range(num_iterations-1):
            #compute coupling coefficient, Conceptually: measure how likely capsule i may activate capsule j
            c_ij = F.softmax(b_ij, dim=1)
            
            s_j = (c_ij * temp_u_hat).sum(dim=2)
            
            v_j = self.squash(s_j)
            delta = torch.matmul(temp_u_hat, v_j.unsqueeze(-1))
            b_ij = b_ij + delta
        
        c_ij = F.softmax(b_ij, dim=2)
        s_j = (c_ij * u_hat).sum(dim=2)
        
        v_j = self.squash(s_j)
        
        return v_j
    
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

In [17]:
class Decoder(nn.Module): #difference between decoder and reconstruction ????
    def __init__(self):
        super(Decoder, self).__init__()
#num capsule=10, capsule_size=16, imsize=32, img_channel=3        
        self.reconstraction_layers = nn.Sequential(
            nn.Linear(16 * 10, 512), #nn.Linear(capsule_size*num_capsules, 512),
            #nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            #nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3072), #nn.Linear(1024, imsize*imsize*img_channel),
            nn.Sigmoid()
        )
        
    def forward(self, output):
        batch_size = output.size(0)
        classes = torch.norm(output, dim=2)
        
        
        _, max_length_indices = classes.max(dim=1)

        masked = torch.eye(10)
    
        masked = masked.index_select(dim=0, index=max_length_indices).unsqueeze(2)
        reconstructions = self.reconstraction_layers( (output*masked).view(batch_size, -1) )
        reconstructions = reconstructions.view(-1, 3, 32, 32)
        return reconstructions

In [18]:
class CapsNet(nn.Module):
    def __init__(self):
        super(CapsNet, self).__init__()
        self.conv_layer = ConvLayer()
        self.primary_capsules = PrimaryCaps()
        self.digit_capsules = DigitCaps()
        self.decoder = Decoder()
        
        self.mse_loss = nn.MSELoss()
        
    def forward(self, x):
        #print ('original:',x.size()) #size=4,3,32,32
        output = self.conv_layer(x)
        #print("conv:",output.size()) #size=4,256,24,24
        output = self.primary_capsules(output)
        #print("prim:",output.size()) #size=4,2048,8
        output = self.digit_capsules(output)
        #print("digit:",output.size()) #size=4,10,16
        preds = torch.norm(output, dim=-1)
        reconstructions = self.decoder(output)
        return preds, reconstructions
    


In [19]:
class MarginLoss(nn.Module):
    def __init__(self, size_average=False, loss_lambda=0.5):
        super(MarginLoss, self).__init__()
        self.size_average = size_average
        self.m_plus = 0.9
        self.m_minus = 0.1
        self.loss_lambda = loss_lambda

    def forward(self, inputs, labels):
        
        left=F.relu(self.m_plus - inputs)**2
        right=F.relu(inputs - self.m_minus)**2
        L_k = labels * left + self.loss_lambda * (1 - labels) * right
        L_k = L_k.sum(dim=1)
        if self.size_average:
            return L_k.mean()
        else:
            return L_k.sum()

class CapsuleLoss(nn.Module):
    def __init__(self, loss_lambda=0.5, recon_loss_scale=5e-4, size_average=False):

        super(CapsuleLoss, self).__init__()
        self.size_average = size_average
        self.margin_loss = MarginLoss(size_average=size_average, loss_lambda=loss_lambda)
        self.reconstruction_loss = nn.MSELoss(size_average=size_average)
        self.recon_loss_scale = recon_loss_scale

    def forward(self, inputs, labels, images, reconstructions):
        margin_loss = self.margin_loss(inputs, labels)
        reconstruction_loss = self.reconstruction_loss(reconstructions, images)
        caps_loss = (margin_loss + self.recon_loss_scale * reconstruction_loss)

        return caps_loss

In [28]:
# Training network over 2 epochs
capsule_net = CapsNet()

# Optimizer: Adam
criterion = CapsuleLoss()
optimizer = Adam(capsule_net.parameters(), lr=1e-3)
#optimizer = optim.SGD(capsule_net.parameters(), lr=0.001, momentum=0.9)
epochs=2
capsule_net.train()
total = 0
correct = 0
eye = torch.eye(len(classes))


for epoch in range(epochs):  # loop over the dataset multiple times
    writer = SummaryWriter(logdir='CapsNet_10/training_epoch%d'%(epoch))
    
    running_loss = 0.0
    eye = torch.eye(len(classes))
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs, reconstructions = capsule_net(inputs)
        labels_eye = eye[labels]
        loss = criterion(outputs, labels_eye, inputs, reconstructions)
        loss.backward()
        optimizer.step()
        
        
        _,predicted=torch.max(outputs,1)
        
        # print statistics
        running_loss += loss.item()
        total += labels.size(0) #count the number of labels with right shape
        correct += (predicted == labels).sum().item() #count the number of correct prediction
        
        if i % 500 == 499:    # print every 500 mini-batches
            print('[%d, %5d] loss: %.3f, accuracy: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000, 100*float(correct)/float(total)))
            
            writer.add_scalar('accuracy_train', 100*float(correct)/float(total), i)
            writer.add_scalar('loss_train', running_loss / 2000, i)
            
            running_loss = 0.0
            correct = 0
            total = 0
    
    #added check point
    if epoch%2 == 2:
        PATH = './checkpoint.1.pth'
        torch.save(capsule_net.state_dict(), PATH)
        
    writer.close()
print('Finished Training')

[1,   500] loss: 0.923, accuracy: 16.400
[1,  1000] loss: 0.842, accuracy: 25.400
[1,  1500] loss: 0.816, accuracy: 30.250
[1,  2000] loss: 0.795, accuracy: 32.600
[1,  2500] loss: 0.783, accuracy: 34.850
[1,  3000] loss: 0.774, accuracy: 37.900
[1,  3500] loss: 0.766, accuracy: 38.450
[1,  4000] loss: 0.764, accuracy: 39.200
[1,  4500] loss: 0.751, accuracy: 41.250
[1,  5000] loss: 0.747, accuracy: 40.800
[1,  5500] loss: 0.737, accuracy: 43.250
[1,  6000] loss: 0.750, accuracy: 40.800
[1,  6500] loss: 0.740, accuracy: 43.250
[1,  7000] loss: 0.737, accuracy: 44.550
[1,  7500] loss: 0.741, accuracy: 43.800
[1,  8000] loss: 0.723, accuracy: 45.750
[1,  8500] loss: 0.719, accuracy: 46.900
[1,  9000] loss: 0.728, accuracy: 45.950
[1,  9500] loss: 0.714, accuracy: 46.950
[1, 10000] loss: 0.725, accuracy: 45.800
[1, 10500] loss: 0.730, accuracy: 45.650
[1, 11000] loss: 0.720, accuracy: 47.650
[1, 11500] loss: 0.722, accuracy: 46.300
[1, 12000] loss: 0.709, accuracy: 48.550
[1, 12500] loss:

In [34]:
#training saving
PATH = './cifar_CapsNet_SGD_2epochs_3.10.pth'
torch.save(capsule_net.state_dict(), PATH)

In [41]:
PATH = './cifar_CapsNet_SGD_2epochs_3.10.pth'
#Load previously net from choosen training
capsule_net = CapsNet()
capsule_net.load_state_dict(torch.load(PATH))
criterion = CapsuleLoss()
optimizer = Adam(capsule_net.parameters(), lr=1e-3)
capsule_net.eval()
test_loss = 0
correct = 0
total = 0
correct_final = 0
total_final = 0
eye = torch.eye(len(classes))

writer = SummaryWriter(logdir='CapsNet_10/test.2')
for i, datas in enumerate(testloader, 0):
    inputs, labels = datas
    
    target = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
        
    optimizer.zero_grad()    
    output, reconstructions = capsule_net(inputs)
    
    _,predicted=torch.max(output,1)
    labels_eye = eye[labels]
    loss = criterion(output, labels_eye, inputs, reconstructions)
    test_loss += loss.item()
    
    
    total += labels.size(0) #count the number of labels with right shape
    correct += (predicted == labels).sum().item() #count the number of right labels
    total_final += labels.size(0) #count the number of labels with right shape
    correct_final += (predicted == labels).sum().item() #count the number of right labels 
    if i % 500 == 499:    # print every 500 mini-batches
        print('[%5d] loss: %.3f, accuracy: %.3f' %( i + 1, (test_loss / 2000), (100*correct/total) ))
        
        writer.add_scalar('accuracy_test', 100*correct/total, i)
        writer.add_scalar('loss_test', test_loss / 2000, i)
        
        test_loss = 0.0
        correct = 0
        total = 0    
writer.close()
print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct_final / total_final))    
    
print (test_loss / len(testloader))

[  500] loss: 0.646, accuracy: 55.150
[ 1000] loss: 0.642, accuracy: 56.300
[ 1500] loss: 0.658, accuracy: 56.950
[ 2000] loss: 0.651, accuracy: 55.300
[ 2500] loss: 0.666, accuracy: 53.900
Accuracy of the network on the 10000 test images: 55 %
0.0
