In [None]:
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.optim import Adam
import torch.nn.functional as F

import csv
from skimage import io

from PIL import Image
import pandas as pd

import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from torchsummary import summary

import matplotlib.pyplot as plt
import matplotlib as mpl
import time
import pathlib
import os
import copy
from datetime import date

import import_ipynb
import ResNetCaps_E
import Dataset_Loader
import losses

verbose = False
load_model = False
LFW_use = False
ATET_use = True
folder = 'ATET'

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print('CUDA')

In [None]:
dataset_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),        
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])

#folderDataset = "/home/rita/JupyterProjects/EYE-SEA/DataSets/Verification/ATeT_faces/orl_faces/"
folderDataset = "/home/rita/JupyterProjects/EYE-SEA/DataSets/Verification/lfw/"
#folderDataset = "/media/Data/rita/EYE-SEA/Verification/Datasets/ATeT_faces/orl_faces/"
batch_size = 130

Train_loader = Dataset_Loader.Folded_Dataset(folderDataset, dataset_transform,0.8)
dataLoader_generator = torch.utils.data.DataLoader(Train_loader,batch_size=batch_size)
Test_loader = Dataset_Loader.Folded_Dataset(folderDataset, dataset_transform,0.8,train = False)
dataLoader_generator_test = torch.utils.data.DataLoader(Test_loader,batch_size=batch_size)

In [None]:
PATH= os.path.join(os.getcwd(),os.path.join('Log_model/Cluster_Loss',folder,'DIGIT/',(date.today()).isoformat()))
pathlib.Path(PATH).mkdir(parents=True, exist_ok=True)
if len(os.listdir(PATH)) > 2 and load_model:
    print('Loading model from PATH: {}'.format(PATH))
    model = ResNetCaps_E.ResNetCaps_E(DigitEnd=False)
    if pick_model == -1:
        init = len(os.listdir(PATH))-2        
        model.load_state_dict(torch.load(os.path.join(PATH,str(init))))
    else:
        model.load_state_dict(torch.load(os.path.join(PATH,str(pick_model))))
        init = pick_model
    model.eval()
else:
    print('Creating a new model')
    init=0
    model = ResNetCaps_E.ResNetCaps_E()
    
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    model = nn.DataParallel(model)
model = model.to(device)

In [None]:
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),lr = 0.1)

In [None]:
class ClusterLoss(nn.Module):
    def __init__(self,alpha=0.2):
        super(ClusterLoss,self).__init__()
        self.alpha = alpha
        self.ranking_loss = nn.SoftMarginLoss()
        
        self.clusters_sum = []
        self.clusters_count = []
        self.clusters_labels = []
        
    def forward(self,feats,targets):       
       
        t_intra,D_intra = self.Euclidean_intra(feats,targets)
        t_inter,D_inter = self.Euclidean_inter(targets)
        
        Y = (torch.Tensor(t_intra).data.new().resize_as_(torch.Tensor(t_intra).data).fill_(1))
        Y = Y.to(device)
        loss = self.ranking_loss((D_intra-D_inter)+self.alpha,Y)
        
        return loss
        
    def mean_feats(self,feats,targets):
        N = feats.size(0)
        
        # shape [N, N]
        is_pos = targets.expand(N, N).eq(targets.expand(N, N).t())
        is_neg = targets.expand(N, N).ne(targets.expand(N, N).t())
        
        target_batch = []
        for i in range(N):
            t = targets[i]
            if not t in target_batch:
                a = feats[is_pos[:,i],:]#list of features computed over the same individual
                sum_a = torch.sum(a,dim=0)              

                if self.clusters_sum:
                    if self.clusters_labels:                       
                        if t in self.clusters_labels:
                            j = self.clusters_labels.index(t.item())
                            self.clusters_sum[j] += sum_a
                            self.clusters_count[j] += a.size(0)
                        else:
                            self.clusters_sum.append(sum_a)
                            self.clusters_labels.append(t)
                            self.clusters_count.append(a.size(0))
                    else: 
                        print('There are no labels {}'.format(M_label))
                else:

                    self.clusters_sum.append(sum_a)
                    self.clusters_labels.append(t.item())
                    self.clusters_count.append(a.size(0))
                target_batch.append(t)
                
    def mean_feats_compute(self):
       
        M_emb = [sum_a/len_a for sum_a,len_a in zip(self.clusters_sum,self.clusters_count)]
        self.M_emb = torch.stack(M_emb).to(device)
                
    def Euclidean_intra(self,feats,targets):
        
        M_intra = self.M_emb
        D = losses.euclidean_distance(feats,M_intra)
        N = feats.size(0)
                       
        is_pos = targets.expand(N, N).eq(targets.expand(N, N).t())
        
        target_intra = []
        D_intra = []
        for i in range(N):          
            if not targets[i].item() in target_intra:
                D_id = D[is_pos[:,i],:]            
                target_intra.append(targets[i].item())
                index_mean = self.clusters_labels.index(targets[i].item())
                D_intra.append(torch.max(D_id[:,index_mean]))
        
        D_intra = torch.stack(D_intra)
        D_intra = D_intra.to(device)
        return target_intra, D_intra
    
    def Euclidean_inter(self,targets):
        M_intra = self.M_emb
        N = targets.size(0)
        is_neg = targets.expand(N, N).ne(targets.expand(N, N).t())
        target_inter = []
        D_inter= []
        for i in range(N):
            if not targets[i].item() in target_inter:
                index_mean = self.clusters_labels.index(targets[i].item())
                M = M_intra[index_mean,:]  
                target_inter.append(targets[i].item())  
                list_inter = []
                for j in range(len(M_intra)):
                    if not j == index_mean:
                        X = (M_intra[j,:])
                        list_inter.append(torch.pairwise_distance(M.unsqueeze(1),X.unsqueeze(1),2))
                D_inter.append(torch.min(torch.stack(list_inter)))

        
        D_inter = torch.Tensor(D_inter)
        D_inter = D_inter.to(device)
        D_inter.requires_grad_()
        return target_inter, D_inter
    
    def classification(self, feats):
        M_intra = self.M_emb
        D = losses.euclidean_distance(feats,M_intra)
        N = feats.size(0)
        classification = []
        for j in range(N):
            classification.append((D[j,:] == torch.min(D[j,:])).nonzero())
            
        return classification

                

In [None]:
#dataset = MiniBatch_generator_2.mini_batch(folderDataset, dataset_transform,0.8)
criterion = ClusterLoss()
#compute the mean embedding value
with torch.no_grad():
    for batch_id, (in_train, labels_train) in enumerate(dataLoader_generator):
        in_train = in_train.to(device)    
        emb_train = model(in_train)
        criterion.mean_feats(emb_train.squeeze(),labels_train)
    criterion.mean_feats_compute()  
    

In [None]:
n_epochs = 10
loss_list = []
start = time.time()
for epoch in range(n_epochs): 
    print('epoch {}:{}'.format(epoch+1, n_epochs)) 
    loss_list_b = []
    for batch_id, (in_a,labels)  in enumerate(dataLoader_generator):
        int_a = Variable(in_a)
        in_a = in_a.to(device)
        labels = labels.to(device)
        model.train()
        emb_a = model(in_a)
        
        optimizer.zero_grad()
        loss = criterion(emb_a.squeeze(),labels)

        if batch_id%10==0:
            print("loss per batch {}".format(loss))
        
        loss.backward()
        optimizer.step()
        
        loss_list_b.append(loss)
    loss_list.append(sum(loss_list_b)/batch_id)
#torch.save(model.state_dict(), os.path.join(PATH,str(epoch)))
end = time.time()
print('Training time: {} s'.format(end - start))
print('mean time per epoch: {} s'.format((end - start)/(n_epochs - 0)))

In [None]:
print(criterion.M_emb)

In [None]:
with torch.no_grad():
    for batch_id, (in_train, labels_train) in enumerate(dataLoader_generator):
        in_train = in_train.to(device)    
        emb_train = model(in_train)
        criterion.mean_feats(emb_train.squeeze(),labels_train)
    criterion.mean_feats_compute()  

In [None]:
print(criterion.M_emb)
print(criterion.M_emb.size())


In [None]:
epochs = np.arange(1,n_epochs+1)
plt.plot(epochs, loss_list, color='pink')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.title('Training phase')
plt.show() 

In [None]:
#@@@@@@@@@@@@@@@@@@@@@@@@@@@@TEST@@@@@@@@@@@@@@@@@@@@@@@@@@@@@#

for batch_id, (in_a_test,labels_test)  in enumerate(dataLoader_generator_test):

    in_a_test = in_a_test.to(device)
    labels_test = labels_test.to(device)
    emb_a_test = model(in_a_test)
        
    loss_test = criterion(emb_a_test.squeeze(),labels_test)

    if batch_id%10==0:
        print("loss per batch {}".format(loss_test))
        
    classification_test = criterion.classification(emb_a_test.squeeze())

        
    #classification = torch.stack(classification).squeeze()
    print(classification_test)
    print(labels_test)
    print(torch.sum(torch.eq(classification_test,labels_test)))

In [None]:
print(classification_test)
print(labels_test)


In [None]:
torch.cuda.empty_cache()