In [None]:
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.optim import Adam
import torch.nn.functional as F

import csv
from skimage import io

from PIL import Image
import pandas as pd

import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

import matplotlib.pyplot as plt
import time
import os
import copy

import import_ipynb
import ResNetCaps
import Pets_Loader
import Animals_Loader
import Marvel_Loader
import Interaction_bilinear

verbose = False
USE_CUDA = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

CIFAR10_USE = False
MARVEL_USE = False
PETS_USE = True
Animals_USE = False
CRITERION = True
ResNetCaps_dim = False

def lr_decrease(optimizer, lr_clip):  
    for param_group in optimizer.param_groups:
        init_lr = param_group['lr'] 
        param_group['lr'] = init_lr*lr_clip
        
def isnan(x):
    return x != x

if ResNetCaps_dim:
    dim = (224,224)
else:
    dim = (32,32)
                  
dataset_transform = transforms.Compose([
    transforms.Resize(dim),
    transforms.ToTensor(),        
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])


batch_size = 32


if CIFAR10_USE: 
    NUM_CLASSES = 10
    print("CIFAR10")
    name_dataset =  "CIFAR10"
    image_datasets = {'train': datasets.CIFAR10('../data', train=True, download=True, transform=dataset_transform),'val': datasets.CIFAR10('../data', train=False, download=True, transform=dataset_transform)}
    print("Initializing Datasets and Dataloaders...")
    dataloaders = {'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=batch_size, shuffle=True) , 'val': torch.utils.data.DataLoader(image_datasets['val'], batch_size=batch_size, shuffle=True) }
    print("Initializing Datasets and Dataloaders...")

if MARVEL_USE: 
    
    NUM_CLASSES = 26
    print("MARVEL")
    name_dataset =  "MARVEL"
    dat_file = "/home/rita/JupyterProjects/EYE-SEA/DataSets/marveldataset2016-master/FINAL.dat"
    #dat_file = "/media/Data/rita/EYE-SEA/Datasets/marveldataset2016-master/FINAL.dat"
    image_datasets = {'train': Marvel_Loader.MARVEL_dataset(dat_file,train = True,transform=dataset_transform),'val': Marvel_Loader.MARVEL_dataset(dat_file,train = False,transform=dataset_transform)}
    print("Initializing Datasets and Dataloaders...")
    dataloaders = {'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=batch_size, shuffle=True) , 'val': torch.utils.data.DataLoader(image_datasets['val'], batch_size=batch_size, shuffle=True) }
    print("Initializing Datasets and Dataloaders...")

if  PETS_USE:
    NUM_CLASSES = 37
    print("PETS")
    name_dataset =  "PETS"
    dat_file = "/home/rita/JupyterProjects/EYE-SEA/DataSets/Pets/Pet_Datasets"
    #dat_file = "/media/Data/rita/EYE-SEA/Datasets/Pets/Pet_Datasets"
    image_datasets = {'train': Pets_Loader.PETS_dataset(dat_file,train = True,transform=dataset_transform),'val': Pets_Loader.PETS_dataset(dat_file,train = False,transform=dataset_transform)}
    print("Initializing Datasets and Dataloaders...")
    dataloaders = {'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=batch_size, shuffle=True) , 'val': torch.utils.data.DataLoader(image_datasets['val'], batch_size=batch_size, shuffle=True) }
    print("Initializing Datasets and Dataloaders...")    
    
if  Animals_USE:
    NUM_CLASSES = 50
    print("Animals")
    name_dataset =  "Animals"
    dat_file = "/home/rita/JupyterProjects/EYE-SEA/DataSets/Animals_with_Attributes2/JPEGImages"
    #dat_file = "/media/Data/rita/EYE-SEA/Datasets/Animals_with_Attributes2/JPEGImages"
    image_datasets = {'train': Animals_Loader.Animals_dataset(dat_file,train = True,transform=dataset_transform),'val': Animals_Loader.Animals_dataset(dat_file,train = False,transform=dataset_transform)}
    print("Initializing Datasets and Dataloaders...")
    dataloaders = {'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=batch_size, shuffle=True) , 'val': torch.utils.data.DataLoader(image_datasets['val'], batch_size=batch_size, shuffle=True) }
    print("Initializing Datasets and Dataloaders...")    


dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

not_ibn = True
model_path =  "/home/rita/JupyterProjects/EYE-SEA/ResNet_CAPSNET/PY/ResNetCaps/"+name_dataset+"/checkpoint_99.pth.tar"

In [None]:
#Load model
model = Interaction_bilinear.Bilinear(NUM_CLASSES, batch_size,not_ibn = not_ibn, model_path = model_path, model_name = 'CapsNet')
model = model.to(device)

#optimizers
criterion = nn.CrossEntropyLoss()
##criterion = nn.NLLLoss()
##optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr = 0.001, momentum=0.3)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),lr = 0.001)
n_epochs = 100
threshold = 20
lr_clip = 0.01
#train
start = time.time()
#batch_id = 100
#inputs, labels = next(iter(dataloaders['train']))
accuracy_train = []
loss_train = []
request_exit = False

for epoch in range(n_epochs): 
    model.train() 
    train_loss = 0
    train_accuracy = 0

    print('epoch {}:{}'.format(epoch+1, n_epochs)) 
    for batch_id, (inputs, labels) in enumerate(dataloaders['train']):
        if MARVEL_USE: labels = labels-1
        labels =torch.eye(NUM_CLASSES).index_select(dim=0, index=labels)
        inputs, labels = Variable(inputs), Variable(labels)
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        if CRITERION:
            _,label = torch.max(labels,1)
            loss = criterion(outputs, label.long())
        else:
            loss = model.model_loss(outputs, labels)
        if verbose: print(loss)
        if isnan(loss): 
            request_exit = True
            print("lost loss")
            break    

        loss.backward()
        optimizer.step()

        train_loss += loss.data#[0]
        train_accuracy += (sum(np.argmax(outputs.data.cpu().numpy(), 1) == np.argmax(labels.data.cpu().numpy(), 1)) / float(batch_size))

        if batch_id % 100 == 0:
            print("train accuracy:", sum(np.argmax(outputs.data.cpu().numpy(), 1) == 
                                   np.argmax(labels.data.cpu().numpy(), 1)) / float(batch_size))
            print("train_loss:",loss.data)#[0])

            if verbose: print("masked {}".format(np.argmax(masked.data.cpu().numpy(), 1)))
            if verbose: print("labels {}".format(np.argmax(labels.data.cpu().numpy(), 1)))
    if batch_id<0 or (train_loss) ==0:
        epoch -= 1
    else:
        accuracy_train.append(train_accuracy/batch_id)#len(dataloaders['train']))
        if epoch > 2:
            if accuracy_train[-1]<(accuracy_train[-2]*0.5):
                print("zero loss {}".format(train_loss/batch_id))
                request_exit = True
            if (epoch % threshold)==0:
                print("Learnining rate decrease")
                lr_decrease(optimizer, lr_clip)
        loss_train.append(train_loss/batch_id)#len(dataloaders['train']))
    

    #if request_exit: break    
end = time.time()
print("Training time execution {}".format(end-start))
print("Loss value for training phase: {}".format(train_loss / batch_id))#len(dataloaders['train'])))
print("Accuracy value for training phase: {}".format(train_accuracy / batch_id))#len(dataloaders['train'])))
epochs = np.arange(1,len(loss_train)+1)
plt.plot(epochs, loss_train, color='g')
plt.plot(epochs, accuracy_train, color='orange')
plt.xlabel('Epochs')
plt.ylabel('Accuracy - Loss')
plt.title('Training phase')
plt.savefig("CapsNet_Bilinear.png") 

In [None]:
epochs = np.arange(1,len(loss_train)+1)
plt.plot(epochs, loss_train, color='g')
plt.plot(epochs, accuracy_train, color='orange')
plt.xlabel('Epochs')
plt.ylabel('Accuracy - Loss')
plt.title('Training phase')
plt.savefig("Dense_Bilinear.png") 

In [None]:
model.eval()
test_loss = 0
test_accuracy = 0 
start = time.time()
for batch_id, (inputs, labels) in enumerate(dataloaders['val']):
    if MARVEL_USE: labels = labels-1
    labels = torch.eye(NUM_CLASSES).index_select(dim=0, index=labels)
    inputs, labels = Variable(inputs), Variable(labels)

    if USE_CUDA: inputs, labels = inputs.to(device), labels.to(device)#cuda()

    outputs = model(inputs)
    if CRITERION:
        _,label = torch.max(labels,1)
        loss = criterion(outputs, label.long())
    else:
        loss = model.model_loss(outputs, labels)

    test_loss += loss.data#[0]
    test_accuracy += (sum(np.argmax(outputs.data.cpu().numpy(), 1) == np.argmax(labels.data.cpu().numpy(), 1)) / float(batch_size))

    if batch_id % 100 == 0:
        print("test accuracy:", sum(np.argmax(outputs.data.cpu().numpy(), 1) == 
                               np.argmax(labels.data.cpu().numpy(), 1)) / float(batch_size))
end = time.time()
print("Test time execution {}".format(end-start))
print("Loss value for test phase: {}".format(test_loss /  len(dataloaders['val']))) 
print("Accuracy value for test phase: {}".format(test_accuracy /  len(dataloaders['val'])))


In [None]:
torch.cuda.empty_cache()