In [None]:
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch import nn
from torch.nn import functional as F
from torch import optim
import os
from tqdm.notebook import tqdm
from model.model_class import FaceRecogModel

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [7]:
path_to_dataset = './archive/images/'
path_to_csv = './archive/triplets.csv'

In [8]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

dataset_csv = pd.read_csv(path_to_csv)
dataset_csv = dataset_csv[["anchor", "pos", "neg"]]

In [9]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

class FaceDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform

        # Assume the data is stored as image files
        self.num_files = len(dataset_csv)
        self.anchor_images = dataset_csv["anchor"]
        self.pos_images = dataset_csv["pos"]
        self.neg_images = dataset_csv["neg"]

    def __len__(self):
        # Return the total number of samples (chunks)
        return self.num_files
    
    def __getitem__(self, idx):
        anchor_image = Image.open(os.path.join(self.data_path, self.anchor_images[idx]))
        pos_image = Image.open(os.path.join(self.data_path, self.pos_images[idx]))
        neg_image = Image.open(os.path.join(self.data_path, self.neg_images[idx]))

        if self.transform:
            anchor_image = self.transform(anchor_image)
            pos_image = self.transform(pos_image)
            neg_image = self.transform(neg_image)

        return anchor_image, pos_image, neg_image

In [10]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
dataset = FaceDataset(path_to_dataset, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [11]:
for images in dataloader:
    anchors, pos, neg = images
    print(anchors.shape)
    print(pos.shape)
    print(neg.shape)
    break

torch.Size([32, 3, 128, 128])
torch.Size([32, 3, 128, 128])
torch.Size([32, 3, 128, 128])


In [12]:
class TripletLoss(nn.Module):
    def __init__(self, margin):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        distance_positive = F.pairwise_distance(anchor, positive)
        distance_negative = F.pairwise_distance(anchor, negative)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean()

In [13]:
model = FaceRecogModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_function = TripletLoss(margin=0.4)

13456


In [None]:
NUM_EPOCHS = 10
loss_list = []

In [None]:
for epoch in tqdm(range(NUM_EPOCHS)):
    for anchor, positive, negative in dataloader:
        anchor = anchor.to(device)
        positive = positive.to(device)
        negative = negative.to(device)

        anchor_output = model(anchor)
        positive_output = model(positive)
        negative_output = model(negative)

        loss = loss_function(anchor_output, positive_output, negative_output)
        loss_list.append(loss.item())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

In [22]:
torch.save(model, 'model_weights.pt')