In [1]:
import math
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
import simplex_coordinates2

verbose = False

class Norm_Arc_loss(nn.Module):
    def __init__(self, s=5.0, in_feature=10,out_feature=10):
        super(Norm_Arc_loss, self).__init__()

        self.d = out_feature - 1
        m = math.acos(-1/self.d)
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.alpha = (1-math.sqrt(self.d +1))/self.d
        self.s = s
        vertex_simplex= torch.Tensor(simplex_coordinates2.simplex_coordinates2(self.d)).permute(1,0)
        lvalue = torch.stack([self.alpha*sum(vertex_simplex[i,:]) for i in range(vertex_simplex.size()[0])])
        self.weight = Parameter(torch.cat((vertex_simplex,lvalue.unsqueeze(1)),1))
        self.weight.requires_grad = False
        #nn.init.xavier_uniform_(self.weight)
        if verbose: print("Norm_Arc_loss _ weight matrix {}".format(self.weight.size()))
        #from arcface
        #make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m
        
        print("margin {}  cos_m {}  sin_m {}".format(m,self.sin_m,self.cos_m))

    def saveFigure(self,data,epoch,batch_id,folder_name,name_var):
        classes = ["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"] 
        
        fig, ax = plt.subplots()
        A = data[1,:,:].cpu().detach().numpy()
        im = ax.imshow(A)
        cbar = ax.figure.colorbar(im)
        ax.set_xticks(np.arange(10),(classes)) 
        ax.set_yticks(np.arange(10),(classes))
        plt.savefig(folder_name+"/"+str(batch_id)+name_var+str(epoch)+".jpg")
        

        
    def arc_loss(self, x, label,epoch,batch_id,folder_name,val = 0):
        if verbose: print("ARC LOSS")
        cosine = []   
        for i in range(len(x)):
            x_i = x[i,:]
            if (x_i.size()) != self.weight.size():
                print("x dimension and weight dimension do not match")
                break
            cosine_i = F.linear(F.normalize(x_i),F.normalize(self.weight))
            cosine.append(cosine_i)
        self.cosine = torch.stack(cosine)
        #if batch_id ==0: self.saveFigure(self.cosine,epoch,batch_id,folder_name,"cosine")
                
        self.sine = torch.sqrt(1.0 - torch.pow(self.cosine, 2))      
        #if batch_id ==0: self.saveFigure(self.sine,epoch,batch_id,folder_name,"sine")
            
        self.phi = self.cosine * self.cos_m - self.sine * self.sin_m
        #if batch_id ==0: self.saveFigure(self.phi,epoch,batch_id,folder_name,"phi")                
        #if batch_id ==0: self.saveFigure(label,epoch,batch_id,folder_name,"label")                
          
        if val == 0:
            output = (label * self.phi) + ((1.0 - label) * self.cosine)
        else:
            output = self.phi
        output = output * self.s
        #if batch_id ==0: self.saveFigure(output,epoch,batch_id,folder_name,"output")
                
        return output
    
    #margin loss di capsule
    def margin_loss(self, x, labels, size_average=True):
        if verbose: print("x {}".format(x.size()))
        if verbose: print("labels {}".format(labels.size()))
        batch_size = x.size(0)

        v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True)) #<-L2
        if verbose: print("v_c {}".format(v_c.size()))
        left = F.relu(0.9 - v_c).view(batch_size, -1) #**2
        right = F.relu(v_c - 0.1).view(batch_size, -1) #**2

        loss = labels * left + 0.5 * (1.0 - labels) * right

        loss = loss.sum(dim=1).mean()

        return loss
    
    def forward(self,x,labels,L_angle):
        L_margin = self.margin_loss(x,labels)
        
        loss = L_angle + L_margin
    

In [2]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.optim import Adam
import torch.nn.functional as F

import csv
from skimage import io

from PIL import Image
import pandas as pd

import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

import matplotlib.pyplot as plt
import time
import os
import copy
import import_ipynb
import CapsNet_Layers 
import ResNetCaps_E


model_name = "ResNetCaps"

def lr_decrease(optimizer, lr_clip):  
    for param_group in optimizer.param_groups:
        init_lr = param_group['lr'] 
        param_group['lr'] = init_lr*lr_clip
        
def isnan(x):
    return x != x   

if(model_name == "ResNetCaps"):
    resize_dim = (224,224)
else:
    resize_dim = (32,32)

dataset_transform = transforms.Compose([
    transforms.Resize(resize_dim),
    transforms.ToTensor(),        
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])


batch_size = 100


NUM_CLASSES = 10
print("CIFAR10")
image_datasets = {'train': datasets.CIFAR10('../data', train=True, download=True, transform=dataset_transform),'val': datasets.CIFAR10('../data', train=False, download=True, transform=dataset_transform)}
print("Initializing Datasets and Dataloaders...")
dataloaders = {'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=batch_size, shuffle=True) , 'val': torch.utils.data.DataLoader(image_datasets['val'], batch_size=batch_size, shuffle=True) }
print("Initializing Datasets and Dataloaders...")

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

importing Jupyter notebook from CapsNet_Layers.ipynb
importing Jupyter notebook from ResNetCaps_E.ipynb
CIFAR10
Files already downloaded and verified
Files already downloaded and verified
Initializing Datasets and Dataloaders...
Initializing Datasets and Dataloaders...


In [None]:
print("=> using model CapsuleNET with the new loss")
USE_CUDA = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if model_name == "ResNetCaps":
    model = ResNetCaps_E.ResNetCaps(NUM_CLASSES)
else:
    model = CapsNet_Layers.CapsNet(NUM_CLASSES)
model = model.to(device)

if USE_CUDA:
    model = model.to(device)#cuda()
    print('cuda')
optimizer = Adam(model.parameters(),lr = 0.0001)

criterion = nn.CrossEntropyLoss().to(device)
######NEWLOSS
criterionNew = Norm_Arc_loss()
criterionNew = criterionNew.to(device)
#############
#optimizer = optim.SGD([{'params': model.parameters(), 'weight_decay': 5e-4}], lr=0.1, momentum=0.9, nesterov=True)
print(torch.version)

n_epochs = 15
x = range(0,n_epochs)
accuracy_train = []
loss_train = []
loss_train_AN = []
loss_train_ML = []
m = nn.Softmax(dim=1)
start = time.time()

for epoch in range(n_epochs): 
    model.train() 
    train_loss = 0 
    train_loss_angle = 0
    train_loss_margin = 0
    train_accuracy = 0
    
    batch_accuracy = []
    folder_name = "heatmap/epoch_"+str(epoch)
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)
    
    print('epoch {}:{}'.format(epoch+1, n_epochs)) 
    for batch_id, (data, target) in enumerate(dataloaders['train']):
        #if batch_id==0:
        #    fig, ax = plt.subplots()
        #    A = data[1,0,:,:].numpy()
        #    im = ax.imshow(A)
        #    plt.savefig(folder_name+"/"+str(batch_id)+"image"+str(target[1].item)+str(epoch)+".jpg")
        
        target =torch.eye(NUM_CLASSES).index_select(dim=0, index=target)          
        data, target = Variable(data), Variable(target)
        
        if USE_CUDA:
            data, target = data.to(device), target.to(device)#.cuda()

        target_m = []
        for i in range(len(target)):
            n_loc = (target[i,:] == 1).nonzero()
            m = torch.zeros(NUM_CLASSES,NUM_CLASSES)
            m[n_loc,n_loc] = 1
            target_m.append(m) 
        target_m = torch.stack(target_m).to(device)  
        optimizer.zero_grad()
        if model_name == "ResNetCaps":
            output = model(data)       
        else:
            output, masked, _ = model(data)            
        #########NEWLOSS########
        L_angle = criterionNew.arc_loss(output.squeeze(),target_m,epoch,batch_id,folder_name,val=0)
        
#############################################entire matrix######################################################    
        #L_angle = L_angle.view(batch_size,-1,1)
        #target_m = target_m.view(batch_size,1,-1)
        #target_m_p = []
        #for i in range(len(target_m)):
        #    target_m_p.append((target_m[i,:,:] == 1).nonzero()[0,1])
        #target_m_p = torch.stack(target_m_p)

        #loss = criterion(L_angle.squeeze(),target_m_p.long())
################################################################################################################   
#############################################only diagonal#####################################################              
        b = []
        for i in range(len(L_angle)):
            b.append(torch.diag(L_angle[i]))
        b = torch.stack(b)
        _,label = torch.max(target, 1)
    
        loss_AN =criterion(b,label.long())
################################################################################################################ 
# In CapsNet la loss del modello tiene di conto sia dell'output dell'encoding quanto quello del decoding

        #loss_ML = criterionNew.margin_loss(output,target) 
        loss = loss_AN #+ loss_ML
        
        #print("LOSS  {}".format(loss))
        if isnan(loss):
            print("loss lost")
            epoch -=1
            break
            
        loss.backward()
        optimizer.step()

        train_loss += loss.data
        train_accuracy += (sum(np.argmax(b.data.cpu().numpy(), 1) == np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size))
        train_loss_angle += loss_AN
        #train_loss_margin += loss_ML

        if batch_id % 100 == 0:
            print("train diag accuracy:", sum(np.argmax(b.data.cpu().numpy(), 1) == 
                                   np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size))
            if not model_name == "ResNetCaps":
                print("train margin accuracy:", sum(np.argmax(masked.data.cpu().numpy(), 1) == 
                                   np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size))
            #print("angle loss {} margin loss {}" .format(loss_AN.data,loss_ML))
            print("angle loss {} " .format(loss_AN.data))
            batch_accuracy.append(sum(np.argmax(b.data.cpu().numpy(), 1) == 
                                   np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size))
            
    accuracy_train.append(np.mean(batch_accuracy))
    del batch_accuracy
    loss_train.append(train_loss/len(dataloaders['train']))
    loss_train_AN.append(train_loss_angle/len(dataloaders['train']))
    #loss_train_ML.append(train_loss_margin/len(dataloaders['train']))
    
end = time.time()
print("Training time execution {}".format(end-start))
print("Loss value for test phase: {}".format(train_loss /  len(dataloaders['train'])))
print("Accuracy value for test phase: {}".format(train_accuracy /  len(dataloaders['train'])))

In [None]:
epochs = np.arange(1,n_epochs+1)
plt.plot(epochs, loss_train, color='g')
plt.plot(epochs, accuracy_train, color='orange')
plt.xlabel('Epochs')
plt.ylabel('Accuracy - Loss')
plt.title('Training phase')
plt.savefig("IBN_Train_Retrained.png")

In [None]:
model.eval()
test_loss = 0
test_accuracy = 0
start = time.time()

for batch_id, (data, target) in enumerate(dataloaders['val']):
    
    target =torch.eye(NUM_CLASSES).index_select(dim=0, index=target)          
    data, target = Variable(data), Variable(target)

    if USE_CUDA:
        data, target = data.to(device), target.to(device)#.cuda()

    target_m = []
    for i in range(len(target)):
        n_loc = (target[i,:] == 1).nonzero()
        m = torch.zeros(NUM_CLASSES,NUM_CLASSES)
        m[n_loc,n_loc] = 1
        target_m.append(m) 
    target_m = torch.stack(target_m).to(device)  
    output = model(data)       
    #########NEWLOSS########
    L_angle = criterionNew.arc_loss(output.squeeze(),target_m,epoch,batch_id,folder_name,val=1)
    #L_angle = L_angle.view(batch_size,-1,1)
    #target_m = target_m.view(batch_size,1,-1)
    #target_m_p = []
    #for i in range(len(target_m)):
     #   target_m_p.append((target_m[i,:,:] == 1).nonzero()[0,1])
   # target_m_p = torch.stack(target_m_p)
    
    b = []
    for i in range(len(L_angle)):
        b.append(torch.diag(L_angle[i]))
    b = torch.stack(b)
    #print(b.size())
    _,label = torch.max(target, 1)
    loss =criterion(b,label.long())
    loss_ML = criterionNew.margin_loss(output,target) 
    loss = loss + loss_ML
    test_loss += loss.data
    
    test_accuracy += (sum(np.argmax(b.data.cpu().numpy(), 1) == np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size))
    
    if batch_id % 100 == 0:
            print("test accuracy:", sum(np.argmax(b.data.cpu().numpy(), 1) == 
                                   np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size) )    
            print("loss {} margin loss {}" .format(loss.data,loss_ML))
end = time.time()   
print("Validation time execution {}".format(end-start))
print("Loss value for test phase: {}".format(test_loss / len(dataloaders['val'])))
print("Accuracy value for test phase: {}".format(test_accuracy / len(dataloaders['val'])))


In [None]:
print(target_m.view(batch_size,-1,1).size())

In [None]:
torch.cuda.empty_cache()