In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as functional
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.optim import Adam
from torch.autograd import Variable

import numpy as np
from tensorboardX import SummaryWriter

In [4]:
#Load train set and test set and normalize the images in range [-1,1]
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(( 0.1307,), ( 0.3081,))])

#50000 images training
trainset = torchvision.datasets.MNIST(root='./mnist', train=True,
                                        download=True, transform=transform)
#We load 4 samples per batch reduce the traininset to 12500
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=0)

#print(len(trainset))
#10000 images test
testset = torchvision.datasets.MNIST(root='./mnist', train=False,
                                       download=True, transform=transform)
#We load 4 samples per batchreduce the traininset to 2500
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=0)


classes = list(range(10))

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./mnist/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting ./mnist/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./mnist/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting ./mnist/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting ./mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting ./mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist/MNIST/raw

Processing...


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Done!


In [5]:
class ConvLayer(nn.Module):
    def __init__(self, in_channels=1, out_channels=256, kernel_size=9):
        
        super(ConvLayer, self).__init__()

        self.conv = nn.Conv2d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=kernel_size,
                               stride=1
                             )
        self.relu = nn.ReLU(inplace=True)
        #We want to transform 32x32x3 in 20x20x256 

    def forward(self, x):
        return self.relu(self.conv(x))

In [6]:
class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=32, in_channels=256, out_channels=8, kernel_size=9):
        super(PrimaryCaps, self).__init__()
        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels=in_channels, out_channels=num_capsules, kernel_size=kernel_size, stride=2, padding=0) 
                          for _ in range(out_channels)])
    def forward(self, x):
        u = [capsule(x) for capsule in self.capsules]
        u = torch.stack(u, dim=4) 
        u = u.view(x.size(0), 32 * 6 * 6, -1)
        #change shape of tensor to 32 capsules of 8x6x6  gridsize=6
        return self.squash(u)
    
    #Squash function maintaining the direction of the vector (use instead of RELU activation fct in CNN)
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm)+ 1e-8)
        return output_tensor

In [7]:
class DigitCaps(nn.Module):
    def __init__(self, num_capsules=10, num_routes=32 * 6 * 6, in_channels=8, out_channels=16):
        super(DigitCaps, self).__init__()

        self.in_channels = in_channels
        self.num_routes = num_routes
        self.num_capsules = num_capsules

        
        self.W = nn.Parameter(0.01*torch.randn(1, num_capsules, num_routes, out_channels, in_channels))#0.01 added from danielhavir
        
        
    def forward(self, x):
        batch_size = x.size(0)

        x = x.unsqueeze(1)
        x = x.unsqueeze(4)
        
        
        u_hat = torch.matmul(self.W, x) #matmul = matrix multiplication
        
        u_hat = u_hat.squeeze(-1)
        temp_u_hat = u_hat.detach()
        
        b_ij = torch.zeros(batch_size,  self.num_capsules, self.num_routes,1)
        num_iterations = 3 #number of routing

        #this is the routing algorithm
        for iteration in range(num_iterations-1):
            #compute coupling coefficient, Conceptually: measure how likely capsule i may activate capsule j
            c_ij = F.softmax(b_ij, dim=1)
            
            s_j = (c_ij * temp_u_hat).sum(dim=2)
            
            v_j = self.squash(s_j)
            delta = torch.matmul(temp_u_hat, v_j.unsqueeze(-1))
            b_ij = b_ij + delta
        
        c_ij = F.softmax(b_ij, dim=2)
        s_j = (c_ij * u_hat).sum(dim=2)
        
        v_j = self.squash(s_j)
        
        return v_j
    
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

In [8]:
class Decoder(nn.Module): #difference between decoder and reconstruction ????
    def __init__(self):
        super(Decoder, self).__init__()
#num capsule=10, capsule_size=16, imsize=32, img_channel=3        
        self.reconstraction_layers = nn.Sequential(
            nn.Linear(16 * 10, 512), #nn.Linear(capsule_size*num_capsules, 512),
            #nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            #nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 784), #nn.Linear(1024, imsize*imsize*img_channel),
            nn.Sigmoid()
        )
        
    def forward(self, output):
        batch_size = output.size(0)
        classes = torch.norm(output, dim=2)
        
        
        _, max_length_indices = classes.max(dim=1)

        masked = torch.eye(10)
    
        masked = masked.index_select(dim=0, index=max_length_indices).unsqueeze(2)
        reconstructions = self.reconstraction_layers( (output*masked).view(batch_size, -1) )
        reconstructions = reconstructions.view(-1, 1, 28, 28)
        return reconstructions

In [9]:
class CapsNet(nn.Module):
    def __init__(self):
        super(CapsNet, self).__init__()
        self.conv_layer = ConvLayer()
        self.primary_capsules = PrimaryCaps()
        self.digit_capsules = DigitCaps()
        self.decoder = Decoder()
        
        self.mse_loss = nn.MSELoss()
        
    def forward(self, x):
        #print ('original:',x.size()) #size=4,3,32,32
        output = self.conv_layer(x)
        #print("conv:",output.size()) #size=4,256,24,24
        output = self.primary_capsules(output)
        #print("prim:",output.size()) #size=4,2048,8
        output = self.digit_capsules(output)
        #print("digit:",output.size()) #size=4,10,16
        preds = torch.norm(output, dim=-1)
        reconstructions = self.decoder(output)
        return preds, reconstructions
    


In [10]:
class MarginLoss(nn.Module):
    def __init__(self, size_average=False, loss_lambda=0.5):
        super(MarginLoss, self).__init__()
        self.size_average = size_average
        self.m_plus = 0.9
        self.m_minus = 0.1
        self.loss_lambda = loss_lambda

    def forward(self, inputs, labels):
        
        left=F.relu(self.m_plus - inputs)**2
        right=F.relu(inputs - self.m_minus)**2
        L_k = labels * left + self.loss_lambda * (1 - labels) * right
        L_k = L_k.sum(dim=1)
        if self.size_average:
            return L_k.mean()
        else:
            return L_k.sum()

class CapsuleLoss(nn.Module):
    def __init__(self, loss_lambda=0.5, recon_loss_scale=5e-4, size_average=False):

        super(CapsuleLoss, self).__init__()
        self.size_average = size_average
        self.margin_loss = MarginLoss(size_average=size_average, loss_lambda=loss_lambda)
        self.reconstruction_loss = nn.MSELoss(size_average=size_average)
        self.recon_loss_scale = recon_loss_scale

    def forward(self, inputs, labels, images, reconstructions):
        margin_loss = self.margin_loss(inputs, labels)
        reconstruction_loss = self.reconstruction_loss(reconstructions, images)
        caps_loss = (margin_loss + self.recon_loss_scale * reconstruction_loss)

        return caps_loss

In [13]:
# Training network over 2 epochs
capsule_net = CapsNet()

# Optimizer: Adam
criterion = CapsuleLoss()
optimizer = Adam(capsule_net.parameters(), lr=1e-3)
#optimizer = optim.SGD(capsule_net.parameters(), lr=0.001, momentum=0.9)
epochs=20
capsule_net.train()
total = 0
correct = 0
eye = torch.eye(len(classes))


for epoch in range(epochs):  # loop over the dataset multiple times
    writer = SummaryWriter(logdir='CapsNet_10_MNIST_20epochs/training_epoch%d'%(epoch))
    
    running_loss = 0.0
    eye = torch.eye(len(classes))
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs, reconstructions = capsule_net(inputs)
        labels_eye = eye[labels]
        loss = criterion(outputs, labels_eye, inputs, reconstructions)
        loss.backward()
        optimizer.step()
        
        
        _,predicted=torch.max(outputs,1)
        
        # print statistics
        running_loss += loss.item()
        total += labels.size(0) #count the number of labels with right shape
        correct += (predicted == labels).sum().item() #count the number of correct prediction
        
        if i % 500 == 499:    # print every 500 mini-batches
            print('[%d, %5d] loss: %.3f, accuracy: %.3f' %
                  (epoch + 1, i + 1, running_loss / 500, 100*float(correct)/float(total)))
            
            writer.add_scalar('accuracy_train', 100*float(correct)/float(total), i)
            writer.add_scalar('loss_train', running_loss / 500, i)
            
            running_loss = 0.0
            correct = 0
            total = 0
    
    #added check point
    if epoch%2 == 0:
        PATH = './checkpoint.1.pth'
        torch.save(capsule_net.state_dict(), PATH)
        
    writer.close()
print('Finished Training')

[1,   500] loss: 2.896, accuracy: 54.450
[1,  1000] loss: 1.489, accuracy: 92.300
[1,  1500] loss: 1.336, accuracy: 94.650
[1,  2000] loss: 1.271, accuracy: 95.650
[1,  2500] loss: 1.263, accuracy: 96.350
[1,  3000] loss: 1.318, accuracy: 95.300
[1,  3500] loss: 1.334, accuracy: 95.350
[1,  4000] loss: 1.402, accuracy: 93.350
[1,  4500] loss: 1.347, accuracy: 94.000
[1,  5000] loss: 1.271, accuracy: 96.150
[1,  5500] loss: 1.258, accuracy: 96.300
[1,  6000] loss: 1.303, accuracy: 95.350
[1,  6500] loss: 1.354, accuracy: 95.150
[1,  7000] loss: 1.264, accuracy: 96.100
[1,  7500] loss: 1.248, accuracy: 96.450
[1,  8000] loss: 1.267, accuracy: 95.950
[1,  8500] loss: 1.291, accuracy: 95.700
[1,  9000] loss: 1.272, accuracy: 96.050
[1,  9500] loss: 1.244, accuracy: 96.900
[1, 10000] loss: 1.235, accuracy: 97.050
[1, 10500] loss: 1.234, accuracy: 97.400
[1, 11000] loss: 1.261, accuracy: 96.550
[1, 11500] loss: 1.222, accuracy: 97.100
[1, 12000] loss: 1.221, accuracy: 96.750
[1, 12500] loss:

[7, 10500] loss: 0.961, accuracy: 99.500
[7, 11000] loss: 0.970, accuracy: 99.350
[7, 11500] loss: 0.966, accuracy: 99.650
[7, 12000] loss: 0.960, accuracy: 99.700
[7, 12500] loss: 0.967, accuracy: 99.450
[7, 13000] loss: 0.966, accuracy: 99.400
[7, 13500] loss: 0.977, accuracy: 99.300
[7, 14000] loss: 0.963, accuracy: 99.700
[7, 14500] loss: 0.958, accuracy: 99.600
[7, 15000] loss: 0.961, accuracy: 99.500
[8,   500] loss: 0.955, accuracy: 99.550
[8,  1000] loss: 0.949, accuracy: 99.750
[8,  1500] loss: 0.952, accuracy: 99.700
[8,  2000] loss: 0.965, accuracy: 99.650
[8,  2500] loss: 0.960, accuracy: 99.750
[8,  3000] loss: 0.953, accuracy: 99.850
[8,  3500] loss: 0.928, accuracy: 99.800
[8,  4000] loss: 0.953, accuracy: 99.650
[8,  4500] loss: 0.947, accuracy: 99.650
[8,  5000] loss: 0.951, accuracy: 99.750
[8,  5500] loss: 0.953, accuracy: 99.650
[8,  6000] loss: 0.948, accuracy: 99.900
[8,  6500] loss: 0.944, accuracy: 99.850
[8,  7000] loss: 0.949, accuracy: 99.550
[8,  7500] loss:

[14,  4000] loss: 0.886, accuracy: 99.950
[14,  4500] loss: 0.888, accuracy: 99.950
[14,  5000] loss: 0.892, accuracy: 99.950
[14,  5500] loss: 0.893, accuracy: 99.700
[14,  6000] loss: 0.894, accuracy: 99.900
[14,  6500] loss: 0.895, accuracy: 99.900
[14,  7000] loss: 0.891, accuracy: 99.950
[14,  7500] loss: 0.888, accuracy: 99.900
[14,  8000] loss: 0.887, accuracy: 99.900
[14,  8500] loss: 0.889, accuracy: 99.800
[14,  9000] loss: 0.897, accuracy: 99.750
[14,  9500] loss: 0.884, accuracy: 99.900
[14, 10000] loss: 0.878, accuracy: 99.900
[14, 10500] loss: 0.890, accuracy: 99.800
[14, 11000] loss: 0.890, accuracy: 99.950
[14, 11500] loss: 0.889, accuracy: 99.650
[14, 12000] loss: 0.901, accuracy: 99.600
[14, 12500] loss: 0.897, accuracy: 99.900
[14, 13000] loss: 0.878, accuracy: 100.000
[14, 13500] loss: 0.880, accuracy: 99.950
[14, 14000] loss: 0.901, accuracy: 99.600
[14, 14500] loss: 0.895, accuracy: 99.700
[14, 15000] loss: 0.892, accuracy: 99.600
[15,   500] loss: 0.880, accuracy

In [14]:
#training saving
PATH = './MNIST_CapsNet_SGD_20epochs_3.10.pth'
torch.save(capsule_net.state_dict(), PATH)

In [17]:
PATH = './MNIST_CapsNet_SGD_20epochs_3.10.pth'
#Load previously net from choosen training
capsule_net = CapsNet()
capsule_net.load_state_dict(torch.load(PATH))
criterion = CapsuleLoss()
optimizer = Adam(capsule_net.parameters(), lr=1e-3)
capsule_net.eval()
test_loss = 0
correct = 0
total = 0
correct_final = 0
total_final = 0
eye = torch.eye(len(classes))

writer = SummaryWriter(logdir='CapsNet_10_MNIST_20epochs/test.2')
for i, datas in enumerate(testloader, 0):
    inputs, labels = datas
    
    target = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
        
    optimizer.zero_grad()    
    output, reconstructions = capsule_net(inputs)
    
    _,predicted=torch.max(output,1)
    labels_eye = eye[labels]
    loss = criterion(output, labels_eye, inputs, reconstructions)
    test_loss += loss.item()
    
    
    total += labels.size(0) #count the number of labels with right shape
    correct += (predicted == labels).sum().item() #count the number of right labels
    total_final += labels.size(0) #count the number of labels with right shape
    correct_final += (predicted == labels).sum().item() #count the number of right labels 
    if i % 500 == 499:    # print every 500 mini-batches
        print('[%5d] loss: %.3f, accuracy: %.3f' %( i + 1, (test_loss / 500), (100*correct/total) ))
        
        writer.add_scalar('accuracy_test', 100*correct/total, i)
        writer.add_scalar('loss_test', test_loss / 500, i)
        
        test_loss = 0.0
        correct = 0
        total = 0    
writer.close()
print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct_final / total_final))    
    
print (test_loss / len(testloader))

[  500] loss: 0.891, accuracy: 98.950
[ 1000] loss: 0.911, accuracy: 98.200
[ 1500] loss: 0.911, accuracy: 99.300
[ 2000] loss: 0.948, accuracy: 99.650
[ 2500] loss: 0.936, accuracy: 99.550
Accuracy of the network on the 10000 test images: 99 %
0.0


In [16]:
writer = SummaryWriter(logdir='CapsNet_10_MNIST/graph')
writer.add_graph(capsule_net, inputs)
writer.close()