In [1]:
import torch
import random
import numpy as np
import pickle
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import time

random.seed(42)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
cat_train_path = "assets/transformed data/train_data_cat.pkl"
cat_test_path = "assets/transformed data/test_data_cat.pkl"
dog_train_path = "assets/transformed data/train_data_dog.pkl"
dog_test_path = "assets/transformed data/test_data_dog.pkl"


In [3]:
# --------- DATASET ---------
class SiamesePairDataset(Dataset):
    def __init__(self, dataset_dir, train_bool: bool):

        with open(dataset_dir, "rb") as f:
            self.dataset = pickle.load(f)
        self.train_bool = train_bool
        self.horizontal_flip = transforms.RandomHorizontalFlip(p=0.5)
        self.user_imgs = [item for item in self.dataset if item[1] == 1]
        self.neg_imgs = [item for item in self.dataset if item[1] == 0]
        self.all_pairs = []

        for i in range(len(self.user_imgs)):
            for j in range(i+1, len(self.user_imgs)):
                self.all_pairs.append((self.user_imgs[i][0], self.user_imgs[j][0], 1))

        for i in range(min(len(self.user_imgs), len(self.neg_imgs))):
            self.all_pairs.append((self.user_imgs[i][0], self.neg_imgs[i][0], 0))

    def __len__(self):
        return len(self.all_pairs)

    def __getitem__(self, idx):
        p1, p2, label = self.all_pairs[idx]

        if self.train_bool:
            return self.horizontal_flip(p1), self.horizontal_flip(p2), torch.tensor(label, dtype=torch.float32)

        else:
            return p1, p2, torch.tensor(label, dtype=torch.float32)


In [4]:
# --------- MODEL ---------
class EmbeddingNet(nn.Module):
    def __init__(self):
        super().__init__()
        mobilenet = models.mobilenet_v2(pretrained=True)
        self.features = mobilenet.features
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(1280, 128),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.features(x)
        x = self.pool(x)
        x = self.fc(x)
        return x

In [5]:
# --------- CONTRASTIVE LOSS ---------
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super().__init__()
        self.margin = margin

    def forward(self, o1, o2, label):
        dist = torch.nn.functional.pairwise_distance(o1, o2)
        loss = label * dist**2 + (1 - label) * torch.clamp(self.margin - dist, min=0)**2
        return loss.mean()

In [11]:
# --------- TRAINING ---------
def train_siamese(model, dataloader, epochs=10):
    model.to(DEVICE)
    criterion = ContrastiveLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # variables to stop the training if loss doesn't improve
    best_loss = float('inf')
    patience = 5
    trigger_times = 0

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        start_epoch = time.time()

        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}")
        for x1, x2, y in pbar:
            x1, x2, y = x1.to(DEVICE), x2.to(DEVICE), y.to(DEVICE)
            out1, out2 = model(x1), model(x2)
            loss = criterion(out1, out2, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            pbar.set_postfix(loss=loss.item())

        end_epoch = time.time()
        print(f"Epoch {epoch+1} Summary: Avg Loss = {total_loss / len(dataloader):.4f} | Time: {end_epoch - start_epoch:.2f}s")


        # Below stops training if loss doesn't improve
        avg_loss = total_loss / len(dataloader)
        if avg_loss < best_loss:
            best_loss = avg_loss
            trigger_times = 0
            torch.save(model.state_dict(), f"assets/trained models/model_temp_epoch_{epoch+1}.pt")
        else:
            trigger_times += 1
            if trigger_times >= patience:
                print("Early stopping triggered")
                break
        

In [13]:
# To train dogs, pass dog_train_path to the dataset
# To train cats, pass cat_train_path to the dataset
dataset = SiamesePairDataset(cat_train_path, train_bool=True)

In [15]:
loader = DataLoader(dataset, batch_size=16, shuffle=True)

In [17]:
model = EmbeddingNet()



In [19]:
train_siamese(model, loader, epochs=50)

Epoch 1: 100%|████████████████████████████████████████████████████████████| 45/45 [00:07<00:00,  5.96it/s, loss=0.0223]


Epoch 1 Summary: Avg Loss = 0.1721 | Time: 7.56s


Epoch 2: 100%|████████████████████████████████████████████████████████████| 45/45 [00:07<00:00,  6.26it/s, loss=0.0181]


Epoch 2 Summary: Avg Loss = 0.1622 | Time: 7.21s


Below code allows to see what happens when Horizontal flip happens. Change value p value in  RandomHorizontalFlip(p=1)

to visually siee the difference between flipped and not flipped images.

In [29]:

with open(f"assets/transformed data/train_data_dog.pkl", "rb") as f:
    data = pickle.load(f)

# Define the transform
horizontal_flip = transforms.RandomHorizontalFlip(p=0)

# see the image
tens = data[1][0]

# make the flip
tens = horizontal_flip(tens)

# Undo normalization to see the picture
mean = torch.tensor([0.485, 0.456, 0.406], device=tens.device).view(-1, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225], device=tens.device).view(-1, 1, 1)

b = (tens * std ) + mean

# Convert to PIL image
to_pil = transforms.ToPILImage()
image = to_pil(b)

# Show the image
image.show()

