In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms

import random

from PIL import Image

from pathlib import Path
import os

dtype = torch.float
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
nepochs = 10
nbatchs = 5
nbatchs_per_ep = 10
nbatchs_per_ev = 10
embed_size = 50
lr = 2e-3


xx = os.path.join("../checkpoints")
if not os.path.exists(xx):
    os.makedirs(xx)

In [None]:
class LogoDataset(Dataset):
    def __init__(
        self,
        folder_dataset: Path,
        transform=None,
    ) -> None:
        self.dataset = ImageFolder(root=folder_dataset)
        self.transform = transform

    def __len__(self) -> int:
        return len(self.dataset.imgs)

    def __getitem__(
        self,
        index: int,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

        x1 = self.dataset.imgs[index]

        is_same = random.randint(0, 1)
        if is_same:
            while True:
                x2 = random.choice(self.dataset.imgs)
                if x1[1] == x2[1]:
                    break
        else:
            while True:
                x2 = random.choice(self.dataset.imgs)
                if x1[1] != x2[1]:
                    break

        im1 = Image.open(x1[0]).convert("L")
        im2 = Image.open(x2[0]).convert("L")

        if self.transform is not None:
            im1 = self.transform(im1)
            im2 = self.transform(im2)

        return im1, im2, torch.tensor([int(x1[1] == x2[0])], dtype=dtype)


transform = transforms.Compose(
    [
        transforms.Resize((400, 400)),
        transforms.ToTensor(),
    ]
)

dataset = LogoDataset(
    folder_dataset=Path("../logos/"),
    transform=transform,
)

train_loader = DataLoader(
    dataset=dataset,
    shuffle=True,
    num_workers=0,
    batch_size=nbatchs,
)

In [None]:
class SiameseNet(nn.Module):
    def __init__(self) -> None:
        super(SiameseNet, self).__init__()

        # self.encoder = torchvision.models.vgg11(progress=False) # vgg11 takes so much memory!

        self.encoder = torchvision.models.alexnet(progress=False)
        self.encoder.features[0] = nn.Conv2d(
            1, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2)
        )

    def forward(self, x1: torch.tensor, x2: torch.tensor) -> torch.tensor:
        nbatchs = x1.size(0)
        x1 = self.encoder(x1).view((nbatchs, -1))
        x2 = self.encoder(x2).view((nbatchs, -1))

        return x1, x2

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(
        self,
        margin: int,
    ) -> None:
        super(ContrastiveLoss, self).__init__()

        if isinstance(margin, torch.Tensor):
            self.margin = margin
        else:
            self.margin = torch.tensor([margin], dtype=dtype)

    def forward(
        self, x1: torch.Tensor, x2: torch.Tensor, y: torch.Tensor
    ) -> torch.Tensor:
        dist2 = torch.nn.functional.pairwise_distance(x1, x2, keepdim=True).pow(2)
        out = (1 - y) * dist2 + y * torch.clamp(self.margin - dist2, min=0.0)
        return 0.5 * out.mean()

In [None]:
model = SiameseNet(embed_size=200).to(device)

loss_fn = ContrastiveLoss(margin=5)
optim = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
loss_history = []
for epoch in range(nepochs):
    model.train()
    acc = 0
    for batch, (x1, x2, y) in enumerate(train_loader):
        x1, x2, y = x1.to(device), x2.to(device), y.to(device)

        optim.zero_grad()

        # features
        out1, out2 = model(x1, x2)

        loss = loss_fn(out1, out2, y)
        loss.backward()
        optim.step()

        if not batch % 10:
            print(
                f"epoch {epoch+1}/{nepochs} batch {batch+1}/{nbatchs} loss {loss.item()}"
            )

    loss_history.append(loss.item())
    if not epoch % 10:
        torch.save(
            {
                "epoch": epoch,
                "loss": loss.item(),
                "plot_performance": loss_history,
                "model": model.state_dict(),
                "optimizer": optim.state_dict(),
            },
            f"../checkpoints/checkpoint_{epoch}.pkl",
        )

In [None]:
import matplotlib.pyplot as plt

plt.plot(list(range(1, nepochs + 1)), loss_history, label="Training loss")
plt.xlabel("Iterations")
plt.ylabel("Contrastive loss")
plt.legend()
plt.grid(True)
plt.show()