In [None]:
import os
import torch
import random
import torchvision
from PIL import Image
from torchvision import datasets, transforms
from torch import nn
from utils.utils import LoadDataset, set_seed
from simclr.simclr_model import SimCLR
from byol.byol_model import BYOL
from moco.moco_model import MoCo
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
set_seed(42)

In [None]:
def extracting_feature_vectors_from_simclr(model_path, data_loader):

    """Function to extract feature vectors from a SimCLR model
    Parameters:
    model_path (str): Path to the model
    data_loader (torch.utils.data.DataLoader): Data loader for the dataset
    Output:
    feature_vectors (numpy array): Feature vectors extracted from the model
    labels (numpy array): Labels of the feature vectors
    """

    set_seed(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    resnet = torchvision.models.resnet18()
    backbone = nn.Sequential(*list(resnet.children())[:-1])
    model = SimCLR(backbone)
    model.load_state_dict(torch.load(model_path))
    model.to(device)
    model.eval()

    feature_vectors = []
    labels = []

    with torch.no_grad():
        for images, label in data_loader:
            images = images.to(device) 
            outputs = model.backbone(images).flatten(start_dim=1)
            feature_vectors.append(outputs.cpu().numpy())
            labels.append(label.numpy())
    feature_vectors = np.concatenate(feature_vectors)
    labels = np.concatenate(labels)

    return feature_vectors, labels

In [None]:
def extracting_feature_vectors_from_byol(model_path, data_loader):

    """Function to extract feature vectors from a BYOL model
    Parameters:
    model_path (str): Path to the model
    data_loader (torch.utils.data.DataLoader): Data loader for the dataset
    Output:
    feature_vectors (numpy array): Feature vectors extracted from the model
    labels (numpy array): Labels of the feature vectors
    """
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    resnet = torchvision.models.resnet18()
    backbone = nn.Sequential(*list(resnet.children())[:-1])
    model = BYOL(backbone)
    model.load_state_dict(torch.load(model_path))
    model.to(device)
    model.eval()

    feature_vectors = []
    labels = []

    with torch.no_grad():
        for images, label in data_loader:
            images = images.to(device) 
            outputs = model.backbone(images).flatten(start_dim=1)
            feature_vectors.append(outputs.cpu().numpy())
            labels.append(label.numpy())
    feature_vectors = np.concatenate(feature_vectors)
    labels = np.concatenate(labels)

    return feature_vectors, labels

In [None]:
def extracting_feature_vectors_from_moco(model_path, data_loader):

    """Function to extract feature vectors from a MoCo model
    Parameters:
    model_path (str): Path to the model
    data_loader (torch.utils.data.DataLoader): Data loader for the dataset
    Output:
    feature_vectors (numpy array): Feature vectors extracted from the model
    labels (numpy array): Labels of the feature vectors
    """
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    resnet = torchvision.models.resnet18()
    backbone = nn.Sequential(*list(resnet.children())[:-1])
    model = MoCo(backbone)
    model.load_state_dict(torch.load(model_path))
    model.to(device)
    model.eval()

    feature_vectors = []
    labels = []

    with torch.no_grad():
        for images, label in data_loader:
            images = images.to(device) 
            outputs = model.backbone(images).flatten(start_dim=1)
            feature_vectors.append(outputs.cpu().numpy())
            labels.append(label.numpy())
    feature_vectors = np.concatenate(feature_vectors)
    labels = np.concatenate(labels)

    return feature_vectors, labels

In [None]:
# Simclr models, Replace the paths with the paths to your models paths on your machine
simclr_models = {"seed 0": ["/home/jovyan/models/trained_models/seed_zero/simclr/simclr_model_center_cropping.pth", 
                            "/home/jovyan/models/trained_models/seed_zero/simclr/simclr_model_random_cropping.pth", 
                            "/home/jovyan/models/trained_models/seed_zero/simclr/simclr_model_color_jitter.pth", 
                            "/home/jovyan/models/trained_models/seed_zero/simclr/simclr_model_random_flipping.pth", 
                            "/home/jovyan/models/trained_models/seed_zero/simclr/simclr_model_random_perspective.pth", 
                            "/home/jovyan/models/trained_models/seed_zero/simclr/simclr_model_random_rotation.pth", 
                            "/home/jovyan/models/trained_models/seed_zero/simclr/simclr_model_random_grayscale.pth", 
                            "/home/jovyan/models/trained_models/seed_zero/simclr/simclr_model_gaussian_blur.pth", 
                            "/home/jovyan/models/trained_models/seed_zero/simclr/simclr_model_random_invert.pth", 
                            "/home/jovyan/models/trained_models/seed_zero/simclr/simclr_model_random_erasing.pth"], 
                 
                "seed 42": ["/home/jovyan/models/trained_models/seed_42/simclr/simclr_model_center_cropping.pth", 
                            "/home/jovyan/models/trained_models/seed_42/simclr/simclr_model_random_cropping.pth", 
                            "/home/jovyan/models/trained_models/seed_42/simclr/simclr_model_color_jitter.pth", 
                            "/home/jovyan/models/trained_models/seed_42/simclr/simclr_model_random_flipping.pth", 
                            "/home/jovyan/models/trained_models/seed_42/simclr/simclr_model_random_perspective.pth", 
                            "/home/jovyan/models/trained_models/seed_42/simclr/simclr_model_random_rotation.pth", 
                            "/home/jovyan/models/trained_models/seed_42/simclr/simclr_model_random_grayscale.pth", 
                            "/home/jovyan/models/trained_models/seed_42/simclr/simclr_model_gaussian_blur.pth", 
                            "/home/jovyan/models/trained_models/seed_42/simclr/simclr_model_random_invert.pth", 
                            "/home/jovyan/models/trained_models/seed_42/simclr/simclr_model_random_erasing.pth"],
                 
                "seed 123": ["/home/jovyan/models/trained_models/seed_123/simclr/simclr_model_center_cropping.pth", 
                             "/home/jovyan/models/trained_models/seed_123/simclr/simclr_model_random_cropping.pth", 
                             "/home/jovyan/models/trained_models/seed_123/simclr/simclr_model_color_jitter.pth", 
                             "/home/jovyan/models/trained_models/seed_123/simclr/simclr_model_random_flipping.pth", 
                             "/home/jovyan/models/trained_models/seed_123/simclr/simclr_model_random_perspective.pth", 
                             "/home/jovyan/models/trained_models/seed_123/simclr/simclr_model_random_rotation.pth", 
                             "/home/jovyan/models/trained_models/seed_123/simclr/simclr_model_random_grayscale.pth", 
                             "/home/jovyan/models/trained_models/seed_123/simclr/simclr_model_gaussian_blur.pth", 
                             "/home/jovyan/models/trained_models/seed_123/simclr/simclr_model_random_invert.pth", 
                             "/home/jovyan/models/trained_models/seed_123/simclr/simclr_model_random_erasing.pth"]}

In [None]:
# Byol models, Replace the paths with the paths to your models paths on your machine
byol_models = {"seed 0": ["/home/jovyan/models/trained_models/seed_zero/byol/byol_model_center_cropping.pth", 
                          "/home/jovyan/models/trained_models/seed_zero/byol/byol_model_random_cropping.pth", 
                          "/home/jovyan/models/trained_models/seed_zero/byol/byol_model_color_jitter.pth", 
                          "/home/jovyan/models/trained_models/seed_zero/byol/byol_model_random_flipping.pth", 
                          "/home/jovyan/models/trained_models/seed_zero/byol/byol_model_random_perspective.pth", 
                          "/home/jovyan/models/trained_models/seed_zero/byol/byol_model_random_rotation.pth", 
                          "/home/jovyan/models/trained_models/seed_zero/byol/byol_model_random_grayscale.pth", 
                          "/home/jovyan/models/trained_models/seed_zero/byol/byol_model_gaussian_blur.pth", 
                          "/home/jovyan/models/trained_models/seed_zero/byol/byol_model_random_invert.pth", 
                          "/home/jovyan/models/trained_models/seed_zero/byol/byol_model_random_erasing.pth"], 
                 
                "seed 42": ["/home/jovyan/models/trained_models/seed_42/byol/byol_model_center_cropping.pth", 
                            "/home/jovyan/models/trained_models/seed_42/byol/byol_model_random_cropping.pth", 
                            "/home/jovyan/models/trained_models/seed_42/byol/byol_model_color_jitter.pth", 
                            "/home/jovyan/models/trained_models/seed_42/byol/byol_model_random_flipping.pth", 
                            "/home/jovyan/models/trained_models/seed_42/byol/byol_model_random_perspective.pth", 
                            "/home/jovyan/models/trained_models/seed_42/byol/byol_model_random_rotation.pth", 
                            "/home/jovyan/models/trained_models/seed_42/byol/byol_model_random_grayscale.pth", 
                            "/home/jovyan/models/trained_models/seed_42/byol/byol_model_gaussian_blur.pth", 
                            "/home/jovyan/models/trained_models/seed_42/byol/byol_model_random_invert.pth", 
                            "/home/jovyan/models/trained_models/seed_42/byol/byol_model_random_erasing.pth"],
                 
                "seed 123": ["/home/jovyan/models/trained_models/seed_123/byol/byol_model_center_cropping.pth", 
                             "/home/jovyan/models/trained_models/seed_123/byol/byol_model_random_cropping.pth", 
                             "/home/jovyan/models/trained_models/seed_123/byol/byol_model_color_jitter.pth", 
                             "/home/jovyan/models/trained_models/seed_123/byol/byol_model_random_flipping.pth", 
                             "/home/jovyan/models/trained_models/seed_123/byol/byol_model_random_perspective.pth", 
                             "/home/jovyan/models/trained_models/seed_123/byol/byol_model_random_rotation.pth", 
                             "/home/jovyan/models/trained_models/seed_123/byol/byol_model_random_grayscale.pth", 
                             "/home/jovyan/models/trained_models/seed_123/byol/byol_model_gaussian_blur.pth", 
                             "/home/jovyan/models/trained_models/seed_123/byol/byol_model_random_invert.pth", 
                             "/home/jovyan/models/trained_models/seed_123/byol/byol_model_random_erasing.pth"]}

In [None]:
# MoCo models, Replace the paths with the paths to your models paths on your machine
moco_models = {"seed 0": ["/home/jovyan/models/trained_models/seed_zero/moco/moco_model_center_cropping.pth", 
                          "/home/jovyan/models/trained_models/seed_zero/moco/moco_model_random_cropping.pth", 
                          "/home/jovyan/models/trained_models/seed_zero/moco/moco_model_color_jitter.pth", 
                          "/home/jovyan/models/trained_models/seed_zero/moco/moco_model_random_flipping.pth", 
                          "/home/jovyan/models/trained_models/seed_zero/moco/moco_model_random_perspective.pth", 
                          "/home/jovyan/models/trained_models/seed_zero/moco/moco_model_random_rotation.pth", 
                          "/home/jovyan/models/trained_models/seed_zero/moco/moco_model_random_grayscale.pth", 
                          "/home/jovyan/models/trained_models/seed_zero/moco/moco_model_gaussian_blur.pth", 
                          "/home/jovyan/models/trained_models/seed_zero/moco/moco_model_random_invert.pth", 
                          "/home/jovyan/models/trained_models/seed_zero/moco/moco_model_random_erasing.pth"], 
                 
                "seed 42": ["/home/jovyan/models/trained_models/seed_42/moco/moco_model_center_cropping.pth", 
                            "/home/jovyan/models/trained_models/seed_42/moco/moco_model_random_cropping.pth", 
                            "/home/jovyan/models/trained_models/seed_42/moco/moco_model_color_jitter.pth", 
                            "/home/jovyan/models/trained_models/seed_42/moco/moco_model_random_flipping.pth", 
                            "/home/jovyan/models/trained_models/seed_42/moco/moco_model_random_perspective.pth", 
                            "/home/jovyan/models/trained_models/seed_42/moco/moco_model_random_rotation.pth", 
                            "/home/jovyan/models/trained_models/seed_42/moco/moco_model_random_grayscale.pth", 
                            "/home/jovyan/models/trained_models/seed_42/moco/moco_model_gaussian_blur.pth", 
                            "/home/jovyan/models/trained_models/seed_42/moco/moco_model_random_invert.pth", 
                            "/home/jovyan/models/trained_models/seed_42/moco/moco_model_random_erasing.pth"],
                 
                "seed 123": ["/home/jovyan/models/trained_models/seed_123/moco/moco_model_center_cropping.pth", 
                             "/home/jovyan/models/trained_models/seed_123/moco/moco_model_random_cropping.pth", 
                             "/home/jovyan/models/trained_models/seed_123/moco/moco_model_color_jitter.pth", 
                             "/home/jovyan/models/trained_models/seed_123/moco/moco_model_random_flipping.pth", 
                             "/home/jovyan/models/trained_models/seed_123/moco/moco_model_random_perspective.pth", 
                             "/home/jovyan/models/trained_models/seed_123/moco/moco_model_random_rotation.pth", 
                             "/home/jovyan/models/trained_models/seed_123/moco/moco_model_random_grayscale.pth", 
                             "/home/jovyan/models/trained_models/seed_123/moco/moco_model_gaussian_blur.pth", 
                             "/home/jovyan/models/trained_models/seed_123/moco/moco_model_random_invert.pth", 
                             "/home/jovyan/models/trained_models/seed_123/moco/moco_model_random_erasing.pth"]}

In [None]:
def compute_distances_for_single_image(query_vector, feature_vectors):
    
    """Function to compute the cosine similarity between a query vector and a set of feature vectors
    Parameters:
    query_vector (numpy array): Query vector
    feature_vectors (numpy array): Feature vectors
    Output:
    distances (numpy array): Cosine similarity between the query vector and the feature vectors
    """

    distances = cosine_similarity(query_vector.reshape(1, -1), feature_vectors)
    return distances.flatten()

In [None]:
def format_model_name(model_path):

    """Function to format the model name
    Parameters:
    model_path (str): Path to the model
    Output:
    formatted_name (str): Formatted model name
    """
    
    model_name = model_path.split('/')[-1]
    
    for prefix in ["simclr_model_", "byol_model_", "moco_model_"]:
        model_name = model_name.replace(prefix, "")
    formatted_name = model_name.replace(".pth", "").replace("_", " ").capitalize()
    
    return formatted_name

In [None]:
# Load the datasets, Replace the paths with the paths to your datasets paths on your machine
binary_class_data =  LoadDataset("/home/jovyan/data/cat_dog/", 50).load_data()
vehicles_dataset = LoadDataset("/home/jovyan/data/vehicles/", 50).load_data()
clothing_dataset = LoadDataset("/home/jovyan/data/clothing/", 50).load_data()

In [None]:
datasets = {
    "Cat vs Dogs": binary_class_data,
    "Vehicles": vehicles_dataset,
    "Clothing": clothing_dataset,
}

In [None]:
def plot_comparison_nearest_neighbors(query_image, nearest_neighbors_images_by_seed, distances_by_seed, title):

    """Function to plot the nearest neighbors of a query image
    Parameters:
    query_image (torch.Tensor): Query image
    nearest_neighbors_images_by_seed (dict): Nearest neighbors images by seed
    distances_by_seed (dict): Distances by seed
    title (str): Title of the plot
    Output:
    Plot of the nearest neighbors of the query image
    """
    
    num_seeds = len(nearest_neighbors_images_by_seed)
    num_neighbors = len(next(iter(nearest_neighbors_images_by_seed.values())))

    fig, axes = plt.subplots(num_seeds, num_neighbors + 1, figsize=(3 * (num_neighbors + 1), 3 * num_seeds))
    fig.suptitle(title, fontsize=20, y=0.96, fontweight="bold")
    
    for i, (seed, neighbors_images) in enumerate(nearest_neighbors_images_by_seed.items()):
        axes[i, 0].imshow(query_image.permute(1, 2, 0).numpy())
        axes[i, 0].set_title(f"Query image", fontsize=12, pad=10)
        axes[i, 0].set_ylabel(f"{seed}", fontsize=16, fontweight="bold")
        axes[i, 0].set_xticks([]) 
        axes[i, 0].set_yticks([]) 

        for j, (neighbor_img, distance) in enumerate(zip(neighbors_images, distances_by_seed[seed])):
            ax = axes[i, j + 1]
            ax.imshow(neighbor_img.permute(1, 2, 0).numpy())
            ax.axis("off")
            ax.text(0.5, 1.05, f"Similarity: {distance:.8f}", transform=ax.transAxes, 
                    fontsize=12, ha='center', va='bottom', color='black')
    
    path = f"/home/jovyan/scripts/plots/saved_plots/nearest_neighbours/{title}.pdf"
    plt.savefig(path, bbox_inches='tight', dpi=50)
    plt.subplots_adjust(wspace=0.1, hspace=0.4)
    plt.show()

In [None]:
def process_and_plot_nearest_neighbors(model_name, models, datasets):

    """Function to process and plot the nearest neighbors of a query image
    Parameters:
    model_name (str): Name of the model
    models (dict): Models
    datasets (dict): Datasets
    Output:
    Plot of the nearest neighbors of the query image
    """

    for dataset_name, dataloader in datasets.items():
        for model_paths in zip(*models.values()):
            
            seed_feature_vectors = {}
            seed_labels = {}
            for seed, model_path in zip(models.keys(), model_paths):
                if model_name == "SimCLR":
                    feature_vectors, labels = extracting_feature_vectors_from_simclr(model_path, dataloader)
                elif model_name == "BYOL":
                    feature_vectors, labels = extracting_feature_vectors_from_byol(model_path, dataloader)
                elif model_name == "MoCo":
                    feature_vectors, labels = extracting_feature_vectors_from_moco(model_path, dataloader)

                seed_feature_vectors[seed] = feature_vectors
                seed_labels[seed] = labels

            query_image, query_label = next(iter(dataloader))
            query_image = query_image[0] 
            query_vector = seed_feature_vectors["seed 0"][0]  
            augmentation_name = format_model_name(model_paths[0])

            nearest_neighbors_images_by_seed = {}
            distances_by_seed = {}
            for seed in seed_feature_vectors.keys():
                distances = compute_distances_for_single_image(query_vector, seed_feature_vectors[seed])
                nearest_neighbor_indices = np.argsort(distances)[:10]
                nearest_neighbors_images = [dataloader.dataset[i][0] for i in nearest_neighbor_indices]
                nearest_neighbors_images_by_seed[seed] = nearest_neighbors_images
                distances_by_seed[seed] = [distances[i] for i in nearest_neighbor_indices]

            plot_comparison_nearest_neighbors(
                query_image,
                nearest_neighbors_images_by_seed,
                distances_by_seed,
                title=f"{model_name} - {dataset_name} - {augmentation_name}"
            )

In [None]:
# Process and plot the SimCLR nearest neighbors of the query image for all models and datasets
process_and_plot_nearest_neighbors("SimCLR", simclr_models, datasets)

In [None]:
# Process and plot the BYOL nearest neighbors of the query image for all models and datasets
process_and_plot_nearest_neighbors("BYOL", byol_models, datasets)

In [None]:
# Process and plot the MoCo nearest neighbors of the query image for all models and datasets
process_and_plot_nearest_neighbors("MoCo", moco_models, datasets) 