In [None]:
import os
import torch
import random
from tqdm import tqdm
import numpy as np
from torch import nn
import matplotlib.pyplot as plt
import pickle
import shutil
import cv2
from sklearn.metrics import confusion_matrix
from torch_mtcnn import detect_faces
from PIL import Image

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print('Device = %s' %(torch.cuda.get_device_name(device) if torch.cuda.is_available() else 'cpu'))

## Generate the dataset here

In [None]:
class VGG_1(torch.utils.data.Dataset):
    def __init__(self, folder_path, max_samples, identity_imgs=10):
        paths = {}

        for gender in os.listdir(folder_path):
            gender_path = os.path.join(folder_path, gender)
            if os.path.isdir(gender_path):
                paths[gender] = {} # contains identities

                for iden in os.listdir(gender_path):
                    iden_path = os.path.join( gender_path, iden )
                    if not os.path.isdir(iden_path):
                        continue

                    # add some pictures form that specific identity
                    paths[gender][iden] = []
                    for i in os.listdir(iden_path):
                        if os.path.isdir( os.path.join(iden_path, i) ):
                            continue
                        if len(paths[gender][iden]) >= identity_imgs:
                            break
                        paths[gender][iden].append(os.path.join(iden_path, i))

        # create random data
        self.pairs = []
        for i in tqdm(range(max_samples)):
            # select a gender -> can be changed to make the dataset biased
            if random.random() <= 0.5:
                # label 1: same identities
                gender = random.choice(list(paths.keys()))
                iden = random.choice(list(paths[gender].keys()))
                img_1 = random.choice(paths[gender][iden])
                img_2 = random.choice(paths[gender][iden])
                self.pairs.append((
                    img_1,
                    img_2,
                    0,
                    gender, gender,
                    img_1, img_2, 'VGG_1'
                ))
            else:
                # label 0: different identities and maybe different genders as well!
                gender_1 = random.choice(list(paths.keys()))
                iden_1 = random.choice(list(paths[gender_1].keys()))
                gender_2 = random.choice(list(paths.keys()))
                iden_2 = random.choice(list(paths[gender_2].keys()))
                img_1 = random.choice(paths[gender_1][iden_1])
                img_2 = random.choice(paths[gender_2][iden_2])
                self.pairs.append((
                    img_1,
                    img_2,
                    1,
                    gender_1,
                    gender_2,
                    img_1,
                    img_2, 'VGG_1'
                ))
        # print data dist.
        counts = {1:0, 0:0}
        for p in self.pairs:
            counts[p[2]] += 1
        print('\n\nCounts 0 = %d | 1 = %d' %(counts[0], counts[1]))

        # preload images if needed!
        self.img_pairs = []
        self.loaded = False

    def load_images(self):
        norm_img = np.zeros((100, 100))

        # load the images from self.pairs
        for pair in tqdm(self.pairs):
            img_1 = cv2.imread(pair[0])
            img_2 = cv2.imread(pair[1])

            # self.img_pairs.append((
                # cv2.normalize(cv2.resize(cv2.cvtColor(img_1, cv2.COLOR_BGR2GRAY), (100, 100)), norm_img, 0, 255, cv2.NORM_MINMAX),
                # cv2.normalize(cv2.resize(cv2.cvtColor(img_2, cv2.COLOR_BGR2GRAY), (100, 100)), norm_img, 0, 255, cv2.NORM_MINMAX),
                # pair[2], pair[3], pair[4], pair[5], pair[6]
            # ))
            self.img_pairs.append((
                cv2.normalize(cv2.resize(img_1, (100, 100)), norm_img, 0, 255, cv2.NORM_MINMAX),
                cv2.normalize(cv2.resize(img_2, (100, 100)), norm_img, 0, 255, cv2.NORM_MINMAX),
                pair[2], pair[3], pair[4], pair[5], pair[6]
            ))
        self.loaded = True

    def __getitem__(self, ind):
        if not self.loaded:
            return self.pairs[ind]
        return self.img_pairs[ind]

    def __len__(self):
        return len(self.pairs)


##########   FaceDatabases   ##########
class FaceDB(torch.utils.data.Dataset):
    def __init__(self, path_dict, pre_load=False):
        '''
        path_dict = {
            'iranian_women': {path: str, samples: int},
            'multiview_01' : {path: str, samples: int}
            'vgg_1         : {path: str, samples: int},
        }
        '''

        self.pairs = []

        # get the VGG data by default
        dataset = VGG_1(path_dict['vgg_1']['path'], path_dict['vgg_1']['samples'])
        self.pairs.extend(dataset.pairs)
        del dataset

        # read from other databases
        self.pairs.extend(self.iranian_woman_db(path_dict['morphed']['path'], path_dict['morphed']['samples'], 'morphed'))
        self.pairs.extend(self.multi_pie_multiview(path_dict['multiview_01']['path'], path_dict['multiview_01']['samples'], 'multiview_01'))
        self.pairs.extend(self.iranian_woman_db(path_dict['iranian_women']['path'], path_dict['iranian_women']['samples'], 'iranian_women'))
        self.pairs.extend(self.multi_pie_multiview(path_dict['multiview_02']['path'], path_dict['multiview_02']['samples'], 'multiview_02'))

        # print data dist.
        counts = {1:0, 0:0}
        for p in self.pairs:
            counts[p[2]] += 1
        print('\n\nCounts 0 = %d | 1 = %d' %(counts[0], counts[1]))

        # preload images if needed!
        self.img_pairs = []
        self.loaded = False

    def multi_pie_multiview(self, path, no_samples, name, camera_list=['13_0', '14_0', '05_1', '05_0', '04_1'], session='01'):
        # get the identities
        identities = {}
        pairs = []

        for iden in os.listdir(path):
            if not os.path.isfile(iden):
                # consider only session 01 for now
                images = []
                for camera in camera_list:
                    for i in os.listdir(os.path.join(path, iden, session, camera)):
                        images.append(os.path.join(path, iden, session, camera, i))
                if len(images) > 0:
                    identities[iden] = images

        # loop to create image pairs
        for i in tqdm(range(no_samples)):
            if random.random() <= 0.5:
                # label 1: same identities
                iden = random.choice(list(identities.keys()))
                img_1 = random.choice(identities[iden])
                img_2 = random.choice(identities[iden])
                pairs.append((
                    img_1,
                    img_2,
                    0,
                    'unknown', 'unknown',
                    img_1, img_2, name + session
                ))
            else:
                # label 0: different identities and maybe different genders as well!
                iden_1 = random.choice(list(identities.keys()))
                iden_2 = random.choice(list(identities.keys()))
                while iden_1 == iden_2:
                    iden_2 = random.choice(list(identities.keys()))

                img_1 = random.choice(identities[iden_1])
                img_2 = random.choice(identities[iden_2])
                pairs.append((
                    img_1,
                    img_2,
                    1,
                    'unknown', 'unknown',
                    img_1,
                    img_2, name + session
                ))
        return pairs

    def iranian_woman_db(self, path, no_samples, name):
        # get the identities
        identities = {}
        pairs = []

        for iden in os.listdir(path):
            if not os.path.isfile(iden):
                # identities.append(os.path.join(path, iden))
                images = []
                for i in os.listdir(os.path.join(path, iden)):
                    if not ('_2' in i or '_3' in i or '_4' in i or '_5' in i):
                        images.append(os.path.join(path, iden, i))
                identities[iden] = images

        # loop to create image pairs
        for i in tqdm(range(no_samples)):
            if random.random() <= 0.5:
                # label 1: same identities
                iden = random.choice(list(identities.keys()))
                img_1 = random.choice(identities[iden])
                img_2 = random.choice(identities[iden])
                pairs.append((
                    img_1,
                    img_2,
                    0,
                    'female', 'female',
                    img_1, img_2, name
                ))
            else:
                # label 0: different identities and maybe different genders as well!
                iden_1 = random.choice(list(identities.keys()))
                iden_2 = random.choice(list(identities.keys()))
                while iden_1 == iden_2:
                    iden_2 = random.choice(list(identities.keys()))

                img_1 = random.choice(identities[iden_1])
                img_2 = random.choice(identities[iden_2])
                pairs.append((
                    img_1, img_2,
                    1,
                    'female', 'female',
                    img_1,
                    img_2, name
                ))
        return pairs

    def crop_image(self, path):
        image = cv2.imread(path)
        res = detect_faces(image)
        if res is None:
            return None
        bounding_boxes, landmarks = res[0], res[1]

        if len(bounding_boxes) == 0 or len(bounding_boxes) > 1:
            return None

        # else, crop the image
        bounding_boxes = list(map(int, bounding_boxes[0]))
        return image[ bounding_boxes[1] : bounding_boxes[3], bounding_boxes[0] : bounding_boxes[2]]

    def load_images(self):
        norm_img = np.zeros((100, 100))
        self.idx = {}
        idx_pair = -1
        idx_loaded = 0

        # load the images from self.pairs
        for pair in tqdm(self.pairs):
            idx_pair += 1
            # crop image if needed
            if 'VGG_1' == pair[7]:
                img_1 = cv2.imread(pair[0])
                img_2 = cv2.imread(pair[1])
            else:
                img_1 = self.crop_image(pair[0])
                img_2 = self.crop_image(pair[1])

            if img_1 is None or img_2 is None:
                continue
            
            try:
                # self.img_pairs.append((
                #    cv2.normalize(cv2.resize(cv2.cvtColor(img_1, cv2.COLOR_BGR2GRAY), (100, 100)), norm_img, 0, 255, cv2.NORM_MINMAX),
                #    cv2.normalize(cv2.resize(cv2.cvtColor(img_2, cv2.COLOR_BGR2GRAY), (100, 100)), norm_img, 0, 255, cv2.NORM_MINMAX),
                #    pair[2], pair[3], pair[4], pair[5], pair[6], pair[7]
                # ))
                self.img_pairs.append((
                    cv2.normalize(cv2.resize(img_1, (100, 100)), norm_img, 0, 255, cv2.NORM_MINMAX),
                    cv2.normalize(cv2.resize(img_2, (100, 100)), norm_img, 0, 255, cv2.NORM_MINMAX),
                    pair[2], pair[3], pair[4], pair[5], pair[6], pair[7]
                ))
                self.idx[idx_loaded] = idx_pair
                idx_loaded += 1
            except:
                pass

        self.loaded = True

    def __getitem__(self, ind):
        if not self.loaded:
            return self.pairs[ind]
        return self.img_pairs[ind]

    def __len__(self):
        return len(self.img_pairs)


In [None]:
if not os.path.exists('./data_cache_rgb.pkl'):
    path_dict = {
        'iranian_women': {'path': 'D:\\Dropbox Main\\Dropbox\\datasets\\12_Iranianwomendb', 'samples': 1},
        'multiview_01' : {'path': 'data\\session01\\multiview', 'samples': 1},
        'vgg_1'        : {'path': 'dataset', 'samples': 200000},
        'multiview_02' : {'path': 'data\\session02\\multiview', 'samples': 1},
        'morphed': {'path': 'D:\\Dropbox Main\\Dropbox\\datasets\\1_Morphdb', 'samples': 1}
    }

    dataset = FaceDB(path_dict)
    dataset.load_images()
    pickle.dump(dataset, open('data_cache_rgb.pkl', 'wb'))
else:
    dataset = pickle.load(open('data_cache_rgb.pkl', 'rb'))

In [None]:
# plot some of the pictures
for i in range(5):
    r_ind = random.randint(0, len(dataset) - 1)

    plt.subplot(121)
    plt.axis("off")
    plt.imshow(dataset[r_ind][0])
    plt.subplot(122)
    plt.axis("off")
    plt.text(75, 8, 'label %d' %dataset[r_ind][2], style='italic',fontweight='bold',
            bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})
    plt.imshow(dataset[r_ind][1])
    plt.show()

In [None]:
# split the dataset into different sets -> train and test for now
test_portion = 0.2

test_size = int(test_portion * len(dataset))
train_size = len(dataset) - test_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# create dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128)

print('Loaders Size = %d | %d' %(len(train_loader), len(test_loader)))

In [None]:
len(dataset.img_pairs)

## Create the model

In [None]:
class Siamese_Net(nn.Module):
    def __init__(self):
        super(Siamese_Net, self).__init__()

        self.net = nn.Sequential(
            nn.Conv2d(3, 64, 10),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),

            nn.Conv2d(64, 128, 7),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),

            nn.Conv2d(128, 128, 4),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout2d(0.2),
            nn.MaxPool2d((2, 2)),

            nn.Conv2d(128, 256, 4),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Dropout2d(0.1),
            nn.MaxPool2d((2, 2)),

            # FC
            nn.Flatten(),
            nn.Linear(256 * 2 * 2, 4096),
            nn.Sigmoid()
        )

        self.classifier = nn.Sequential(
            nn.Linear(4096, 2),
            nn.Softmax(dim=-1)
        )

    def forward(self, x1, x2):
        o1, o2 = self.net(x1), self.net(x2)
        return self.classifier( torch.abs(o1 - o2) )

# define the loss function
class ContrastiveLoss(torch.nn.Module):
      def __init__(self, margin=2.0):
            super(ContrastiveLoss, self).__init__()
            self.margin = margin

      def forward(self, output1, output2, label):
            # Find the pairwise distance
            euclidean_distance = torch.functional.F.pairwise_distance(output1, output2)
            # perform contrastive loss calculation with the distance
            loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) + (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))

            return loss_contrastive

def test(model, data_loader):
    model.eval()
    preds_labels = []
    true_labels = []

    for batch in data_loader:
        preds = model(batch[0].float().permute(0, 3, 1, 2).to(device), batch[1].float().permute(0, 3, 1, 2).to(device))
        dists = torch.ones(preds.shape)
        dists[torch.where(torch.abs(preds) < 0.5)] = 0
        dists[torch.where(torch.abs(preds) > 0.5)] = 1
        preds_labels.extend(dists.squeeze(1).tolist())
        true_labels.extend( batch[2].tolist() )
    tn, fp, fn, tp = confusion_matrix(true_labels, preds_labels).ravel()
    return (tp + tn) / (tn + fp + fn + tp)

# create the model
model = Siamese_Net().to(device)

In [None]:
iterations = 25
loss_history = []
optimizer = torch.optim.Adam(model.parameters(), lr=6e-4)
loss_func = nn.CrossEntropyLoss() # ContrastiveLoss()


for i in range(iterations):
    model.train()

    for batch in tqdm(train_loader, desc='Batch %d' %i, total=len(train_loader)):
        optimizer.zero_grad()

        x0 = batch[0].float().permute(0, 3, 1, 2).to(device)
        x1 = batch[1].float().permute(0, 3, 1, 2).to(device)

        preds = model(x0, x1).squeeze(-1)
        # calc loss
        loss = loss_func(preds, batch[2].to(device))

        # update model's params
        loss.backward()
        optimizer.step()
 
    # test the model after several steps
    print('Iteration %d - Loss = %0.8f' %(i, loss.item()))
    loss_history.append(loss.item())

    if i != 0 and i % 2 == 0:
        # save the model
        # torch.save(model.state_dict(), './model_%d.pt' %(i))
        # save the model
        torch.save({
            'epoch': i,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, 'model_rgb_v2_%d.pt' %(i))
        #print('\tAccuracies = %0.3f | %0.3f' %(test(model, train_loader), test(model, test_loader)))
#print('\tAccuracies = %0.3f | %0.3f' %(test(model, train_loader), test(model, test_loader)))

# save the model
torch.save({
    'epoch': i,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss
    }, 'model_rgb_v2.pt')

In [None]:
# draw the loss hostory
plt.plot(loss_history)

## Viz

In [None]:
import torchvision

def imshow(img, text=None, should_save=False):
    npimg = img.numpy()
    plt.subplot(121)
    plt.axis("off")
    plt.imshow(npimg[0])
    plt.subplot(122)
    plt.axis("off")
    plt.text(75, 8, text, style='italic',fontweight='bold',
            bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})
    plt.imshow(npimg[1])
    plt.show()

# load the model
model = Siamese_Net().to(device)
checkpoint = torch.load('model_rgb.pt')
model.load_state_dict(checkpoint['model_state_dict'])

model.eval()
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1)
dataiter = iter(test_dataloader)
x0,_,_, _, _, _, _, _ = next(dataiter)
next(dataiter)

for i in range(20):
    x0, x1, label, _, _, _, _, _ = next(dataiter)
    concatenated = torch.cat((x0,x1), 0)
    
    preds = model(x0.float().permute(0, 3, 1, 2).to(device), x1.float().permute(0, 3, 1, 2).to(device))
    imshow(concatenated,'Dissimilarity: {:.2f} | {:d}'.format(preds.to('cpu').detach().numpy()[0][0], label.item()))


In [None]:
for iteration in range(5):
    print('ITER', iteration)
    # get current index
    mis_dirs = list(map(int, [i for i in os.listdir('D:/Dropbox Main/Dropbox/datasets/Trials/mismatch/') if not os.path.isfile(os.path.join('D:/Dropbox Main/Dropbox/datasets/Trials/mismatch', i))]))
    mismatch_idx = max(mis_dirs) + 1 if len(mis_dirs) > 0 else 0
    m_dirs = list(map(int, [i for i in os.listdir('D:/Dropbox Main/Dropbox/datasets/Trials/match/') if not os.path.isfile(os.path.join('D:/Dropbox Main/Dropbox/datasets/Trials/match', i))]))
    match_idx = max(m_dirs) + 1 if len(m_dirs) > 0 else 0
    print(mismatch_idx, match_idx)

    ## Get the required image pairs
    model = Siamese_Net().to(device)
    saved_weights = torch.load('./model.pt')
    model.load_state_dict(saved_weights['model_state_dict'])
    model.eval()

    # generate several images
    path_dict = {
        'iranian_women': {'path': 'D:\\Dropbox Main\\Dropbox\\datasets\\12_Iranianwomendb', 'samples': 0},
        'multiview_01' : {'path': 'data\\session01\\multiview', 'samples': 1000},
        'vgg_1'        : {'path': 'dataset', 'samples': 500},
        'multiview_02' : {'path': 'data\\session02\\multiview', 'samples': 1000},
        'morphed': {'path': 'D:\\Dropbox Main\\Dropbox\\datasets\\1_Morphdb', 'samples': 5000}
    }

    dataset = FaceDB(path_dict)
    dataset.load_images()
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=64)

    for batch in data_loader:
        preds_labels = []
        true_labels = []

        preds = model(batch[0].unsqueeze(1).float().to(device), batch[1].unsqueeze(1).float().to(device))
        dists = torch.ones(preds.shape)
        dists[torch.where(torch.abs(preds) < 0.5)] = 0
        dists[torch.where(torch.abs(preds) > 0.5)] = 1
        preds_labels.extend(dists.squeeze(1).tolist())
        true_labels.extend( batch[2].tolist() )

        for idx in range(len(true_labels)):
            if true_labels[idx] == 1 and preds_labels[idx] == 0 and abs(preds[idx] - 0.5) <= 0.3:
                # copy files
                os.mkdir('D:/Dropbox Main/Dropbox/datasets/Trials/mismatch/' + str(mismatch_idx))
                shutil.copy(batch[5][idx], 'D:/Dropbox Main/Dropbox/datasets/Trials/mismatch/' + str(mismatch_idx) + '/' + batch[5][idx].split('\\')[-1])
                shutil.copy(batch[6][idx], 'D:/Dropbox Main/Dropbox/datasets/Trials/mismatch/' + str(mismatch_idx) + '/' + batch[6][idx].split('\\')[-1])
                fout = open('D:/Dropbox Main/Dropbox/datasets/Trials/mismatch/' + str(mismatch_idx) + '/info.txt', 'w')
                fout.write('Gender = %s\nDB Name = %s\n' %(batch[3][idx], batch[-1][idx]))
                fout.close()
                mismatch_idx += 1

            elif true_labels[idx] == 0 and preds_labels[idx] == 1 and abs(preds[idx] - 0.5) <= 0.1:
                # copy files
                os.mkdir('D:/Dropbox Main/Dropbox/datasets/Trials/match/' + str(match_idx))
                shutil.copy(batch[5][idx], 'D:/Dropbox Main/Dropbox/datasets/Trials/match/' + str(match_idx) + '/' + batch[5][idx].split('\\')[-1])
                shutil.copy(batch[6][idx], 'D:/Dropbox Main/Dropbox/datasets/Trials/match/' + str(match_idx) + '/' + batch[6][idx].split('\\')[-1])
                fout = open('D:/Dropbox Main/Dropbox/datasets/Trials/match/' + str(match_idx) + '/info.txt', 'w')
                fout.write('Gender = %s\nDB Name = %s\n' %(batch[3][idx], batch[-1][idx]))
                fout.close()
                match_idx += 1