In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms

import random

from PIL import Image

from pathlib import Path
import os

import collections

from sklearn.model_selection import train_test_split

dtype = torch.float
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
nepochs = 100
# nbatchs = 30
nbatchs_per_ep = 30
nbatchs_per_ev = 512
# embed_size = 50
lr = 2e-3


xx = os.path.join("../checkpoints")
if not os.path.exists(xx):
    os.makedirs(xx)

In [None]:
class LogoDataset(Dataset):
    def __init__(
        self,
        folder_dataset: Path,
        transform=None,
    ) -> None:
        self.dataset = ImageFolder(root=folder_dataset)
        self.transform = transform

    def __len__(self) -> int:
        return len(self.dataset.imgs)

    def __getitem__(
        self,
        index: int,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

        x1 = random.choice(self.dataset.imgs)

        is_same = random.randint(0, 1)
        if is_same:
            while True:
                x2 = random.choice(self.dataset.imgs)
                if x1[1] == x2[1]:
                    break
        else:
            while True:
                x2 = random.choice(self.dataset.imgs)
                if x1[1] != x2[1]:
                    break

        im1 = Image.open(x1[0]).convert("RGBA").convert("L")
        im2 = Image.open(x2[0]).convert("RGBA").convert("L")

        if self.transform is not None:
            im1 = self.transform(im1)
            im2 = self.transform(im2)

        return im1, im2, torch.tensor([int(x1[1] == x2[1])], dtype=dtype)


transform = transforms.Compose(
    [
        transforms.Resize((264, 400)),
        transforms.ToTensor(),
    ]
)


dataset = LogoDataset(
    folder_dataset=Path("../test_imgs"),
    transform=transform,
)

train_loader = DataLoader(
    dataset=dataset,
    shuffle=False,
    num_workers=0,
    batch_size=nbatchs_per_ep,
)

In [None]:
class SiameseNet(nn.Module):
    def __init__(self,
        init_weight=None,
    ) -> None:
        super(SiameseNet, self).__init__()
        
        # vgg11 takes so much memory!
        # self.encoder = torchvision.models.vgg11(progress=False) 

        self.encoder = torchvision.models.alexnet(progress=False)
        self.encoder.features[0] = nn.Conv2d(
            1, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2)
        )
        
        self.ff = nn.Sequential(
            nn.Linear(1, 3),
            nn.ReLU(),
            nn.Linear(3, 1),
            nn.Sigmoid(),
        )

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        nbatchs = x1.size(0)
        x1 = self.encoder(x1).view((nbatchs, -1))
        x2 = self.encoder(x2).view((nbatchs, -1))

        dist = torch.nn.functional.pairwise_distance(x1, x2, keepdim=True)
        return self.ff(dist)

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(
        self,
        reduce:bool,
        margin:int = 1,
    ) -> None:
        super(ContrastiveLoss, self).__init__()

        if isinstance(margin, torch.Tensor):
            self.margin = margin
        else:
            self.margin = torch.tensor([margin], dtype=dtype).to(device)

        self.reduce = reduce
        
    def forward(
        self, yhat: torch.Tensor, y: torch.Tensor
    ) -> torch.Tensor:
        out = (1 - y) * yhat.pow(2) + y * torch.clamp(self.margin - yhat, min=0.0).pow(2)
        
        if self.reduce:
            out = out.mean()
        return .5 * out

In [None]:
model = SiameseNet()
model = nn.DataParallel(model)
model = model.to(device)

# warstarting model using parameters
# from a "wrong" model
checkpoint = torch.load("../checkpoints/checkpoint_20.pkl", map_location=device)
model.module.encoder.load_state_dict(checkpoint["model"], strict=False)

loss_fn = ContrastiveLoss(margin=1, reduce=True)
optim = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
loss_history = []
threshold  =.5

for epoch in range(nepochs):
    model.train()
    accuracy = 0
    for batch, (x1, x2, y) in enumerate(train_loader):
        x1, x2, y = x1.to(device), x2.to(device), y.to(device)

        optim.zero_grad()

        # features
        yhat = model(x1, x2)

        loss = loss_fn(yhat, y)
        loss.backward()
        optim.step()
            
        with torch.no_grad():
            accuracy += ((yhat>threshold) == y).sum()

    loss_history.append(loss.item())
    with torch.no_grad():
        accuracy /= len(train_loader)

    if not epoch % 1:
        print(
            f"epoch {epoch+1}/{nepochs} loss {loss.item():.4f} accuracy {accuracy.cpu().numpy():.4f}"
        )

        torch.save(
            {
                "epoch": epoch,
                "loss": loss.item(),
                "plot_performance": loss_history,
                "model": model.state_dict(),
                "optimizer": optim.state_dict(),
            },
            f"../checkpoints/checkpoint_{epoch}.pkl",
        )

In [None]:
def evaluate_siamese_network(dataset_loader, model_path, device, threshold=.5):
    if model_path is not None:
        model = torch.load(
            model_path, map_location=device
        )

    model.eval()
    accuracy = 0
    true_positive = false_positive = true_negative = false_negative = 0
    Outcome = collections.namedtuple(
        "Outcome",
        [
            "accuracy",
            "true_positive",
            "true_negative",
            "false_positive",
            "false_negative",
        ],
    )

    dataset_loader = DataLoader(
        dataset=dataset,
        shuffle=False,
        num_workers=0,
        batch_size=nbatchs_per_ev,
    )

    for x1, x2, y in dataset_loader:
        x1, x2, y = x1.to(device), x2.to(device), y.to(device)

        yhat = model(x1, x2)

        correct = (yhat >= threshold) == y

        for c, yh in zip(correct, yhat):
            if yh:
                if c:
                    true_positive += 1
                else:
                    false_positive += 1
            else:
                if c:
                    true_negative += 1
                else:
                    false_negative += 1
            accuracy += int(c)

    return Outcome(
        accuracy=accuracy/len(train_loader),
        true_positive=true_positive,
        false_positive=false_positive,
        true_negative=true_negative,
        false_negative=false_negative
    )


In [None]:
outcome = evaluate_siamese_network(train_loader, model, .5)
print(f"Accuracy: {outcome.accuracy:0.2f}")
true_positive = outcome.true_positive
false_negative = outcome.false_negative
true_negative = outcome.true_negative
false_positive = outcome.false_positive

print(f"True Positive Rate: {true_positive/(true_positive + false_negative): 0.2f}")
print(f"True Negative Rate: {true_negative/(true_negative + false_positive):0.2f}")
print(f"Precision: {outcome.true_positive/(true_positive + false_positive):0.2f}")
print(f"False Negative Rate: {false_negative/(false_negative + true_positive):0.2f}")
print(f"False Positive Rate: {false_positive/(false_positive + true_negative): 0.2f}")

In [None]:
import matplotlib.pyplot as plt

plt.plot(list(range(1, len(loss_history) + 1)), loss_history, label="Training loss")
plt.xlabel("Iterations")
plt.ylabel("Contrastive loss")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
nsamples = 3
thresh = .5

m = SiameseNet().to(device)

# checkpoint = torch.load(
#     "../checkpoints/checkpoint_20.pkl", map_location=device
# )

# m.load_state_dict(checkpoint["model"], strict=False)

m.eval()
train_loader = DataLoader(
    dataset=dataset,
    shuffle=False,
    num_workers=0,
    batch_size=nsamples,
)

data = iter(train_loader)
x1, x2, y = next(data)
x1, x2, y = x1.to(device), x2.to(device), y.to(device)


# prediction
yp = model(x1, x2) >= thresh

fig, axs = plt.subplots(2, nsamples, figsize=(40,25))
for i in range(2):    
    for j in range(nsamples):
        axs[i,j].set_xticklabels([])
        axs[i,j].set_yticklabels([])
        axs[i,j].set_aspect('equal')
        xx = x1[j].cpu().numpy().transpose(1,2,0)
        yg = y[j].int().item()
        yp = yhat[j].int().item()

        if i:
            axs[i,j].imshow(xx, cmap="gray")
        if not i:
            axs[i,j].set_title(f"G{yg} P{yp} - {'correct' if yp==yg else 'false'}")
            xx = x2[j].cpu().numpy().transpose(1,2,0)
            axs[i,j].imshow(xx, cmap="gray")
plt.subplots_adjust(wspace=.05, hspace=-.7)