In [None]:
from __future__ import print_function
import pandas as pd
import numpy as np
import csv
import os
import torch

import matplotlib
import matplotlib.pyplot as plt
import random

from PIL import Image
from skimage import io
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader

verbose = False

In [None]:
class AT_T_TripletGenerator(Dataset):    
    def __init__(self, dat_folder,train = True, transform = None):   
        super(AT_T_TripletGenerator, self).__init__()
        self.root_dir = os.path.dirname(dat_folder) 
        self.transform = transform
        
        #list of folders 
        self.ind = [individual for individual in os.listdir(self.root_dir) if os.path.isdir(os.path.join(self.root_dir,individual))]
        print("number of individuals {}".format(len(self.ind)))
        
        csv_file_train = os.path.join(self.root_dir, "data_verification_Train.csv")
        csv_file_test = os.path.join(self.root_dir, "data_verification_Test.csv")
            
        if not "data_verification_Train.csv" in os.listdir(self.root_dir) or not "data_verification_Test.csv" in os.listdir(self.root_dir):
        
            train_line_list = []
            test_line_list = [] 

            for i,individual in enumerate(self.ind):
                folder_path = os.path.join(self.root_dir,individual)
                name_images = [name for name in os.listdir(folder_path)]          
                random.shuffle(name_images)

                for j in range(len(name_images)-1):
                    if os.path.isfile(os.path.join(folder_path,name_images[j])):
                        anchor_filename = os.path.join(individual,name_images[j])
                        if verbose: print("anchor {}".format(j))
                        if verbose: print("anchor filename {}".format(anchor_filename))
                        positive_index = np.delete(np.arange(0,len(name_images)-1),j)                        
                        if verbose: print("positive {}".format(positive_index))
                        
                        negative_index_folder = random.sample(range(0,len(self.ind)-1),len(positive_index))
                        
                        if i in negative_index_folder:
                            if i == 0: 
                                negative_index_folder[negative_index_folder.index(i)] = i+1
                            else:
                                negative_index_folder[negative_index_folder.index(i)] = i-1
                        if verbose: print("negative folders {}".format(negative_index_folder))
                        
                        negative_index = []
                        for negative_folder in negative_index_folder:
                            num_images = len(os.listdir(os.path.join(self.root_dir,self.ind[negative_folder])))
                            negative_index.append(np.random.randint(num_images,size=1))
                                                
                        for bid in range(len(positive_index)):                   
                            positive_filename = os.path.join(individual,name_images[positive_index[bid]])

                            negative_folder = os.path.join(self.root_dir,self.ind[negative_index_folder[bid]])
                            negative_filename = os.path.join(self.ind[negative_index_folder[bid]],os.listdir(negative_folder)[int(negative_index[bid])])
                            
                            if verbose: print("positive filename {}".format(positive_filename))
                            if verbose: print("negative filename {}".format(negative_filename))
                            if j < (len(name_images)-1)*0.8:
                                line = [anchor_filename,positive_filename,negative_filename]
                                train_line_list.append([line])
                            else:
                                line = [anchor_filename,positive_filename,negative_filename]
                                test_line_list.append([line])

            with open(csv_file_train,"w") as f:
                writer = csv.writer(f,delimiter=',')
                writer.writerow(["Anchor","Positive", "Negative"])
                for lines in train_line_list:
                    writer.writerows(lines)

            with open(csv_file_test,"w") as f:
                writer = csv.writer(f,delimiter=',')
                writer.writerow(["Anchor","Positive", "Negative"])
                for lines in test_line_list:
                    writer.writerows(lines) 
    
        if train: self.AT_T_datafile = pd.read_csv(csv_file_train)
        else: self.AT_T_datafile = pd.read_csv(csv_file_test)
            
    def __len__(self):
        return len(self.AT_T_datafile)
    
    def __getitem__(self,idx):
        anchor = self.__loadfile(os.path.join(self.root_dir,self.AT_T_datafile.iloc[idx,0]))
        positive = self.__loadfile(os.path.join(self.root_dir,self.AT_T_datafile.iloc[idx,1]))
        negative = self.__loadfile(os.path.join(self.root_dir,self.AT_T_datafile.iloc[idx,2]))
        
        if self.transform:
            anchor = Image.fromarray(anchor)
            anchor = self.transform(anchor)
            positive = Image.fromarray(positive)
            positive = self.transform(positive)
            negative = Image.fromarray(negative)
            negative = self.transform(negative)
        
        return anchor,positive,negative
    
    def __loadfile(self, data_file):
        image = io.imread(data_file)
        if len(image.shape)<3:
            image = np.stack((image,)*3, axis=-1)
        return image

In [None]:
#folder = "/home/rita/JupyterProjects/EYE-SEA/DataSets/Verification/ATeT_faces/orl_faces/"
#dataset_transform = transforms.Compose([
#    transforms.Resize((224,224)),
#    transforms.ToTensor(),        
#    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
#])
#generator = AT_T_TripletGenerator(folder, transform = dataset_transform)
#
#dataloaders = torch.utils.data.DataLoader(generator,batch_size=4, shuffle=True)
#inputs = next(iter(dataloaders))

In [None]:
#print(len(inputs))