In [None]:
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.optim import Adam
import torch.nn.functional as F

import csv
from skimage import io
from PIL import Image
import pandas as pd

import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

import matplotlib.pyplot as plt
import matplotlib as mpl
import time
import os
import pathlib
import copy
from datetime import date

import import_ipynb
import ResNetCaps_E
import MiniBatch_generator
import losses
import CHIMP_DataLoader
import ArcMarginProduct

ATET_use = False
LFW_use = False
CHIM_use = True

#single model selected for loading
pick_model = -1
#bool for loading previous model
load_model = False

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
#------------------------------------DATASET INITIALIZATION -----------------------------------------------------
if ATET_use:
    folder = 'ATET'
    dataset_folder = "ATeT_faces/orl_faces/"
if LFW_use:
    folder = 'LFW'
    dataset_folder = "lfw/"
if CHIM_use:
    folder = 'CHIM'
    dataset_folder = "chimpanzee_faces-master/datasets_cropped_chimpanzee_faces/data_CZoo/annotations_czoo.txt"   
    

dataset_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),        
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])

dataset_folder = os.path.join("/home/rita/JupyterProjects/EYE-SEA/DataSets/Verification", dataset_folder)


In [None]:
#----------------------------------------DATASET LOADING----------------------------------------------------
#dimension mini_batch
percent = 0.2

if CHIM_use:
    print('Loading chimp dataset')
    dataset = CHIMP_DataLoader.Chimp_Dataset(dataset_folder,dataset_transform,percent,hold_out=True)
    dataset_t = CHIMP_DataLoader.Chimp_Dataset(dataset_folder,dataset_transform,percent,train=False)
else:
    print('Loading {} dataset'.format(folder))
    dataset = MiniBatch_generator.mini_batch(dataset_folder, dataset_transform,percent)
#num of classes of the dataset
out_feat = len(dataset.num_ind_train)

In [None]:
#------------------------------------MODEL LOADER ---------------------------------- 

PATH= os.path.join(os.getcwd(),os.path.join('Log_model/ARCFACE',folder,'DIGIT/',(date.today()).isoformat()))
pathlib.Path(PATH).mkdir(parents=True, exist_ok=True)
if len(os.listdir(PATH)) > 2 and load_model:
    print('Loading model from PATH: {}'.format(PATH))
    model = ResNetCaps_E.ResNetCaps_E(DigitEnd=False)
    if pick_model == -1:
        init = len(os.listdir(PATH))-2        
        model.load_state_dict(torch.load(os.path.join(PATH,str(init))))
    else:
        model.load_state_dict(torch.load(os.path.join(PATH,str(pick_model))))
        init = pick_model
    model.eval()
else:
    print('Creating a new model')
    init=0
    model = ResNetCaps_E.ResNetCaps_E(DigitEnd=True)

In [None]:
#--------------------------Preparation of model-------------------------------------
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    model = nn.DataParallel(model)
    
model = model.to(device)

print(out_feat)
criterion = torch.nn.CrossEntropyLoss().to(device)
margin = ArcMarginProduct.ArcMarginProduct(in_feature=16,out_feature=out_feat, s=32)
margin = margin.to(device)
optimizer = optim.SGD([{'params': model.parameters(), 'weight_decay': 5e-4},
                {'params': margin.parameters(), 'weight_decay': 5e-4}], lr=0.001, momentum=0.9, nesterov=True)
#optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),lr = 0.001)

In [None]:
#-----------TRAINING PHASE ----------------------------------minibatch organised random everyepoch

n_epochs = 100
loss_list_b = []

for epoch in range(n_epochs): 
    print('epoch {}:{}'.format(epoch+1, n_epochs)) 
    model.train()
    loss_collect = 0
    in_a,labels = dataset.prepare_batch()
    in_a = torch.stack(in_a)
    in_a = in_a.to(device)
    labels = torch.Tensor(labels).to(device)
    
    #Compute embeddings for anchor, positive, and negative images

    emb_a = model(in_a)
    emb_a = emb_a.view(in_a.size(0),-1)
   
    output = margin(emb_a,labels.long())
    
    loss = criterion(output, labels.long())
    optimizer.zero_grad()
    loss_collect +=loss.item()

    print("lost per batch {}".format(loss))

    loss.backward()
    optimizer.step()
    loss_list_b.append(loss)
    torch.save(model.state_dict(), os.path.join(PATH,str(epoch)))    

In [None]:
epochs = np.arange(1,n_epochs+1)
plt.plot(epochs, loss_list_b, color='pink')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.title('Training phase')
plt.show()

# Test Phase (has to be the same for all the verification loss methods)

In [None]:
def find_threshold(var,percentile):
    hist, bin_edges = np.histogram(var,100)
    cdf = np.float32(np.cumsum(hist))/np.sum(hist)
    bin_centers = (bin_edges[:-1]+bin_edges[1:])/2
    threshold = np.interp(percentile*0.01, cdf, bin_centers)    
    return threshold

In [None]:
thresholds = []
with torch.no_grad():
    for i in range(10):
        in_a,_ = dataset.prepare_batch()
        in_a = torch.stack(in_a)
        in_a = in_a.to(device)
        emb_a = model(in_a)  
        D = losses.euclidean_distance(emb_a.squeeze(),emb_a.squeeze()) #criterion.M_emb)
        #print(D)
        v = torch.zeros(D.size(0)).to(device).type(D.dtype)
        mask = torch.diag(torch.ones_like(v)).to(device).type(D.dtype)
        D_m = (mask * torch.diag(v) + (1. - mask) * D).cpu().detach().numpy()
        D_m = D_m[~np.eye(D_m.shape[0],dtype=bool)].reshape(D_m.shape[0],D_m.shape[1] - 1)
        thresholds.append(find_threshold(D_m,0.025))
    print(np.max(thresholds))


In [None]:
print(thresholds)
torch.set_printoptions(threshold=100000)
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%TEST%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
model.eval()
in_a_test,labels_test = dataset.prepare_batch_test()
in_a_test = torch.stack(in_a_test)
in_a_test = in_a_test.to(device)
labels_test = torch.Tensor(labels_test).to(device)
emb_a = model(in_a_test)

D = (losses.euclidean_distance(emb_a.squeeze(),emb_a.squeeze()))
#D_m = D[~np.eye(D.shape[0],dtype=int)].reshape(D.shape[0],-1)
v = torch.zeros(D.size(0)).to(device).type(D.dtype)
mask = torch.diag(torch.ones_like(v)).to(device).type(D.dtype)
D_m = (mask * torch.diag(v) + (1. - mask) * D).cpu().detach().numpy()
b = (D_m < np.mean(thresholds)).astype(int)

print(b)


In [None]:
def accuracy(feats,labels,threshold):
    D = losses.euclidean_distance(feats,feats)
    N = D.size(0)
    print(N)
    # shape [N, N]
    is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
    is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())
    
    # Exclude selfs for positive samples
    device = labels.device
    v = torch.zeros(N).to(device).type(is_pos.dtype)
    mask = torch.diag(torch.ones_like(v)).to(device).type(is_pos.dtype)
    is_pos = mask * torch.diag(v) + (1. - mask) * is_pos

    # `dist_ap` means distance(anchor, positive)
    dist_ap = D[is_pos].contiguous().view(N, -1)
    # `dist_an` means distance(anchor, negative)
    dist_an = D[is_neg].contiguous().view(N, -1)
    
    #threshold = torch.mean(dist_ap) + (torch.mean(dist_an) - torch.mean(dist_ap))/2
    
    positives_True =  0
    for i in dist_ap:
        for j in range(len(i)):
            if i[j].item() < threshold: positives_True += 1 
    negatives_True =  0
    for i in dist_an:
        for j in range(len(i)):
            if i[j].item() > threshold: negatives_True += 1   
    
    VAL = positives_True/dist_ap.numel()
    FAR = negatives_True/dist_an.numel()
    
    return positives_True, negatives_True, VAL, FAR, threshold

In [None]:
in_a,labels = dataset.prepare_batch_test()
in_a = torch.stack(in_a)
in_a = in_a.to(device)
labels = torch.Tensor(labels).to(device)
emb_a = model(in_a)
loss = criterion(emb_a.squeeze(),labels.long())
P_T, N_T, VAL, FAR,th = accuracy(emb_a.squeeze(),labels,np.max(thresholds))
print("Loss {}. Threshold {}: P_T {} N_T {} VAL {} FAR {}".format(loss, th, P_T, N_T, VAL, FAR))