In [None]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

import matplotlib.pyplot as plt

import tqdm as tqdm

import PIL
from PIL import Image

import random

In [None]:
!tar -xzf /kaggle/input/lfwpeople/lfw-funneled.tgz

In [None]:
def generate_triplets(data_dir):
    people_pathes = [os.path.join(data_dir, person) for person in os.listdir(data_dir) if '.txt' not in person]
    people_count = len(people_pathes)

    singles = []
    trimmed = []
    
    for person in people_pathes:
        person_images_pathes = [os.path.join(person, image_path) for image_path in os.listdir(person)]
        if len(person_images_pathes) == 1:
            singles.append(os.path.join(person, os.listdir(person)[0]))
        else:
            trimmed.append(person)
        
    print('This dataset contain {}, {} of which has a single, and {} have two or more '
          .format(people_count, len(singles), len(trimmed)))
    
    anchors = []
    positives = []
    negtives = []
    
    trimmed_count = len(trimmed)

    for person in trimmed:
        person_images_pathes = [os.path.join(person, image_path) for image_path in os.listdir(person)]
        
        for image in person_images_pathes:
            anchor = image
            positive = person_images_pathes[random.randrange(0, len(person_images_pathes) - 1)]
            negtive = singles[random.randrange(0, len(singles) - 1)]
            
            anchors.append(anchor)
            positives.append(positive)
            negtives.append(negtive)
    
    return anchors, positives, negtives

dataset_dir = '/kaggle/working/lfw_funneled'
ANCH, POS, NEG = generate_triplets(dataset_dir)

TRIPLETS_LEN = len(ANCH)
def get_triplets_by_index(index):
    return ANCH[index], POS[index], NEG[index]

if len(ANCH) == len(POS) and len(NEG) == len(NEG):
    print('Generated {} triplets'.format(len(NEG)))

In [None]:
print(get_triplets_by_index(TRIPLETS_LEN - 1))

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
print(device)

In [None]:
IMAGE_INPUT_SIZE = (224, 224)

transform = transforms.Compose(
    [transforms.Resize(IMAGE_INPUT_SIZE),
    transforms.ToTensor()]
)

In [None]:
def load_and_transform_triplets_by_index(index):
    a, p, n = get_triplets_by_index(index)
    
    a_tensor = transform(Image.open(a))
    p_tensor = transform(Image.open(p))
    n_tensor = transform(Image.open(n))
    
    return a_tensor, p_tensor, n_tensor

def load_and_transform_batch(start, end):
    try:
        anchor_batch   = torch.stack([ transform(Image.open(ANCH[i]))  for i in range(start, end) ]).to(device)
        positive_batch = torch.stack([ transform(Image.open(POS[i]))   for i in range(start, end) ]).to(device)                                
        negtive_batch  = torch.stack([ transform(Image.open(NEG[i]))   for i in range(start, end) ]).to(device)       
        return anchor_batch, positive_batch, negtive_batch
    except:
        pass


In [None]:
def plot_triplets_by_index(index):
    a_tensor, p_tensor, n_tensor = load_and_transform_triplets_by_index(index)
    
    fig=plt.figure(figsize=(32, 8))

    fig.add_subplot(1, 3, 1)
    plt.imshow(a_tensor.view(224, 224, 3))
    plt.text(0.02, 0.5,'Anchor', fontsize=18)

    fig.add_subplot(1, 3, 2)
    plt.imshow(p_tensor.view(224, 224, 3))
    plt.text(0.02, 0.5,'Positive', fontsize=18)

    fig.add_subplot(1, 3, 3)
    plt.imshow(n_tensor.view(224, 224, 3))
    plt.text(0.02, 0.5,'Negitive', fontsize=18)

    plt.show()
        
plot_triplets_by_index(3)

In [None]:
# define the model
model = models.resnet18(pretrained=False).to(device)


In [None]:
triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.1e-3)

In [None]:
EPOCHS = 10

for epoch in range(EPOCHS):
    running_loss = 0.0
    
    for i in range( int(len(ANCH)) ):
        try:
            A, P, N = load_and_transform_batch(i, i + 50)   
            
            optimizer.zero_grad()
            
            anchor_out = model(A)
            negitive_out = model(P)        
            positive_out = model(N)
            
            loss = triplet_loss(anchor_out, negitive_out, positive_out)
            loss.backward()
            optimizer.step()
            
            # print statistics
            running_loss += loss.item()
            if i % 10 == 9:    # print every 200 mini-batches
                print('[%d, %5d] loss: %.9f' %
                      (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0

        except:
            print("epoch {} done".format(epoch))
            break


In [None]:
torch.save(model, 'facenet.ptm')

In [None]:
def check(a, b):
    a_tensor = transform(Image.open(a))
    p_tensor = transform(Image.open(b))
    
    return torch.dist(model(a_tensor.view(1, 3, 224, 224).to(device))
                      , model(p_tensor.view(1, 3, 224, 224).to(device))
                     )


In [None]:
def compare(index):
    print(check(ANCH[index], POS[index]), 'anchor - pos')
    print(check(ANCH[index], NEG[index]), 'anchor - neg')
    
for sample in range(int(len(ANCH)* 0.9)):
    compare(sample)