# A comparison of Embedding space in Self-Supervised method and Supervised method

Each type of neural network has it's own way to embed the input data into another high-dimensional space. In this notebook i'll show a comparison on how different Self-Supervised methods we saw during the course and Supervised Method embed their data and we will measure the goodness of those embeddings by applying some data mining technique.

In [4]:
from losses import *
from networks import SimpleCNN, SiameseNetwork


import pandas as pd
import seaborn as sns
import numpy as np
import random
import os
import torch
import torch.nn as nn
from tqdm import tqdm

import torch.nn.functional as F
import matplotlib.pyplot as plt

# PCA
from sklearn.decomposition import PCA

# import mnist from torchvision
from torchvision import datasets, transforms

### Cuda to go faster

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Data Loading

In [6]:
class TripletDataset:

    def __init__(self, mnist_dataset):
        self.mnist_dataset = mnist_dataset
        self.classes = list(set(mnist_dataset.targets.numpy()))
        self.class_to_images = {c: [] for c in self.classes}
        
        # Group images by class
        for i in range(len(mnist_dataset)):
            image, label = mnist_dataset[i]
            self.class_to_images[label].append(image)
        
    def __getitem__(self, index):
        anchor_image, anchor_label = self.mnist_dataset[index]
        
        # Select positive image from the same class
        positive_image = random.choice(self.class_to_images[anchor_label])
        
        # Select negative image from a different class
        negative_label = random.choice([c for c in self.classes if c != anchor_label])
        negative_image = random.choice(self.class_to_images[negative_label])
        
        return anchor_image, positive_image, negative_image
    
    
    def __len__(self):
        return len(self.mnist_dataset)
    
class SiameseMNIST:
    
    def __init__(self, mnist_dataset):
        self.mnist_dataset = mnist_dataset

    def __getitem__(self, index):
        img1, label1 = self.mnist_dataset[index]
        should_get_same_class = np.random.randint(0, 2)
        if should_get_same_class:
            while True:
                index2 = np.random.randint(0, len(self.mnist_dataset))
                img2, label2 = self.mnist_dataset[index2]
                if label1 == label2:
                    break
        else:
            while True:
                index2 = np.random.randint(0, len(self.mnist_dataset))
                img2, label2 = self.mnist_dataset[index2]
                if label1 != label2:
                    break
        return (img1, img2), torch.tensor(int(label1 != label2), dtype=torch.float32)

    def __len__(self):
        return len(self.mnist_dataset)

In [None]:
# creating and loading dataset
mnist = datasets.MNIST('./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))

triplet_dataset = TripletDataset(mnist)
triplet_loader = torch.utils.data.DataLoader(triplet_dataset, batch_size=1024, shuffle=True)

In [None]:
# test a forward pass and print shapes
# networks are defined in networks.py in classes SimpleCNN and SiameseNetwork
model = SimpleCNN().to(device)

# taking a random sample from the dataset, we should get a tensor of shape [1, 32]
sample =  mnist[random.choice(range(len(mnist)))]
print(model.forward_once(sample[0].unsqueeze(0).to(device)).shape)

In [None]:
sample = next(iter(triplet_loader))
print(sample[0].shape, sample[1].shape, sample[2].shape)

# plot with matplotlib
import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 3)
axs[0].imshow(sample[0][0].squeeze(), cmap='gray')
axs[1].imshow(sample[1][0].squeeze(), cmap='gray')
axs[2].imshow(sample[2][0].squeeze(), cmap='gray')

# put legend 
axs[0].set_title('Anchor')
axs[1].set_title('Positive')
axs[2].set_title('Negative')

plt.show()

## First experiment: Extracting feature from Triplet Loss trained model

- Triplet Loss takes 3 images
  - Anchor
  - Positive
  - Negative

- The model is trained to minimize the distance between Anchor and Positive and maximize the distance between Anchor and Negative.

$$
\mathcal{L(x, x^+, x^-)} = \sum_{i=1}^{N} \max(0, ||f(x_i) - f(x_i^+)||^2 - ||f(x_i) - f(x_i^-)||^2 + \epsilon)
$$

Ideally we want to use hard samples, so those that are hard negative or hard positive. So sampling matters if we want better performances.

In [None]:
# check for checkpoint
if os.path.exists('tripletmodel.pth'):
    tripletmodel = SimpleCNN()
    tripletmodel.load_state_dict(torch.load('tripletmodel.pth'))
    tripletmodel.eval()
    tripletmodel.to(device)
    print('Model loaded')
else:
        
    tripletmodel = SimpleCNN()
    tripletloss = TripletLoss()


    optimizer = torch.optim.Adam(tripletmodel.parameters(), lr=1e-3)

    tripletmodel.to(device)
    triplet_train_loss = []

    for epoch in range(10):
        pbar = tqdm(triplet_loader)
        epoch_loss = 0
        for anchor, positive, negative in pbar:

            anchor = anchor.to(device)
            positive = positive.to(device)
            negative = negative.to(device)

            optimizer.zero_grad()

            anchor_embedding = tripletmodel.forward_once(anchor)
            positive_embedding = tripletmodel.forward_once(positive)
            negative_embedding = tripletmodel.forward_once(negative)
            
            loss = tripletloss(anchor_embedding, positive_embedding, negative_embedding)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            pbar.set_postfix_str(f'Loss {loss.item()}')
            triplet_train_loss.append(loss.item())
            
        print(f'Epoch {epoch}, Loss {epoch_loss/len(triplet_loader)}')

        
    # Save the model
    torch.save(tripletmodel.state_dict(), 'tripletmodel.pth')
    # save the loss
    np.save('triplet_train_loss.npy', triplet_train_loss)



In [None]:
tripletmodel.eval()
embeddings = []
labels = []

#take the first 1000 images
for i in range(5000):
    image, label = mnist[i]

    image = image.unsqueeze(0).to(device)
    embedding = tripletmodel.forward_once(image)
    labels.append(label)
    embeddings.append(embedding.cpu().detach().numpy())

embeddings_triplet = np.concatenate(embeddings, axis=0)


# ogni colonna del dataframe è una dimensione dell'embedding (x0, x1, x2, ..., x31) e l'ultima colonna è la label
df = pd.DataFrame(embeddings_triplet, columns=[f'x{i}' for i in range(embeddings_triplet.shape[1])])
df['label'] = labels

df.head()

In [None]:
# make pca and plot
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
pca_embeddings_triplet = pca.fit_transform(embeddings_triplet)

df['pca1'] = pca_embeddings_triplet[:, 0]
df['pca2'] = pca_embeddings_triplet[:, 1]


plt.figure(figsize=(12, 5))
sns.scatterplot(data=df, x='pca1', y='pca2', hue='label', palette='tab10')

# Siamese Network 

- Siamese Network takes 2 images and works similarly to Triplet Loss, but it's trained to minimize the distance between the two images.

- Runs two forward passes on the same network with the same weights on the images
- Apply the following loss function:

First compute a distance, can be L1 or L2 thats is differentiable (also cosine works) + linear ffw and nonlinearity ($\sigma$) to get a probability
$$
p(x_i, x_j) = \sigma(W |f(x_i) - f(x_j)|)
$$

Of course the loss will be on a Batch : $\mathcal{L}(B)$ but on two sample is something like:
$$
\mathcal{L(x_i, x_j)} = \mathbb{1}_{y_i = y_j} \log(p(x_i, x_j)) + \mathbb{1}_{y_i \neq y_j} \log(1 - p(x_i, x_j)) 
$$


In [None]:
mnist = datasets.MNIST('./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))

siamese_dataset = SiameseMNIST(mnist)

siamese_loader = torch.utils.data.DataLoader(siamese_dataset, batch_size=1024, shuffle=True)

In [None]:
sample = next(iter(siamese_loader))

# show both positive and negative pair
fig, axs = plt.subplots(1, 2)
axs[0].imshow(sample[0][0][0].squeeze(), cmap='gray')
axs[1].imshow(sample[0][1][0].squeeze(), cmap='gray')


In [None]:
# check for checkpoint
if os.path.exists('siamese_network.pth'):
    siamese_network = SiameseNetwork()
    siamese_network.load_state_dict(torch.load('siamese_network.pth'))
    siamese_network.eval()
    siamese_network.to(device)
    print('Model loaded')
else:

    siamese_network = SiameseNetwork()
    siamese_network.to(device)

    optimizer = torch.optim.Adam(siamese_network.parameters(), lr=1e-3)

    contrastive_loss = ContrastiveLoss()

    siamese_train_loss = []

    for epoch in range(10):
        
        pbar = tqdm(siamese_loader)
        epoch_loss = 0
        
        for (img1, img2), target in pbar:
            
            img1 = img1.to(device)
            img2 = img2.to(device)
            #label
            target = target.to(device)
            
            optimizer.zero_grad()
            output1, output2 = siamese_network(img1, img2)
            loss = contrastive_loss(output1, output2, target)

            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            pbar.set_postfix_str(f'Loss {loss.item()}')
            siamese_train_loss.append(loss.item())
            
        print(f'Epoch {epoch}, Loss {epoch_loss/len(siamese_loader)}')

    # Save the model
    torch.save(siamese_network.state_dict(), 'siamese_network.pth')
    # save the loss
    np.save('siamese_train_loss.npy', siamese_train_loss)


In [None]:
siamese_network.eval()

embeddings = []
labels = []

for i in range(5000):
    image, label = mnist[i]

    image = image.unsqueeze(0).to(device)
    #forward once
    embedding = siamese_network(image, image)[0]
    labels.append(label)
    embeddings.append(embedding.cpu().detach().numpy())
    
    
embeddings_siamese = np.concatenate(embeddings, axis=0)

df = pd.DataFrame(embeddings_siamese, columns=[f'x{i}' for i in range(embeddings_siamese.shape[1])])
df['label'] = labels

In [None]:
pca = PCA(n_components=2)

pca_embeddings_siamese = pca.fit_transform(embeddings_siamese)

df['pca1'] = pca_embeddings_siamese[:, 0]
df['pca2'] = pca_embeddings_siamese[:, 1]

plt.figure(figsize=(12, 5))
sns.scatterplot(data=df, x='pca1', y='pca2', hue='label', palette='tab10')

plt.show()

# Baseline on Supervised Approach: SimpleCNN

- We will train a simple CNN on the dataset and extract the features from the last layer before the classification layer.

- The loss is the usual CrossEntropy Loss

In [None]:
# Just classic train of CNN on MNIST

# check for checkpoint

mnist = datasets.MNIST('./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))

mnist_loader = torch.utils.data.DataLoader(mnist, batch_size=1024, shuffle=True)

if os.path.exists('baseline_model.pth'):
    model = SimpleCNN()
    model.load_state_dict(torch.load('baseline_model.pth'))
    baseline_train_loss = np.load('baseline_train_loss.npy')
else:

    model = SimpleCNN().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    baseline_train_loss = []

    # send to device
    for epoch in range(10):
        pbar = tqdm(mnist_loader)
        epoch_loss = 0
        for image, label in pbar:

            image = image.to(device)
            label = label.to(device)

            optimizer.zero_grad()

            output = model.forward_once(image)
            loss = criterion(output, label)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            pbar.set_postfix_str(f'Loss {loss.item()}')
            baseline_train_loss.append(loss.item())

        print(f'Epoch {epoch}, Loss {epoch_loss/len(mnist_loader)}')

    # Save the model
    torch.save(model.state_dict(), 'baseline_model.pth')
    # save the loss
    np.save('baseline_train_loss.npy', baseline_train_loss)
    

In [None]:
model.eval()

embeddings = []
labels = []

for i in range(5000):
    image, label = mnist[i]

    image = image.unsqueeze(0).to(device)
    embedding = model.forward_once(image)
    labels.append(label)
    embeddings.append(embedding.cpu().detach().numpy())
    
embeddings_baseline = np.concatenate(embeddings, axis=0)

df = pd.DataFrame(embeddings_baseline, columns=[f'x{i}' for i in range(embeddings_baseline.shape[1])])
df['label'] = labels

pca = PCA(n_components=2)

pca_embeddings_baseline = pca.fit_transform(embeddings_baseline)

df['pca1'] = pca_embeddings_baseline[:, 0]
df['pca2'] = pca_embeddings_baseline[:, 1]

plt.figure(figsize=(12, 5))
sns.scatterplot(data=df, x='pca1', y='pca2', hue='label', palette='tab10')

plt.show()


## Evaluating goodness of clustering

In [None]:
# We can use a data mining approach like silhouette score to evaluate the quality of the clustering
# So we can compare different Losses: TripletLoss, ContrastiveLoss, SiameseLoss and SupervisedApproach
from sklearn.metrics import silhouette_score

triplet_score = silhouette_score(embeddings_triplet, labels)
print(f'Triplet Loss Silhouette Score: {triplet_score}')

siamese_score = silhouette_score(embeddings_siamese, labels)   
print(f'Siamese Loss Silhouette Score: {siamese_score}')

baseline_score = silhouette_score(embeddings_baseline, labels)
print(f'Baseline Silhouette Score: {baseline_score}')


# Bar chart order by silhouette score
scores = [siamese_score, triplet_score, baseline_score]

plt.bar(['Siamese', 'Triplet', 'Baseline'], scores)
plt.ylabel('Silhouette Score')
plt.show()


## Plotting Losses to visualize convergence of our models

In [None]:
import matplotlib.pyplot as plt  # Ensure this import is correct

# plotting the losses
# add some alpha to the plot
plt.plot(triplet_train_loss, label='Triplet Loss', alpha=0.75)
plt.plot(siamese_train_loss, label='Siamese (Contrastive Loss)', alpha=0.75)
plt.plot(baseline_train_loss, label='Baseline', alpha=0.75)
plt.legend()

## Experimenting on FashionMNIST

- We will use the FashionMNIST dataset to train the models and extract the embeddings.

In [None]:
fashion_mnist = datasets.FashionMNIST('./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))

triplet_dataset = TripletDataset(fashion_mnist)
triplet_loader = torch.utils.data.DataLoader(triplet_dataset, batch_size=1024, shuffle=True)

In [None]:
# show sample
sample = next(iter(triplet_loader))

fig, axs = plt.subplots(1, 3)
axs[0].imshow(sample[0][0].squeeze(), cmap='gray')
axs[1].imshow(sample[1][0].squeeze(), cmap='gray')
axs[2].imshow(sample[2][0].squeeze(), cmap='gray')

In [None]:
if os.path.exists('triplet_fashion_model.pth'):
    tripletmodel = SimpleCNN()
    tripletmodel.load_state_dict(torch.load('triplet_fashion_model.pth'))
    tripletmodel.eval()
    tripletmodel.to(device)
    print('Model loaded')
else:

    tripletmodel = SimpleCNN()
    tripletloss = TripletLoss()

    optimizer = torch.optim.Adam(tripletmodel.parameters(), lr=1e-3)

    tripletmodel.to(device)
    triplet_train_loss = []

    for epoch in range(10):
        pbar = tqdm(triplet_loader)
        epoch_loss = 0
        for anchor, positive, negative in pbar:

            anchor = anchor.to(device)
            positive = positive.to(device)
            negative = negative.to(device)

            optimizer.zero_grad()

            anchor_embedding = tripletmodel.forward_once(anchor)
            positive_embedding = tripletmodel.forward_once(positive)
            negative_embedding = tripletmodel.forward_once(negative)
            
            loss = tripletloss(anchor_embedding, positive_embedding, negative_embedding)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            pbar.set_postfix_str(f'Loss {loss.item()}')
            triplet_train_loss.append(loss.item())
            
        print(f'Epoch {epoch}, Loss {epoch_loss/len(triplet_loader)}')

    # Save the model
    torch.save(tripletmodel.state_dict(), 'triplet_fashion_model.pth')
    # save the loss
    np.save('triplet_fashion_train_loss.npy', triplet_train_loss)
    

In [None]:

tripletmodel.eval()

embeddings = []
labels = []

for i in range(5000):
    image, label = fashion_mnist[i]

    image = image.unsqueeze(0).to(device)
    embedding = tripletmodel.forward_once(image)
    labels.append(label)
    embeddings.append(embedding.cpu().detach().numpy())
    
embeddings_triplet_fashion = np.concatenate(embeddings, axis=0)

df = pd.DataFrame(embeddings_triplet_fashion, columns=[f'x{i}' for i in range(embeddings_triplet_fashion.shape[1])])
df['label'] = labels

In [None]:
pca = PCA(n_components=2)

pca_embeddings_triplet_fashion = pca.fit_transform(embeddings_triplet_fashion)

df['pca1'] = pca_embeddings_triplet_fashion[:, 0]
df['pca2'] = pca_embeddings_triplet_fashion[:, 1]

plt.figure(figsize=(12, 5))
sns.scatterplot(data=df, x='pca1', y='pca2', hue='label', palette='tab10')

plt.show()


In [None]:
fascion_mnist = datasets.FashionMNIST('./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))

siamese_dataset = SiameseMNIST(fashion_mnist)

siamese_loader = torch.utils.data.DataLoader(siamese_dataset, batch_size=1024, shuffle=True)


In [None]:
# show sample
sample = next(iter(siamese_loader))

fig, axs = plt.subplots(1, 2)
axs[0].imshow(sample[0][0][0].squeeze(), cmap='gray')
axs[1].imshow(sample[0][1][0].squeeze(), cmap='gray')

# label 1 or 0 if the images are similar or not
label = sample[1][0].item()
print(f'Target: {label}')


In [None]:
tripletmodel.eval()
embeddings = []
labels = []

#take the first 1000 images
for i in range(5000):
    image, label = mnist[i]

    image = image.unsqueeze(0).to(device)
    embedding = tripletmodel.forward_once(image)
    labels.append(label)
    embeddings.append(embedding.cpu().detach().numpy())

embeddings_triplet = np.concatenate(embeddings, axis=0)


# ogni colonna del dataframe è una dimensione dell'embedding (x0, x1, x2, ..., x31) e l'ultima colonna è la label
df = pd.DataFrame(embeddings_triplet, columns=[f'x{i}' for i in range(embeddings_triplet.shape[1])])
df['label'] = labels

df.head()

In [None]:
if os.path.exists('siamese_fashion_network.pth'):
    siamese_network = SiameseNetwork()
    siamese_network.load_state_dict(torch.load('siamese_fashion_network.pth'))
    siamese_network.eval()
    siamese_network.to(device)
    print('Model loaded')
else:
        

    siamese_network = SiameseNetwork()

    siamese_network.to(device)

    optimizer = torch.optim.Adam(siamese_network.parameters(), lr=1e-3)

    contrastive_loss = ContrastiveLoss()

    siamese_train_loss = []

    for epoch in range(10):
            
            pbar = tqdm(siamese_loader)
            epoch_loss = 0
            
            for (img1, img2), target in pbar:
                
                img1 = img1.to(device)
                img2 = img2.to(device)
                target = target.to(device)
                
                optimizer.zero_grad()
                output1, output2 = siamese_network(img1, img2)
                loss = contrastive_loss(output1, output2, target)
        
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.item()
                pbar.set_postfix_str(f'Loss {loss.item()}')
                siamese_train_loss.append(loss.item())
                
            print(f'Epoch {epoch}, Loss {epoch_loss/len(siamese_loader)}')
            
    # Save the model
    torch.save(siamese_network.state_dict(), 'siamese_fashion_network.pth')
    # save the loss
    np.save('siamese_fashion_train_loss.npy', siamese_train_loss)
    

In [None]:
siamese_network.eval()

embeddings = []
labels = []

for i in range(5000):
    image, label = fashion_mnist[i]

    image = image.unsqueeze(0).to(device)
    embedding = siamese_network(image, image)[0]
    labels.append(label)
    embeddings.append(embedding.cpu().detach().numpy())
    
embeddings_siamese_fashion = np.concatenate(embeddings, axis=0)

df = pd.DataFrame(embeddings_siamese_fashion, columns=[f'x{i}' for i in range(embeddings_siamese_fashion.shape[1])])
df['label'] = labels


In [None]:
pca = PCA(n_components=2)

pca_embeddings_siamese_fashion = pca.fit_transform(embeddings_siamese_fashion)

df['pca1'] = pca_embeddings_siamese_fashion[:, 0]
df['pca2'] = pca_embeddings_siamese_fashion[:, 1]

plt.figure(figsize=(12, 5))
sns.scatterplot(data=df, x='pca1', y='pca2', hue='label', palette='tab10')

plt.show()



In [None]:
fashion_mnist = datasets.FashionMNIST('./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))

mnist_loader = torch.utils.data.DataLoader(fashion_mnist, batch_size=1024, shuffle=True)


In [None]:
if os.path.exists('baseline_fashion_model.pth'):
    model = SimpleCNN()
    model.load_state_dict(torch.load('baseline_fashion_model.pth'))
    baseline_train_loss = np.load('baseline_fashion_train_loss.npy')
else:

        
    model = SimpleCNN().to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    baseline_train_loss = []

    for epoch in range(10):
        pbar = tqdm(mnist_loader)
        epoch_loss = 0
        for image, label in pbar:

            image = image.to(device)
            label = label.to(device)

            optimizer.zero_grad()

            output = model.forward_once(image)
            loss = criterion(output, label)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            pbar.set_postfix_str(f'Loss {loss.item()}')
            baseline_train_loss.append(loss.item())

        print(f'Epoch {epoch}, Loss {epoch_loss/len(mnist_loader)}')
        
    # Save the model
    torch.save(model.state_dict(), 'baseline_fashion_model.pth')
    # save the loss
    np.save('baseline_fashion_train_loss.npy', baseline_train_loss)
    

In [None]:
model.eval()

embeddings = []
labels = []

for i in range(5000):
    image, label = fashion_mnist[i]

    image = image.unsqueeze(0).to(device)
    embedding = model.forward_once(image)
    labels.append(label)
    embeddings.append(embedding.cpu().detach().numpy())
    
embeddings_baseline_fashion = np.concatenate(embeddings, axis=0)

df = pd.DataFrame(embeddings_baseline_fashion, columns=[f'x{i}' for i in range(embeddings_baseline_fashion.shape[1])])
df['label'] = labels


In [None]:
pca = PCA(n_components=2)

pca_embeddings_baseline_fashion = pca.fit_transform(embeddings_baseline_fashion)

df['pca1'] = pca_embeddings_baseline_fashion[:, 0]
df['pca2'] = pca_embeddings_baseline_fashion[:, 1]

plt.figure(figsize=(12, 5))
sns.scatterplot(data=df, x='pca1', y='pca2', hue='label', palette='tab10')

plt.show()


## Visualize embeddings scores on FashionMNIST

In [None]:
# Same silhouette score for the fashion mnist dataset

triplet_score = silhouette_score(embeddings_triplet_fashion, labels)
siamese_score = silhouette_score(embeddings_siamese_fashion, labels)
baseline_score = silhouette_score(embeddings_baseline_fashion, labels)

# Show bar chart
scores = [siamese_score, triplet_score, baseline_score]

plt.ylabel('Silhouette Score')
plt.show()


## Another losses comparison

In [None]:
# plotting the losses
plt.plot(triplet_train_loss, label='Triplet Loss', alpha=0.75)
plt.plot(siamese_train_loss, label='Siamese (Contrastive Loss)', alpha=0.75)
plt.plot(baseline_train_loss, label='Baseline', alpha=0.75)
plt.legend()

### This is not an Apple to Apple comparison

Since the models are train on different tasks the loss are not easy to compare. But what we can do is take the embeddings and apply a KNN on them and see how well they perform on a classification task.


In [None]:
# KNN
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split

# Get all embeddings in a pd.DataFrame
df_triplet = pd.DataFrame(embeddings_triplet, columns=[f'x{i}' for i in range(embeddings_triplet.shape[1])])
df_triplet['label'] = labels

df_siamese = pd.DataFrame(embeddings_siamese, columns=[f'x{i}' for i in range(embeddings_siamese.shape[1])])
df_siamese['label'] = labels

df_baseline = pd.DataFrame(embeddings_baseline, columns=[f'x{i}' for i in range(embeddings_baseline.shape[1])])
df_baseline['label'] = labels

# Split the data
X_train_triplet, X_test_triplet, y_train_triplet, y_test_triplet = train_test_split(df_triplet.drop(columns='label'), df_triplet['label'], test_size=0.2)
X_train_siamese, X_test_siamese, y_train_siamese, y_test_siamese = train_test_split(df_siamese.drop(columns='label'), df_siamese['label'], test_size=0.2)
X_train_baseline, X_test_baseline, y_train_baseline, y_test_baseline = train_test_split(df_baseline.drop(columns='label'), df_baseline['label'], test_size=0.2)

# Train the KNN

knn_triplet = KNeighborsClassifier(n_neighbors=5)
knn_siamese = KNeighborsClassifier(n_neighbors=5)
knn_baseline = KNeighborsClassifier(n_neighbors=5)

knn_triplet.fit(X_train_triplet, y_train_triplet)
knn_siamese.fit(X_train_siamese, y_train_siamese)
knn_baseline.fit(X_train_baseline, y_train_baseline)

# Evaluate the KNN using elbow method
from sklearn.metrics import accuracy_score

accuracy_triplet = []
accuracy_siamese = []
accuracy_baseline = []

for k in range(1, 20):
    knn_triplet = KNeighborsClassifier(n_neighbors=k)
    knn_siamese = KNeighborsClassifier(n_neighbors=k)
    knn_baseline = KNeighborsClassifier(n_neighbors=k)

    knn_triplet.fit(X_train_triplet, y_train_triplet)
    knn_siamese.fit(X_train_siamese, y_train_siamese)
    knn_baseline.fit(X_train_baseline, y_train_baseline)

    accuracy_triplet.append(accuracy_score(y_test_triplet, knn_triplet.predict(X_test_triplet)))
    accuracy_siamese.append(accuracy_score(y_test_siamese, knn_siamese.predict(X_test_siamese)))
    accuracy_baseline.append(accuracy_score(y_test_baseline, knn_baseline.predict(X_test_baseline)))
    
plt.plot(range(1, 20), accuracy_triplet, label='Triplet Loss', alpha=0.75)
plt.plot(range(1, 20), accuracy_siamese, label='Siamese Loss', alpha=0.75)
plt.plot(range(1, 20), accuracy_baseline, label='Baseline', alpha=0.75)
plt.legend()
plt.ylabel('Accuracy')
plt.xlabel('K')
plt.show()
