In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.optim import Adam
import torch.nn.functional as F

import csv
from skimage import io

from PIL import Image
import pandas as pd

import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

import matplotlib.pyplot as plt
import time
import os
import copy
import import_ipynb
import Caps_basics.CapsNet_Layers as CapsNet_Layers 
import Caps_basics.ResNetCaps_E as ResNetCaps_E
import Fixed_weight_loss

implementation_name = "Fixed_weight_loss_train"
CUDA = "cuda:0"
model_name = "CapsNet"
log_hm = False
FC = True

def lr_decrease(optimizer, lr_clip):  
    for param_group in optimizer.param_groups:
        init_lr = param_group['lr'] 
        param_group['lr'] = init_lr*lr_clip
        
def isnan(x):
    return x != x   

def save_checkpoint(state, filename='checkpoint.pth.tar'):
    torch.save(state, filename)    

if(model_name == "ResNetCaps"):
    resize_dim = (224,224)
else:
    resize_dim = (32,32)

dataset_transform = transforms.Compose([
    transforms.Resize(resize_dim),
    transforms.ToTensor(),        
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])


batch_size = 100
NUM_CLASSES = 10

print("CIFAR10")
image_datasets = {'train': datasets.CIFAR10('../data', train=True, download=True, transform=dataset_transform),'val': datasets.CIFAR10('../data', train=False, download=True, transform=dataset_transform)}
print("Initializing Datasets and Dataloaders...")
dataloaders = {'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=batch_size, shuffle=True) , 'val': torch.utils.data.DataLoader(image_datasets['val'], batch_size=batch_size, shuffle=True) }
print("Initializing Datasets and Dataloaders...")

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

In [None]:
print("=> using model CapsuleNET with the new loss")
USE_CUDA = True
device = torch.device(CUDA if torch.cuda.is_available() else "cpu")
if model_name == "ResNetCaps":
    model = ResNetCaps_E.ResNetCaps(NUM_CLASSES)
else:
    model = CapsNet_Layers.CapsNet(NUM_CLASSES,FC)
if USE_CUDA:
    model = model.to(device)#cuda()
    print('cuda')
    
optimizer = Adam(model.parameters(),lr = 0.0001)
#optimizer = optim.SGD([{'params': model.parameters(), 'weight_decay': 5e-4}], lr=0.1, momentum=0.9, nesterov=True)


criterion = nn.CrossEntropyLoss().to(device)

######NEWLOSS
criterionNew = Fixed_weight_loss.Fixed_weight_loss()
criterionNew = criterionNew.to(device)
#############

n_epochs = 100
x = range(0,n_epochs)
accuracy_train,loss_train,loss_train_AN,loss_train_ML = [],[],[],[]

with open(implementation_name+".txt", "w") as text_file:
    text_file.write("Model {} epoch max {}".format(model_name,n_epochs))


start = time.time()
for epoch in range(n_epochs): 
    model.train() 
    print('epoch {}:{}'.format(epoch+1, n_epochs))     
    train_loss,train_loss_angle,train_loss_margin,train_accuracy = 0,0,0,0
    
    batch_accuracy = []
    if log_hm:
        folder_name = "heatmap/epoch_"+str(epoch)
        if not os.path.exists(folder_name):
            os.makedirs(folder_name)
    
    for batch_id, (data, target) in enumerate(dataloaders['train']):
        if log_hm and batch_id==0:
            fig, ax = plt.subplots()
            A = data[1,0,:,:].numpy()
            im = ax.imshow(A)
            plt.savefig(folder_name+"/"+str(batch_id)+"image"+str(target[1].item)+str(epoch)+".jpg")
        
        target =torch.eye(NUM_CLASSES).index_select(dim=0, index=target)          
        data, target = Variable(data), Variable(target)
        
        if USE_CUDA:
            data, target = data.to(device), target.to(device)#.cuda()

        target_m = []
        for i in range(len(target)):
            n_loc = (target[i,:] == 1).nonzero()
            m = torch.zeros(NUM_CLASSES,NUM_CLASSES)
            m[n_loc,n_loc] = 1
            target_m.append(m)
        del m 
        target_m = torch.stack(target_m).to(device)  
        
        optimizer.zero_grad()

        if model_name == "ResNetCaps":
            output = model(data)       
        else:
            output, _,masked,output_fc = model(data,target)    
         
        #########NEWLOSS########
        
        L_angle = criterionNew.arc_loss(output_fc.squeeze(),target_m,epoch,batch_id,"heatmap/",val=0)

#############################################only diagonal#####################################################              
        b = []
        for i in range(len(L_angle)):
            b.append(torch.diag(L_angle[i]))
        b = torch.stack(b)
        _,label = torch.max(target, 1)
        del L_angle
        loss_AN =criterion(b,label.long())
################################################################################################################ 
        
# In CapsNet la loss del modello tiene di conto sia dell'output dell'encoding quanto quello del decoding
##############################################marginal loss##################################################### 
        loss_ML = criterionNew.margin_loss(output,target) 
        loss = loss_AN + loss_ML
###############################################################################################################          

        if isnan(loss):
            print("loss lost")
            break
            
        loss.backward()
        optimizer.step()

        train_loss += float(loss.data)
        train_accuracy += float(sum(np.argmax(b.data.cpu().numpy(), 1) == np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size))
        
        train_loss_angle += float(loss_AN)
        train_loss_margin += float(loss_ML)

        if batch_id % 100 == 0:
            print("train diag accuracy:", sum(np.argmax(b.data.cpu().numpy(), 1) == 
                                   np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size))
            print("angle loss {} margin loss {}" .format(loss_AN.data,loss_ML))

            batch_accuracy.append(float(sum(np.argmax(b.data.cpu().numpy(), 1) == 
                                   np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size)))
            
    accuracy_train.append(np.mean(batch_accuracy))
    del data, target, b,batch_accuracy, output, output_fc, masked
    loss_train.append(train_loss/len(dataloaders['train']))
    loss_train_AN.append(train_loss_angle/len(dataloaders['train']))
    loss_train_ML.append(train_loss_margin/len(dataloaders['train']))
    del loss_AN, loss_ML
    
    if epoch % 10 == 0 and not epoch == 0:
        test_loss,test_accuracy = 0,0

        start_test = time.time()

        for batch_id, (data, target) in enumerate(dataloaders['val']):
            target =torch.eye(NUM_CLASSES).index_select(dim=0, index=target)          
            data, target = Variable(data), Variable(target)

            if USE_CUDA:
               data, target = data.to(device), target.to(device)#.cuda()

            target_m = []
            for i in range(len(target)):
                n_loc = (target[i,:] == 1).nonzero()
                m = torch.zeros(NUM_CLASSES,NUM_CLASSES)
                m[n_loc,n_loc] = 1
                target_m.append(m) 
            target_m = torch.stack(target_m).to(device)  
            output, _,masked,output_fc = model(data,val=1)       
    #########NEWLOSS########
            L_angle = criterionNew.arc_loss(output_fc.squeeze(),target_m,epoch,batch_id,"heatmap/",val=1)
    
            b = []
            for i in range(len(L_angle)):
                b.append(torch.diag(L_angle[i]))
            b = torch.stack(b)
    #print(b.size())
            _,label = torch.max(target, 1)
            loss_AN =criterion(b,label.long())
            loss_ML = criterionNew.margin_loss(output,target) 
            loss = loss_AN + loss_ML
            test_loss += float(loss.data)
    
            test_accuracy += float(sum(np.argmax(b.data.cpu().numpy(), 1) == np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size))
            if batch_id % 100 == 0:
                print("test accuracy:", sum(np.argmax(b.data.cpu().numpy(), 1) == 
                                        np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size) )    
                print("loss {} margin loss {}" .format(loss.data,loss_ML))
        end_test= time.time()   
        print("Validation time execution {}".format(end_test-start_test))
        print("Loss value for test phase: {}".format(test_loss / len(dataloaders['val'])))
        print("Accuracy value for test phase: {}".format(test_accuracy / len(dataloaders['val'])))
        with open(implementation_name+".txt", "a") as text_file:
            text_file.write("EPOCH {}".format(epoch))
            text_file.write("Validation time execution {}".format(end_test-start_test))
            text_file.write("Loss value for test phase: {}".format(test_loss / len(dataloaders['val'])))
            text_file.write("Accuracy value for test phase: {}".format(test_accuracy / len(dataloaders['val'])))

        del L_angle, b, test_loss, loss_ML, loss_AN, target_m, data, target, test_accuracy, masked, output, output_fc
end = time.time()
print("Training time execution {}".format(end-start))
print("Loss value for training phase: {}".format(train_loss /  len(dataloaders['train'])))
print("Accuracy value for training phase: {}".format(train_accuracy /  len(dataloaders['train'])))

with open(implementation_name+".txt", "a") as text_file:
    text_file.write("EPOCH {}".format(epoch))
    text_file.write("Training time execution {}".format(end-start))
    text_file.write("Loss value for training phase: {}".format(train_loss /  len(dataloaders['train'])))
    text_file.write("Accuracy value for training phase: {}".format(train_accuracy /  len(dataloaders['train'])))
del train_loss,train_loss_angle,train_loss_margin,train_accuracy

epochs = np.arange(1,n_epochs+1)
plt.plot(epochs, loss_train, color='g')
plt.plot(epochs, loss_train_AN, color='b')
plt.plot(epochs, loss_train_ML, color='c')
plt.plot(epochs, accuracy_train, color='orange')
plt.xlabel('Epochs')
plt.ylabel('Accuracy - Loss')
plt.title('Training phase')
plt.savefig(implementation_name+".png")
save_checkpoint({
            'epoch': epoch + 1,
            'loss_type': implementation_name,
            'arch': 'CapsNet',
            'state_dict': model.state_dict(),
            'optimizer' : optimizer.state_dict(),
            
        }, "checkpoint_"+implementation_name+"_"+model_name+"_"+str(epoch)+".pth.tar")


In [None]:
model.eval()
test_loss,test_accuracy = 0,0

start = time.time()

for batch_id, (data, target) in enumerate(dataloaders['val']):
    
    target =torch.eye(NUM_CLASSES).index_select(dim=0, index=target)          
    data, target = Variable(data), Variable(target)

    if USE_CUDA:
        data, target = data.to(device), target.to(device)#.cuda()

    target_m = []
    for i in range(len(target)):
        n_loc = (target[i,:] == 1).nonzero()
        m = torch.zeros(NUM_CLASSES,NUM_CLASSES)
        m[n_loc,n_loc] = 1
        target_m.append(m) 
    target_m = torch.stack(target_m).to(device)  
    output, _,masked,output_fc = model(data,val=1)       
    #########NEWLOSS########
    L_angle = criterionNew.arc_loss(output_fc.squeeze(),target_m,epoch,batch_id,"heatma/",val=1)
    
    b = []
    for i in range(len(L_angle)):
        b.append(torch.diag(L_angle[i]))
    b = torch.stack(b)
    #print(b.size())
    _,label = torch.max(target, 1)
    loss =criterion(b,label.long())
    loss_ML = criterionNew.margin_loss(output,target) 
    loss = loss + loss_ML
    test_loss += loss.data
    
    test_accuracy += (sum(np.argmax(b.data.cpu().numpy(), 1) == np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size))
    
    if batch_id % 100 == 0:
            print("test accuracy:", sum(np.argmax(b.data.cpu().numpy(), 1) == 
                                   np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size) )    
            print("loss {} margin loss {}" .format(loss.data,loss_ML))
end = time.time()   
print("Validation time execution {}".format(end-start))
print("Loss value for test phase: {}".format(test_loss / len(dataloaders['val'])))
print("Accuracy value for test phase: {}".format(test_accuracy / len(dataloaders['val'])))
with open(implementation_name+".txt", "a") as text_file:
    text_file.write("Validation time execution {}".format(end-start))
    text_file.write("Loss value for test phase: {}".format(test_loss / len(dataloaders['val'])))
    text_file.write("Accuracy value for test phase: {}".format(test_accuracy / len(dataloaders['val'])))

In [None]:
torch.cuda.empty_cache()