In [1]:
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.optim import Adam
import torch.nn.functional as F

import csv
from skimage import io

from PIL import Image
import pandas as pd

import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable



import matplotlib.pyplot as plt
import time
import os
import copy

verbose = False
USE_CUDA = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

MNIST_USE = False
CIFAR10_USE = False
MARVEL_USE = True

In [2]:
class MARVEL_dataset(Dataset):
    def __init__(self, dat_file,train = True, transform = None):   
        self.root_dir = os.path.dirname(dat_file)
        datContent = [i.strip().split(',') for i in open(dat_file).readlines()]
        if train:
            csv_file = os.path.join(self.root_dir, "data_Train.csv")
        else:
            csv_file = os.path.join(self.root_dir, "data_Test.csv")
        with open(csv_file, "w") as f:
            writer = csv.writer(f,delimiter=',')
            writer.writerow(["counter", "set", "class", "label","location"])
            for line in datContent:
                if train and line[1]=='1':
                    if not(line[4] == '-'):
                        writer.writerows([line])  
                if not(train) and line[1] == '2':
                    if not(line[4]=='-'):
                        writer.writerows([line]) 
                
        self.MARVEL_datafile = pd.read_csv(csv_file)       
        self.transform = transform
        
    def __len__(self):
        return len(self.MARVEL_datafile)
    
    def __getitem__(self,idx):
        img_name = self.MARVEL_datafile.iloc[idx,4]
        image = self.__loadfile(img_name)
        target = self.MARVEL_datafile.iloc[idx,2]
        if self.transform:
            image = Image.fromarray(image)
            sample = self.transform(image)
        else:
            sample = image
        return (sample,target)
    
    def __loadfile(self, data_file):
        image = io.imread(data_file)
        if len(image.shape)<3:
            image = np.stack((image,)*3, axis=-1)
        return image

In [3]:
dataset_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),        
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])

batch_size = 100


if CIFAR10_USE: 
    NUM_CLASSES = 10
    print("CIFAR10")
    image_datasets = {'train': datasets.CIFAR10('../data', train=True, download=True, transform=dataset_transform),'val': datasets.CIFAR10('../data', train=False, download=True, transform=dataset_transform)}
    print("Initializing Datasets and Dataloaders...")

    dataloaders = {'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=batch_size, shuffle=True) , 'val': torch.utils.data.DataLoader(image_datasets['val'], batch_size=batch_size, shuffle=True) }
    print("Initializing Datasets and Dataloaders...")

if MARVEL_USE: 
    NUM_CLASSES = 26
    print("MARVEL")
    dat_file = "/home/rita/JupyterProjects/EYE-SEA/DataSets/marveldataset2016-master/FINAL.dat"

    image_datasets = {'train': MARVEL_dataset(dat_file,train = True,transform=dataset_transform),'val': MARVEL_dataset(dat_file,train = False,transform=dataset_transform)}
    print("Initializing Datasets and Dataloaders...")

    dataloaders = {'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=batch_size, shuffle=True) , 'val': torch.utils.data.DataLoader(image_datasets['val'], batch_size=batch_size, shuffle=True) }
    print("Initializing Datasets and Dataloaders...")

if MNIST_USE: 
    NUM_CLASSES = 10
    print("MNIST")
    image_datasets = {'train': datasets.MNIST('../data', train=True, download=True, transform=dataset_transform),'val': datasets.MNIST('../data', train=False, download=True, transform=dataset_transform)}
    print("Initializing Datasets and Dataloaders...")

    dataloaders = {'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=batch_size, shuffle=True) , 'val': torch.utils.data.DataLoader(image_datasets['val'], batch_size=batch_size, shuffle=True) }
    print("Initializing Datasets and Dataloaders...")

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}





MARVEL
Initializing Datasets and Dataloaders...
Initializing Datasets and Dataloaders...


In [4]:
class ConvLayer(nn.Module):
    def __init__(self, in_channels=128,  out_channels=256, kernel_size=9):
           
        super(ConvLayer, self).__init__()

        self.conv = nn.Conv2d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=kernel_size,
                               stride=1
                             )

    def forward(self, x):
        if verbose: print( "Conv input size{}".format(x.size()))
        output = F.relu(self.conv(x))
        if verbose: print("Conv output feature matrix {}".format(output.shape))
        return output

In [5]:
class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9):

        super(PrimaryCaps, self).__init__()

        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0) 
                          for _ in range(num_capsules)])
    
    def forward(self, x, dimension = 32*6*6):
        u = [capsule(x) for capsule in self.capsules]
        u = torch.stack(u, dim=1)
        if verbose: print( "PrimaryCaps {}".format(u.size()))
        u = u.view(x.size(0), dimension, -1)
        if verbose: print("PrimaryCaps size U {}".format(u.size()))
        output = self.squash(u)
        if verbose: print("Primary Caps output {}".format(output.size()))
        return output
    
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        if verbose: print(output_tensor.size())
        return output_tensor

In [6]:
class DigitCaps(nn.Module):
    def __init__(self, num_capsules=NUM_CLASSES, num_routes=32 * 6 * 6 , in_channels=8,  out_channels=16):
        
        super(DigitCaps, self).__init__()

        self.in_channels = in_channels
        self.num_routes = num_routes
        self.num_capsules = num_capsules

        self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels))

    def forward(self, x):
        batch_size = x.size(0)
        x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)
        if verbose: print( "DigitCaps x {}, W {}".format(x.size(),self.W.size()))
        W = torch.cat([self.W] * batch_size, dim=0)
        if verbose: print("DigitCaps W {}".format(W.size()))
        u_hat = torch.matmul(W, x)
        if verbose: print("DigitCaps u_hat {}".format(u_hat.size()))
        
        b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1))
        if USE_CUDA:
            b_ij = b_ij.to(device)#cuda()

        num_iterations = 3
        for iteration in range(num_iterations):
            c_ij = F.softmax(b_ij)
            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)

            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
            v_j = self.squash(s_j)
            
            if iteration < num_iterations - 1:
                a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))
                b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)

        return v_j.squeeze(1)
    
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

In [7]:

def caps_loss(data, x, target, reconstructions):
    return margin_loss(x, target) + reconstruction_loss(data, reconstructions)

def margin_loss( x, labels, size_average=True):
    batch_size = x.size(0)

    v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True))

    left = F.relu(0.9 - v_c).view(batch_size, -1)
    right = F.relu(v_c - 0.1).view(batch_size, -1)

    if verbose : print("Dimensions of labels {}, left {} right {}".format(labels.shape,left.shape,right.shape))
    loss = labels * left + 0.5 * (1.0 - labels) * right
    loss = loss.sum(dim=1).mean()

    return loss

def reconstruction_loss( data, reconstructions):
    mseloss = nn.MSELoss()
    loss = mseloss(reconstructions.view(reconstructions.size(0), -1), data.view(reconstructions.size(0), -1))
    if verbose : print("loss {}".format(loss)) 
    return loss * 0.0005

In [8]:
def margin_loss( x, labels, size_average=True):
    batch_size = x.size(0)

    v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True))

    left = F.relu(0.9 - v_c).view(batch_size, -1)
    right = F.relu(v_c - 0.1).view(batch_size, -1)

    loss = labels * left + 0.5 * (1.0 - labels) * right
    loss = loss.sum(dim=1).mean()

    return loss

def model_loss( x, target):
    return margin_loss(x, target)

def decoder(x, data):
        classes = torch.sqrt((x ** 2).sum(2))
        classes = F.softmax(classes)
        
        _, max_length_indices = classes.max(dim=1)
        masked = Variable(torch.eye(NUM_CLASSES))
        if USE_CUDA:
            masked = masked.to(device)#cuda()
        masked = masked.index_select(dim=0, index=max_length_indices.squeeze(1).data)
        
        return masked

In [None]:
from torchsummary import summary
model = torchvision.models.resnet18(pretrained=True)
modules = list(model.children())[:-4]
model=nn.Sequential(*modules)
for param in model.parameters():
    param.requires_grad = False
    
model.layer3 = nn.Sequential(ConvLayer(), PrimaryCaps(), DigitCaps())
loss_train = []        
accuracy_train = []

model = model.to(device)
optimizer = optim.Adam(model.layer3.parameters(),lr = 0.001)
start = time.time()
n_epochs = 3

for epoch in range(n_epochs):
        model.train() 
        
        train_loss = 0
        train_accuracy = 0
        print('Epoch {}/{}'.format(epoch,3-1))
        print('-'*10)
        for batch_id, (inputs, labels) in enumerate(dataloaders['train']):
            #inputs, labels = next(iter(dataloaders['train']))
            if MARVEL_USE: labels = labels-1
            labels =torch.eye(NUM_CLASSES).index_select(dim=0, index=labels)
            inputs, labels = Variable(inputs), Variable(labels)
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = model_loss(outputs, labels)
            masked = decoder(outputs, inputs)
        
            loss.backward()
            optimizer.step()

            train_loss += loss.data[0]
            train_accuracy += (sum(np.argmax(masked.data.cpu().numpy(), 1) == np.argmax(labels.data.cpu().numpy(), 1)) / float(batch_size))
            
            if batch_id % 100 == 0:
                print("train accuracy:", sum(np.argmax(masked.data.cpu().numpy(), 1) == 
                                       np.argmax(labels.data.cpu().numpy(), 1)) / float(batch_size))
                if verbose: print("masked {}".format(np.argmax(masked.data.cpu().numpy(), 1)))
                if verbose: print("labels {}".format(np.argmax(labels.data.cpu().numpy(), 1)))
#                batch_accuracy.append(sum(np.argmax(preds.data.cpu().numpy(), 1) == 
#                                       np.argmax(labels.data.cpu().numpy(), 1)) / float(batch_size))
        
        loss_train.append(train_loss/len(dataloaders['train']))
        accuracy_train.append(train_accuracy/len(dataloaders['train']))
end = time.time()
print("Training time execution {}".format(end-start))
print("Loss value for training phase: {}".format(train_loss / len(dataloaders['train'])))
print("Accuracy value for training phase: {}".format(train_accuracy / len(dataloaders['train'])))

Epoch 0/2
----------




train accuracy: 0.08
masked [17  0 10  3  1  7 14  9  7 20 16 22  3 24  5  9 12  4  1 17  2 20 21 20
 22 22 15  7 12  5  5 13 16 10 16  9  5 18 14  4 21  3  0  0 24  6 24 21
 21  2 14 14 12 14  1 22  0 10 10 10  2 21  1 18 19 19 17 21  3 19 15  0
 23 24 19  5 20  8 20  9  2 20  4 10 19 17 10 12 20 12 10  2 23 13  1 21
  0  4  8  3]
labels [17 19  6 19  0 11 24 11  7 16 11  9 24  2 13 22  2 25 17 17  8  3  2 20
  3  3 25 22  2  9 18  3 20 16 15 17 19 10 14 18  5 13 13 17 20 17 18 12
 12 11  3  7 22 16 10  6 18  5 24 10  2 23  3 23  0 22  8 15 23  9  3 24
 12 20  9 10  4 21  5  6 25 10 11 18 18  5 13 12  8  2  5  6 22 23 17  1
 10 21 22 16]
train accuracy: 0.13
masked [18 15 13 13 21  5  8  1 14  6  8 14  3 13 15 14 19 20  1 13 15 13  5 21
 13 13 13 18  5  8  1 21 15  6  2 25  3  2 18  6 18 13 25 14 25  6 13 14
  5 13  8 13  8  1 13 25  8  7  8  6 18 13 15  2 13  3  6 20 21 25 14  5
 13 20  7 18 14 14  8  4 15 15 12  5 19 13 14 21 19 15 21  5 22  5 18  1
  1 18  8 25]
labels [ 6 15 17  6

train accuracy: 0.29
masked [11  0  6  1 16 15 17  6  6  6  2  8 11  8 13  1  5  1  1 23  8  8 14 11
  9 14 13  6  8 11  1 13  6  1  6 14 22  1  1  7 11 25  9 24 13  9  2  2
 13 25  8 17 19 19 13  0 17 11  9  2 11  1  0  2 20 14  2 25 10 25  1 11
 11 14  1 13  8  2  6  1 19 14 11 11 17 19  9  1  1 15  9  2 14 25 13  6
  1  2  1 11]
labels [12  0  6  4  5 10 11 11  6  3  2  8 24  2 13 20 13 16 21 20  3  8 18 20
  9 22 13  0  8  3  0 13  4 25  7 21 19 10 11 13 11 21  9 25 13 18 12 20
 25 23  8 12 11 14 13  0 12 11 18  2 22 21  0  2 12 14  3 23 10 25 13 21
 25 16  1 18  8 18  3 25 14  3 11 24 12 16 18 20  7 15 18  3  7 12  5  7
 25  3 21 11]
train accuracy: 0.29
masked [14 19 15 14 11  0  0  7 14  7 14  6 11 17  8  6  2 15  9  2 13 17 19  0
  1 14 11 20  0 25 18  0 14  0  7 13 14 11 25  1 24 22 25 21 11 11  5  1
 25 11  2 20 11 11  1  9 20 15  6  4 19 23 23 13  6 14  9  9  0 11 21 15
  7  6  1  5  6  6 19 19  0 15 11  6 17 13 14  1 13 14 14  9 19 15  2  6
 15 25  1 24]
labels [18  7 15 10

train accuracy: 0.46
masked [18 16  6 11 20  3 19  3 11 25  9 15  6 22 14  7  0  9 11  5 19  8 20  4
  5 25  5 10 13  6 19  1  3 16 13 19 22  6  4 14  0 20 19 15 16 19 19  3
 25  5 20 20  8  6  1 19 20 20 23 25 13 25 14 19  1 25  6 25 22 10 20 25
 19  9 17 14  1  0 14 20 19 15 14  0  0 11 18 19  2  9 14  2  0  1  2  5
 25 25 23  5]
labels [18 16 24 11 20 12  7  4 11  0  9 20  6 21 14  1  0  9 11 25 24  8 17  4
  5 16 14 10 13  6 23 21  3 17 23 22  7  6 21 20  0 24  1 10 10 19  7  3
 23 22 25  4  8  6 23 22 25 12 17 10 13 22 16 24 17 25  6 25 22 10  1 25
 19  9  2 14 21  4 14  8 17 15 14  0 21 11  9 16  3  9 14  2  0  1  2  2
 25 23 12 20]
train accuracy: 0.43
masked [11 14  6 11 14 15  9  1 20  0  8  6 17  2  5  9 11 21 25 15 19  1  0 15
 23  6  3 15 17  2  8  8 17 23 25 14  2 25 25  2 19 21 23  0 23  2 19  4
 15 19 18 18 25 18 19 25 20 25  5  0 25  5 17 17  3 19  1 13 11  0  1 25
 25  0 16 14  8  1 19 13 15  0 17 19  9 15 14 11 14 15 25 10 11  6 23 17
  2 19  1  1]
labels [11 12  6 11

train accuracy: 0.45
masked [19 19 19 16 19 15 18 11 25 19  3 16 19 25 11 25  9 25  7 11 17  9 25  5
  6 18  8 11  2 25 21  0  6 10  4 15  8  3 13 17 20 15 25  4 11  7  7 19
  1  8  9 15 25  0 15  5 20  5  7 11 25 17 17 19 20  8 18 25 25 19 16 14
 18 25  2 20 10 19 11 19  9 18 23  0 18 19 11  0 25 20 25 11  8 11  0 17
 25 17 11  5]
labels [ 3 14 24 23 18 15 18 11  7  3 18 10 25 18 11 10  9 19 21 11  1  9  5  4
 21 18  8 11  3  5  1 25  6 10 24 15  8  3 13 12 20 15 23  4 11 19 12 19
 11  8  9 15 20  7 15 20 12  5 12 11 12  3  5 20 12  8 18 23 20  7 16 14
 18  7  3 16  1 19 11 14  9 18 13  0 18 13 11  0 22 21 21 11  8 11  0 18
 14 12 11  7]
train accuracy: 0.42
masked [ 2 14 19 19  6 25 24 25  6 11 20 24 23 19  5 25 25  8 10  9 15 15 20  2
 10 15  0 19 19 19  0  0 25 19  9  8  2 11  5 13 14 19 19 25 25 20  2 18
  6 14 19 23  2  7 10  0 19 11  6 19 18 16 19 19  0 10 25 19  0 15  0 19
 25  6 11 14 20 10 18 10 25 19 11  0  9 22  0 19  8 18  8 21  9 11  5  6
  6 15  8  8]
labels [ 8 14 18 19

In [None]:
model.eval()
test_loss = 0
test_accuracy = 0
start = time.time()
for batch_id, (inputs, labels) in enumerate(dataloaders['val']):
    labels =torch.eye(NUM_CLASSES).index_select(dim=0, index=labels)
    inputs, labels = Variable(inputs), Variable(labels)
    inputs = inputs.to(device)
    labels = labels.to(device)

    outputs = model(inputs)
    masked = decoder(outputs, inputs)
    
    loss = model_loss(outputs, labels)
    test_loss += loss.data[0]
    test_accuracy += (sum(np.argmax(masked.data.cpu().numpy(), 1) == np.argmax(labels.data.cpu().numpy(), 1)) / float(batch_size))
    
    if batch_id % 100 == 0:
            print("test accuracy:", sum(np.argmax(masked.data.cpu().numpy(), 1) == 
                                   np.argmax(labels.data.cpu().numpy(), 1)) / float(batch_size))
            print("masked {}".format(np.argmax(masked.data.cpu().numpy(), 1)))
            print("labels {}".format(np.argmax(labels.data.cpu().numpy(), 1)))
        
            
end = time.time()   
print("Validation time execution {}".format(end-start))
print("Loss value for test phase: {}".format(test_loss / len(dataloaders['val'])))
print("Accuracy value for test phase: {}".format(test_accuracy / len(dataloaders['val'])))


In [None]:
print(model)
import matplotlib.pyplot as plt
n_epochs = 3
epochs = np.arange(1,n_epochs+1)
plt.plot(epochs, loss_train, color='g')
plt.plot(epochs, loss_train, color='orange')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training phase')
plt.show()

In [None]:
torch.cuda.empty_cache()