In [None]:
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.optim import Adam
import torch.nn.functional as F

import csv
from skimage import io

from PIL import Image
import pandas as pd

import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from torchsummary import summary

import matplotlib.pyplot as plt
import matplotlib as mpl
import time
import os
import copy

import import_ipynb
import ResNetCaps_E
import AT_T_triplet_generator
import LFW_triplet_generator

verbose = False
USE_CUDA = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
LFW_use = True
ATET_use = False

In [None]:

dataset_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),        
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
batch_size = 100

if ATET_use:
    folder = "/home/rita/JupyterProjects/EYE-SEA/DataSets/Verification/ATeT_faces/orl_faces/"
    triplet_generator = AT_T_triplet_generator.AT_T_TripletGenerator(folder, transform = dataset_transform,hold_out_positive=True)
    dataLoader_generator = torch.utils.data.DataLoader(triplet_generator,batch_size=batch_size, shuffle=True)
    triplet_generator_test = AT_T_triplet_generator.AT_T_TripletGenerator(folder,train = False, transform = dataset_transform,hold_out_positive=True)
    dataLoader_generator_test = torch.utils.data.DataLoader(triplet_generator_test,batch_size=batch_size, shuffle=True)
if LFW_use:
    folder = "/home/rita/JupyterProjects/EYE-SEA/DataSets/Verification/lfw"
    triplet_generator = LFW_triplet_generator.LFW_TripletGenerator(folder, transform = dataset_transform,hold_out_positive = True)
    dataLoader_generator = torch.utils.data.DataLoader(triplet_generator,batch_size=batch_size, shuffle=True)
    triplet_generator_test = LFW_triplet_generator.LFW_TripletGenerator(folder,train = False, transform = dataset_transform,hold_out_positive=True)
    dataLoader_generator_test = torch.utils.data.DataLoader(triplet_generator_test,batch_size=batch_size, shuffle=True)

In [None]:
#TripletLoss layer 
class TripletLossLayer(torch.nn.Module):
    def __init__(self,alpha):
        super(TripletLossLayer, self).__init__()
        self.ALPHA = alpha
        
    def triplet_loss(self,a,p,n):
        p_l2 = a-p
        p_dist = torch.sum(torch.mul(p_l2,p_l2))
        #print("p_dist {}".format(p_dist))
        n_l2 = a-n
        n_dist = torch.sum(torch.mul(n_l2,n_l2))
        #print("n_dist {}".format(n_dist))
        zero=torch.zeros([1,1], dtype=torch.float32, device=device)
        
        return [torch.max(p_dist-n_dist+self.ALPHA,zero),p_dist,n_dist]
    
    def forward(self,a,p,n):
        loss, p_dist, n_dist = self.triplet_loss(a,p,n)
        self.loss = loss
        return loss, p_dist, n_dist

In [None]:
model = ResNetCaps_E.ResNetCaps_E()
model = model.to(device)

In [None]:
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),lr = 0.001)
criterion = TripletLossLayer(0.2)
n_epochs = 10

p_dist_list_l = []
n_dist_list_l = []

p_dist_list_b = []
n_dist_list_b = []
loss_list_b = []

for epoch in range(n_epochs): 
    print('epoch {}:{}'.format(epoch+1, n_epochs)) 
    model.train()
    loss_collect = 0
    p_dist_collect = 0
    n_dist_collect = 0
    p_dist_list = []
    n_dist_list = []
    for batch_id, (in_a, in_p, in_n)  in enumerate(dataLoader_generator):
    #in_a, in_p, in_n = next(iter(dataLoader_generator))
        in_a = in_a.to(device)
        in_p = in_p.to(device)
        in_n = in_n.to(device) 
        #Compute embeddings for anchor, positive, and negative images

        emb_a = model(in_a)
        emb_p = model(in_p)
        emb_n = model(in_n)

        optimizer.zero_grad()

        loss,p_dist,n_dist = criterion(emb_a,emb_p,emb_n)
        loss_collect +=loss
        p_dist_collect +=  p_dist
        n_dist_collect +=  n_dist

                   
        if batch_id % 100 == 0:   
            print("p_dist {} n_dist {}".format(p_dist,n_dist))
            print("loss per batch {}".format(loss))
            p_dist_list.append(p_dist)
            n_dist_list.append(n_dist)

        loss.backward()
        optimizer.step()
        
    
    p_dist_list_l.append(p_dist_list)
    n_dist_list_l.append(n_dist_list)    

             
    p_dist_list_b.append(p_dist_collect/batch_id)
    n_dist_list_b.append(n_dist_collect/batch_id)
    loss_list_b.append(loss_collect/batch_id)
        

    

In [None]:
cmap_positive = mpl.cm.summer
cmap_negative = mpl.cm.autumn
plt.figure(figsize=(10,10))
for it_l in range(0,n_epochs):
    plt.plot(np.arange(1,len(p_dist_list_l[it_l])+1), p_dist_list_l[it_l],'+',color=cmap_positive(it_l / float(n_epochs)))
    plt.plot(np.arange(1,len(n_dist_list_l[it_l])+1), n_dist_list_l[it_l],'*',color=cmap_negative(it_l / float(n_epochs)))

plt.xlabel('batchs')
plt.ylabel('p_dist,n_dist')
plt.title('Training phase')
plt.show() 

In [None]:
epochs = np.arange(1,n_epochs+1)
plt.plot(epochs, p_dist_list_b, color='g')
plt.plot(epochs, n_dist_list_b, color='orange')
plt.xlabel('epochs')
plt.ylabel('p_dist,n_dist')
plt.title('Training phase')
plt.show() 

In [None]:
plt.plot(epochs, loss_list_b, color='pink')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.title('Training phase')
plt.show() 

In [None]:
batch_size = 4
triplet_generator_test = AT_T_triplet_generator.AT_T_TripletGenerator(folder,train = False, transform = dataset_transform,hold_out_positive=True)
dataLoader_generator_test = torch.utils.data.DataLoader(triplet_generator_test,batch_size=batch_size, shuffle=True)
size_test = len(dataLoader_generator_test)
#Test 
p_dist_list_Test =[]
n_dist_list_Test =[]
for i,(in_a,in_p,in_n) in enumerate(dataLoader_generator_test):
 

    for i in range(0,len(in_a)):
        in_a_s = in_a[i,:,:,:]
        in_p_s = in_p[i,:,:,:]
        in_n_s = in_n[i,:,:,:]

        #in_a, in_p, in_n = next(iter(dataLoader_generator))
        in_a_s = in_a_s.to(device)
        in_p_s = in_p_s.to(device)
        in_n_s = in_n_s.to(device) 
        #Compute embeddings for anchor, positive, and negative images

        emb_a = model(in_a_s.unsqueeze(0))
        emb_p = model(in_p_s.unsqueeze(0))
        emb_n = model(in_n_s.unsqueeze(0))

        loss,p_dist,n_dist = criterion(emb_a,emb_p,emb_n)
        p_dist_list_Test.append(p_dist)
        n_dist_list_Test.append(n_dist)
        


In [None]:
print(sum(p_dist_list_Test)/len(p_dist_list_Test))
print(sum(n_dist_list_Test)/len(n_dist_list_Test))

plt.plot(range(0,len(p_dist_list_Test)), p_dist_list_Test, color='pink')
#plt.plot(range(0,len(n_dist_list)), n_dist_list, color='g')
plt.xlabel('elements')
plt.ylabel('p_dist_value')
plt.title('test')
plt.show() 

In [None]:
#plt.plot(range(0,len(p_dist_list)), p_dist_list, color='pink')
plt.plot(range(0,len(n_dist_list_Test)), n_dist_list_Test, color='g')
plt.xlabel('elements')
plt.ylabel('n_dist_value')
plt.title('test')
plt.show() 


In [None]:
threshold = 0.3
positives_True =  0
for i in p_dist_list_Test:
    if i < threshold: positives_True += 1 
negatives_True =  0
for i in n_dist_list_Test:
    if i > threshold: negatives_True += 1     

print("TA({}) {} out of P_same {}".format(threshold,positives_True,len(p_dist_list_Test)))
print("FA({}) {} out of P_diff {}".format(threshold,negatives_True,len(n_dist_list_Test)))
print("VAL({}) (TP/P)  {}".format(threshold,(positives_True/len(p_dist_list_Test))))
print("FAR({}) (TN/N)  {}".format(threshold,(negatives_True/len(n_dist_list_Test))))


In [None]:
#Test only one batch

in_a,in_p,in_n = next(iter(dataLoader_generator_test))

p_dist_list =[]
n_dist_list =[]

for i in range(0,batch_size):
    in_a_s = in_a[i,:,:,:]
    in_p_s = in_p[i,:,:,:]
    in_n_s = in_n[i,:,:,:]
    
    #in_a, in_p, in_n = next(iter(dataLoader_generator))
    in_a_s = in_a_s.to(device)
    in_p_s = in_p_s.to(device)
    in_n_s = in_n_s.to(device) 
    #Compute embeddings for anchor, positive, and negative images

    emb_a = model(in_a_s.unsqueeze(0))
    emb_p = model(in_p_s.unsqueeze(0))
    emb_n = model(in_n_s.unsqueeze(0))

    loss,p_dist,n_dist = criterion(emb_a,emb_p,emb_n)
    p_dist_list.append(p_dist)
    n_dist_list.append(n_dist)

fig=plt.figure(figsize=(8, 8))

columns = 3
rows = batch_size
in_a_np = in_a.cpu().numpy()
in_p_np = in_p.cpu().numpy()
in_n_np = in_n.cpu().numpy()
j = 1
for i in range(1, rows +1):
    fig.add_subplot(rows, columns, j)
    j+=1
    #anchor
    img_a = in_a_np[i-1,:,:,:]
    plt.imshow(np.transpose(img_a,[1,2,0]))   
    plt.axis('off') 
    fig.add_subplot(rows, columns, j)
    j+=1
    #positive
    img_p = in_p_np[i-1,:,:,:]
    plt.imshow(np.transpose(img_p,[1,2,0]))
    plt.title(str(p_dist_list[i-1].detach().cpu().numpy()))
    plt.axis('off') 
    fig.add_subplot(rows, columns, j)
    j+=1
    #positive
    img_n = in_n_np[i-1,:,:,:]
    plt.imshow(np.transpose(img_n,[1,2,0]))
    plt.title(str(n_dist_list[i-1].detach().cpu().numpy()))        
    plt.axis('off') 
plt.show()

# # Try with single images

In [None]:
def __loadfile(data_file):

    image = io.imread(data_file)
    if len(image.shape)<3:
        image = np.stack((image,)*3, axis=-1)
    return image

def triplet_generator(transform):
    i_a = Image.fromarray(__loadfile("/home/rita/JupyterProjects/EYE-SEA/Verification_RNCAPS/Images_examples/anchor.JPG"))
    b_a = transform(i_a)
    i_p = Image.fromarray(__loadfile("/home/rita/JupyterProjects/EYE-SEA/Verification_RNCAPS/Images_examples/positive.jpg"))
    b_p = transform(i_p)  
    i_n = Image.fromarray(__loadfile("/home/rita/JupyterProjects/EYE-SEA/Verification_RNCAPS/Images_examples/negative.JPG"))
    b_n = transform(i_n)
    return b_a, b_p, b_n

In [None]:
#INPUT: anchor, positive, and negative

dataset_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),        
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
in_a, in_p, in_n = triplet_generator(dataset_transform)
#in_a, in_p, in_n = Variable(in_a), Variable(in_p),Variable(in_n)
in_a = in_a.to(device)
in_p = in_p.to(device)
in_n = in_n.to(device)     
in_a = in_a.unsqueeze(0)
in_p = in_p.unsqueeze(0)
in_n = in_n.unsqueeze(0)

In [None]:
model = ResNetCaps_E.ResNetCaps_E()
model = model.to(device)

In [None]:
summary(model,(3,224,224))

In [None]:
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),lr = 0.001)
criterion = TripletLossLayer(0.2)
n_epochs = 100

p_dist_list = []
n_dist_list = []
loss_list = []

for epoch in range(n_epochs): 
    model.train()
    print('epoch {}:{}'.format(epoch+1, n_epochs)) 
    emb_a = model(in_a)
    emb_p = model(in_p)
    emb_n = model(in_n)
    
    optimizer.zero_grad()
    
    loss,p_dist,n_dist = criterion(emb_a,emb_p,emb_n)
    p_dist_list.append(p_dist)
    n_dist_list.append(n_dist)
    loss_list.append(loss)
    print("loss {}".format(loss))
    loss.backward()
    optimizer.step()
    
   

In [None]:
emb_a = model(in_a)
emb_p = model(in_p)
emb_n = model(in_n)

print(emb_a)
print(emb_p)
print(emb_n)

In [None]:
epochs = np.arange(1,n_epochs+1)
plt.plot(epochs, p_dist_list, color='g')
plt.plot(epochs, n_dist_list, color='orange')
plt.xlabel('epochs')
plt.ylabel('p_dist,n_dist')
plt.title('Training phase')
plt.show() 

In [None]:
plt.plot(epochs, loss_list, color='pink')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.title('Training phase')
plt.show() 