In [4]:
import numpy as np  
import torch 
import matplotlib.pyplot as plt 
import sys
sys.path.append('/home/alex/Documents/InformationHeuristic')
from src.datasets.datasets import cifar100_dataset

In [22]:
from torch.nn import functional as F
def compute_knn_accuracy(test_features,test_labels,train_features,train_labels,k=1,
    distance_metric="euclidean"):
    """
    Computes the KNN Top-N classification accuracy using either Euclidean or Cosine similarity.

    Args:
        test_features (torch.Tensor): A tensor of features for the test samples.
                                     Shape: (N_test, feature_dim).
        test_labels (torch.Tensor): A tensor of corresponding labels for the
                                   test samples. Shape: (N_test).
        train_features (torch.Tensor): A tensor of features for the training/database samples.
                                  Shape: (N_train, feature_dim).
        train_labels (torch.Tensor): A tensor of corresponding labels for the
                                database samples. Shape: (N_train).
        k (int): The number of nearest neighbors to consider for classification.
                 Must be a positive integer. Defaults to 1 (Top-1 accuracy).
        distance_metric (str): The metric to use for finding neighbors.
                               Accepts "euclidean" or "cosine".
                               Defaults to "euclidean".

    Returns:
        float: The computed Top-N accuracy as a decimal value (e.g., 0.85 for 85%).
    """
    # Ensure inputs are tensors and on the same device.
    if not all(isinstance(t, torch.Tensor) for t in [test_features, test_labels, train_features, train_labels]):
        raise TypeError("All inputs must be torch.Tensor objects.")

    # Validate k and distance_metric.
    if not isinstance(k, int) or k < 1:
        raise ValueError("k must be a positive integer.")
    if distance_metric not in ["euclidean", "cosine"]:
        raise ValueError("distance_metric must be either 'euclidean' or 'cosine'.")

    if distance_metric == "euclidean":
        # Calculate the pairwise Euclidean distance matrix.
        # Shape: (N_test, N_train). Lower value means closer.
        distances = torch.cdist(test_features, train_features, p=2)
        # Find the k smallest distances and their corresponding indices.
        # The 'sorted=True' ensures that the indices are in ascending order of distance.
        topk_values, topk_indices = torch.topk(distances, k, largest=False, sorted=True)

    elif distance_metric == "cosine":
        # Cosine similarity requires normalized vectors for a correct interpretation.
        test_features = F.normalize(test_features, dim=1)
        train_features = F.normalize(train_features, dim=1)
        
        # Calculate the pairwise cosine similarity matrix.
        # Shape: (N_test, N_train). Higher value means closer similarity.
        similarities = F.cosine_similarity(test_features.unsqueeze(1), train_features.unsqueeze(0), dim=2)
        
        # Find the k largest similarities and their corresponding indices.
        # 'largest=True' is crucial here as we are looking for maximum similarity.
        topk_values, topk_indices = torch.topk(similarities, k, largest=True, sorted=True)

    # Use the top-k indices to look up the labels of the nearest neighbors.
    # This results in a tensor of shape (N_test, k).
    predicted_labels_topk = train_labels[topk_indices]

    # Reshape test labels to (N_test, 1) to enable a direct comparison.
    test_labels_reshaped = test_labels.unsqueeze(1)

    # Check if the true label is present in the top-k predicted labels for each query.
    # The comparison '==' broadcasts the test_labels_reshaped across the k dimension.
    # 'torch.any(..., dim=1)' checks if at least one of the k predictions is correct.
    correct_predictions = torch.any(predicted_labels_topk == test_labels_reshaped, dim=1)

    # Calculate the accuracy as the mean of correct predictions.
    accuracy = correct_predictions.float().mean().item()

    return accuracy



In [16]:
train_dataset,test_dataset,_ = cifar100_dataset(data_folder='../../data/cifar100')

Files already downloaded and verified
Files already downloaded and verified


In [12]:
train_labels = [train_dataset[i][1] for i in range(len(train_dataset))]
test_labels = [test_dataset[i][1] for i in range(len(test_dataset))]

train_labels = torch.tensor(train_labels)
test_labels = torch.tensor(test_labels)

In [None]:
train_features = torch.load("/home/alex/Documents/InformationHeuristic/weights/cifar100/representations/barlowtwin/barlowtwin_resnet18_5.00/train_features.pt")
 test_features = torch.load("/home/alex/Documents/InformationHeuristic/weights/cifar100/representations/barlowtwin/barlowtwin_resnet18_5.00/test_features.pt")

In [24]:
compute_knn_accuracy(test_features[:1000],test_labels[:1000],train_features,train_labels,distance_metric="cosine")

: 