In [None]:
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.optim import Adam
import torch.nn.functional as F

import csv
from skimage import io

from PIL import Image
import pandas as pd

import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

import matplotlib.pyplot as plt
import time
import os
import copy

import import_ipynb
import Caps_basics.CapsNet_Layers_multiFC as CapsNet_Layers_MFC
import Caps_basics.CapsNet_Layers as CapsNet_Layers
import Caps_basics.ResNetCaps_E as ResNetCaps_E
import Fixed_weight_loss

implementation_name = "Fixed_weight_loss_train_moreFC"
CUDA = "cuda:0"
model_name = "CapsNet"
log_hm = False
FC = True
verbose = False


def lr_decrease(optimizer, lr_clip):  
    for param_group in optimizer.param_groups:
        init_lr = param_group['lr'] 
        param_group['lr'] = init_lr*lr_clip
        
def isnan(x):
    return x != x

def resume_model(name_file, model, optimizer,map_location): 
    if os.path.isfile(name_file):
        print("=> loading checkpoint '{}'".format(name_file))
        checkpoint = torch.load(name_file,map_location=map_location)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(name_file, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(name_file))

    return start_epoch,model,optimizer

In [None]:
if(model_name == "ResNetCaps"):
    resize_dim = (224,224)
else:
    resize_dim = (32,32)

dataset_transform = transforms.Compose([
    transforms.Resize(resize_dim),
    transforms.ToTensor(),        
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])


batch_size = 100


NUM_CLASSES = 10
print("CIFAR10")
image_datasets = {'train': datasets.CIFAR10('../data', train=True, download=True, transform=dataset_transform),'val': datasets.CIFAR10('../data', train=False, download=True, transform=dataset_transform)}
print("Initializing Datasets and Dataloaders...")
dataloaders = {'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=batch_size, shuffle=True) , 'val': torch.utils.data.DataLoader(image_datasets['val'], batch_size=batch_size, shuffle=True) }
print("Initializing Datasets and Dataloaders...")

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

In [None]:
print("=> using model CapsuleNET with the new loss")
USE_CUDA = True
device = torch.device(CUDA if torch.cuda.is_available() else "cpu")
if model_name == "ResNetCaps":
    model = ResNetCaps_E.ResNetCaps(NUM_CLASSES)
elif model_name == "CapsNet":
    model = CapsNet_Layers.CapsNet(NUM_CLASSES,FC)
else:   
    model = CapsNet_Layers_MFC.CapsNet_MR(NUM_CLASSES,FC)
optimizer = Adam(model.parameters(),lr = 0.0001)
model_path = 'FC_1/checkpoint_Fixed_weight_loss_train_CapsNet_49.pth.tar'
start_epoch,model,optimizer = resume_model(model_path, model, optimizer,map_location=CUDA)
for param in model.parameters():
    param.requires_grad = False
if USE_CUDA:
    model = model.to(device)#cuda()
    print('cuda')
    
criterion = nn.CrossEntropyLoss().to(device)
######NEWLOSS
criterionNew = Fixed_weight_loss.Fixed_weight_loss()
criterionNew = criterionNew.to(device)
#############

In [None]:
model.eval()
test_loss = 0
test_accuracy = 0
start = time.time()

for batch_id, (data, target) in enumerate(dataloaders['val']):
    
    target =torch.eye(NUM_CLASSES).index_select(dim=0, index=target)          
    data, target = Variable(data), Variable(target)

    if USE_CUDA:
        data, target = data.to(device), target.to(device)#.cuda()

    target_m = []
    for i in range(len(target)):
        n_loc = (target[i,:] == 1).nonzero()
        m = torch.zeros(NUM_CLASSES,NUM_CLASSES)
        m[n_loc,n_loc] = 1
        target_m.append(m) 
    target_m = torch.stack(target_m).to(device)  
    output,_,masked,output_fc = model(data)      
    if FC: output = output_fc.view(output_fc.size(0),NUM_CLASSES,-1)
    else: output = output_digit    
    #########NEWLOSS########
    L_angle = criterionNew.arc_loss(output.squeeze(),target_m,start_epoch,batch_id,'heatmap/',val=1)
    
    b = []
    for i in range(len(L_angle)):
        b.append(torch.diag(L_angle[i]))
    b = torch.stack(b)
    #print(b.size())
    _,label = torch.max(target, 1)
    loss =criterion(b,label.long())
    loss_ML = criterionNew.margin_loss(output,target) 
    loss = loss + loss_ML
    test_loss += loss.data
    
    test_accuracy += (sum(np.argmax(b.data.cpu().numpy(), 1) == np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size))
    
    if batch_id % 100 == 0:
            print("test accuracy:", sum(np.argmax(b.data.cpu().numpy(), 1) == 
                                   np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size) )    
            print("loss {} margin loss {}" .format(loss.data,loss_ML))
end = time.time()   
print("Validation time execution {}".format(end-start))
print("Loss value for test phase: {}".format(test_loss / len(dataloaders['val'])))
print("Accuracy value for test phase: {}".format(test_accuracy / len(dataloaders['val'])))
with open(implementation_name+".txt", "a") as text_file:
    text_file.write("Validation time execution {}".format(end-start))
    text_file.write("Loss value for test phase: {}".format(test_loss / len(dataloaders['val'])))
    text_file.write("Accuracy value for test phase: {}".format(test_accuracy / len(dataloaders['val'])))