# Few-Shot Evaluation Notebook

This notebook implements the evaluation logic for a few-shot learning model using a ResNet50 backbone.

### Imports and configuration settings

**Key operations:**
1. The libraries needed are imported
2. **Config**: Config contains all the configurations necessary to run the script
3. Seeding to obtain a repeatable experiment

In [None]:
import torch
import sys
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import glob
import random
import xml.etree.ElementTree as ET
from itertools import cycle
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.models as models
from torchvision.models import ResNet50_Weights
from google.colab import drive
import zipfile
import pandas as pd

class Config:
    # 1. SETUP
    project_name = "FewShot_Evaluation"
    seed = 42
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # 2. DATASET PATH
    dataset_root = "LogoDet-3K/LogoDet-3K-divided"
    csv_index_path = "LogoDet-3K/brand_to_index.csv"

    # 3. MODEL PARAMETERS
    embedding_dim = 128
    pretrained = True
    freeze_layers = 5
    trained_model_path = ""

    # 4. EVALUATION SETTINGS
    prediciton_threashold = 0.5
    n_shot = 1
    num_episodes = 1000

torch.manual_seed(Config.seed)
random.seed(Config.seed)

def setup_dataset(zip_path, extract_to):
    """
    Mounts Google Drive and extracts the dataset if not already present.
    """
    # 1. Mount Google Drive
    if not os.path.exists('/content/drive'):
        drive.mount('/content/drive')

    # 2. Check if the folder already exists
    if os.path.exists(extract_to):
        print(f"Dataset folder '{extract_to}' already exists. Skipping extraction.")
    else:
        print(f"Extracting dataset from {zip_path}...")
        if os.path.exists(zip_path):
            with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                zip_ref.extractall(extract_to)
            print("Extraction complete.")
        else:
            print(f"ERROR: Zip file not found at {zip_path}. Check your path.")

setup_dataset("/content/drive/MyDrive/LogoDet-3K-divided.zip", "/content/LogoDet-3K")

### Dataset and Model Architecture

**Key operations:**
1. **DatasetTest**: Specialized loader for test images that parses XML files to retrieve label indices.
2. **LogoResNet50**: A modified ResNet50 architecture that replaces the final classifier with a layer generating feature embeddings.
3. **load_model**: A function that given the path of a saved model loads it and returns it

In [None]:
class DatasetTest(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    # Load label string to index mapping
        df = pd.read_csv(Config.csv_index_path)
        self.label_to_id = {row['brand']: int(row['index']) for _, row in df.iterrows()}


    def __len__(self):
        return len(self.file_list)

    def load_image(self, image_path):
        img = Image.open(image_path)
        if self.transform:
            img = self.transform(img)
        image_label = os.path.basename(os.path.dirname(image_path))

        label_idx = self.label_to_id[image_label]

        return {"image": img, "label": label_idx}

    def __getitem__(self, idx):
        return self.load_image(self.file_list[idx])

class LogoResNet50(nn.Module):
    def __init__(self, embedding_dim=128, pretrained=True, num_of_freeze_layer=5, activation_fn=None):
        super(LogoResNet50, self).__init__()

        # 1. Load Pre-trained Weights
        # Initialize the model with weights pretrained on ImageNet for transfer learning
        if pretrained:
            weights = ResNet50_Weights.DEFAULT
            self.model = models.resnet50(weights=weights)
        else:
            self.model = models.resnet50(weights=None)

        # 2. Modify the Head (Fully Connected Layer)
        # We need to produce feature embeddings instead of class probabilities
        input_features_fc = self.model.fc.in_features # Typically 2048 for ResNet50

        head_layers = []
        # Project features to the desired embedding dimension (e.g., 128)
        head_layers.append(nn.Linear(input_features_fc, embedding_dim))

        # Add an optional activation function if provided
        if activation_fn is not None:
            head_layers.append(activation_fn)

        # Replace the original classifier with our custom embedding head
        self.model.fc = nn.Sequential(*head_layers)

        # 3. Freezing Management
        # Define the blocks here to access them in the freeze method.
        # This structure allows progressive freezing/unfreezing strategies
        self.blocks = [
            ['conv1', 'bn1'],   # Level 1
            ['layer1'],         # Level 2
            ['layer2'],         # Level 3
            ['layer3'],         # Level 4
            ['layer4'],         # Level 5: Entire backbone frozen
        ]

        # Apply the initial freezing configuration
        self.freeze_numer_of_layer(num_of_freeze_layer)

    def forward(self, x):
        return self.model(x)

    def freeze_numer_of_layer(self, num_of_freeze_layer):
        """
        Manages layer freezing for transfer learning strategies.

        Args:
            num_of_freeze_layer (int):
              0   -> All layers unlocked (Full Fine-Tuning)
              1-5 -> Progressively freezes the backbone layers from shallow to deep
        """

        # STEP 1: RESET. Unfreeze everything (requires_grad = True).
        # This ensures we start from a clean state before applying new constraints.
        for param in self.model.parameters():
            param.requires_grad = True

        # If num is 0, exit immediately (Full Fine-Tuning mode)
        if num_of_freeze_layer == 0:
            print("Configuration: Full Fine-Tuning (All layers are trainable)")
            return

        # Safety check to avoid index out of bounds
        limit = min(num_of_freeze_layer, len(self.blocks))

        frozen_list = []

        # STEP 2: Progressively freeze the requested blocks
        for i in range(limit):
            current_blocks = self.blocks[i]
            for block_name in current_blocks:
                # Retrieve the layer by name
                layer = getattr(self.model, block_name)

                # Freeze parameters for this specific block
                for param in layer.parameters():
                    param.requires_grad = False

                frozen_list.append(block_name)

        print(f"Freezing Level {limit}. Frozen blocks: {frozen_list}")

def load_model(model_path, device):
    model = LogoResNet50(embedding_dim=Config.embedding_dim, pretrained=Config.pretrained, num_of_freeze_layer=Config.freeze_layers)
    # state = torch.load(model_path)
    # model.load_state_dict(state)
    model.to(device)
    model.eval()
    return model

### Evaluation Utilities

**Key operations:**
1. **MetricEvaluator**: Calculates standard metrics like Discriminant Ration, mAP, Precision, Recall, and F1 Score.
2. **FewShotIterator**: Manages the sampling of N-Shot tasks from the test dataset.
3. **getTestPaths**: A function that returns a set of paths from the test set

In [None]:
class MetricEvaluator:
    """
    A class to calculate evaluation metrics for Few-Shot Learning and Metric Learning.

    Implements:
    1. Discriminant Ratio (J): Optimized scalar implementation (O(d) memory).
    2. Mean Average Precision (mAP): Ranking quality metric.
    3. Recall at Fixed Precision (R@P): Operational metric.
    4. Precision & Recall: Raw metrics at a specific similarity threshold.
    5. F1 Score: Harmonic mean of Precision and Recall.
    """

    def __init__(self, device=None):
        """
        Initialize the evaluator.

        Args:
            device (str): 'cuda' or 'cpu'. If None, detects automatically.
        """
        if device:
            self.device = device
        else:
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.epsilon = 1e-6  # For numerical stability

    def compute_discriminant_ratio(self, embeddings, labels):
        """
        Calculates the Discriminant Ratio (J) using the optimized Scalar approach.

        Theory:
            J = Tr(Sb) / Tr(Sw)
            Using the Trace Trick: Tr(Sw) = Tr(St) - Tr(Sb)

        Args:
            embeddings (torch.Tensor): Tensor of shape (Batch_Size, Dimension).
            labels (torch.Tensor): Tensor of class labels.

        Returns:
            float: The Discriminant Ratio score.
        """
        embeddings = embeddings.to(self.device)
        labels = labels.to(self.device)

        # 1. Global Mean Computation
        global_mean = embeddings.mean(dim=0)

        # 2. Calculate Trace of Total Scatter (St)
        # Sum of squared Euclidean distances of all points from the global mean.
        tr_st = torch.sum((embeddings - global_mean) ** 2)

        # 3. Calculate Trace of Between-Class Scatter (Sb)
        tr_sb = 0
        unique_classes = torch.unique(labels)

        for c in unique_classes:
            class_mask = (labels == c)
            class_embeddings = embeddings[class_mask]
            n_c = class_embeddings.size(0)

            if n_c > 0:
                mu_c = class_embeddings.mean(dim=0)
                tr_sb += n_c * torch.sum((mu_c - global_mean) ** 2)

        # 4. Calculate Trace of Within-Class Scatter (Sw)
        tr_sw = tr_st - tr_sb

        # Calculate J
        j_score = tr_sb / (tr_sw + self.epsilon)

        return j_score.item()

    def compute_map(self, query_emb, gallery_emb, query_labels, gallery_labels):
        """
        Calculates Mean Average Precision (mAP).
        """
        query_emb = query_emb.to(self.device)
        gallery_emb = gallery_emb.to(self.device)
        query_labels = query_labels.to(self.device)
        gallery_labels = gallery_labels.to(self.device)

     
        distance_matrix = torch.cdist(query_emb, gallery_emb, p=2)


        num_queries = query_labels.size(0)
        average_precisions = []

        for i in range(num_queries):
            scores = distance_matrix[i]
            target_label = query_labels[i]

            # Ranking
            sorted_indices = torch.argsort(scores, descending=False)
            sorted_gallery_labels = gallery_labels[sorted_indices]

            # Relevance Mask
            relevance_mask = (sorted_gallery_labels == target_label).float()

            total_relevant = relevance_mask.sum()
            if total_relevant == 0:
                average_precisions.append(0.0)
                continue

            # Cumulative Precision
            cumsum = torch.cumsum(relevance_mask, dim=0)
            ranks = torch.arange(1, len(relevance_mask) + 1).to(self.device)
            precisions = cumsum / ranks

            # Average Precision (AP)
            ap = (precisions * relevance_mask).sum() / total_relevant
            average_precisions.append(ap.item())

        if not average_precisions:
            return 0.0
        return sum(average_precisions) / len(average_precisions)

    def compute_precision_recall(self, distance_scores, is_match, threshold=Config.prediciton_threashold):
        """
        Calculates raw Precision and Recall at a specific similarity threshold.

        Definitions:
            Precision = TP / (TP + FP)
            Recall    = TP / (TP + FN)

        Args:
            distance_scores: (torch.Tensor): 1D tensor of scores.
            is_match (torch.Tensor): 1D binary tensor (Ground Truth).
            threshold (float): Cutoff for deciding if a retrieval is Positive.

        Returns:
            tuple: (precision, recall)
        """
        distance_scores = distance_scores.to(self.device)
        is_match = is_match.to(self.device)

        # Binarize predictions: 1 if score <= threshold (Positive), else 0 (Negative)
        predicted_positive = (distance_scores <= threshold).float()

        # True Positives (TP): Predicted Positive AND Actually Match
        tp = (predicted_positive * is_match).sum()

        # False Positives (FP): Predicted Positive BUT Actually Non-Match
        fp = (predicted_positive * (1 - is_match)).sum()

        # False Negatives (FN): Predicted Negative BUT Actually Match
        # (We invert the prediction mask to find negatives)
        fn = ((1 - predicted_positive) * is_match).sum()

        precision = tp / (tp + fp + self.epsilon)
        recall = tp / (tp + fn + self.epsilon)

        return precision.item(), recall.item()

    def compute_recall_at_fixed_precision(self, distance_scores, is_match, min_precision=0.95):
        """
        Calculates Recall at a Fixed Precision (R@P).
        Finds the lowest threshold where Precision >= min_precision.
        """
        distance_scores = distance_scores.to(self.device)
        is_match = is_match.to(self.device)

        sorted_indices = torch.argsort(distance_scores, descending=False)
        sorted_matches = is_match[sorted_indices]

        tps = torch.cumsum(sorted_matches, dim=0)
        total_retrieved = torch.arange(1, len(sorted_matches) + 1).to(self.device)

        precisions = tps / total_retrieved

        # Find indices where Precision satisfies the constraint
        valid_indices = torch.where(precisions >= min_precision)[0]

        if len(valid_indices) == 0:
            return 0.0

        cutoff_index = valid_indices[-1]

        # Recall = TP_at_cutoff / Total_Relevant_In_Dataset
        total_relevant_in_dataset = is_match.sum()

        if total_relevant_in_dataset == 0:
            return 0.0

        recall = tps[cutoff_index] / total_relevant_in_dataset

        return recall.item()

    def compute_f1_score(self, precision, recall):
        """
        Calculates F1 Score (Harmonic Mean).
        """
        if (precision + recall) == 0:
            return 0.0
        return 2 * (precision * recall) / (precision + recall)

class FewShotIterator:
    def __init__(self, file_list, n_shot):
        """
        Initializes the iterator class.
        It prepares the global testset and creates a cyclic iterator over the valid brands.
        """
        self.n_shot = n_shot

        # 1. Validation: Check if input list is empty
        if not file_list:
            raise ValueError("The test file list is empty.")

        #    (Dataset - SupportSet)  is significantly faster with sets (O(1)) compared to lists.
        self.all_files_set = set(file_list)

        # 3. Organize data by Brand
        #    We create a dictionary mapping: { 'BrandName': [list_of_image_paths] }
        self.brands_map = {}

        for file_path in file_list:
            # Extract brand name assuming structure: .../Category/Brand/Image.jpg
            brand_name = os.path.basename(os.path.dirname(file_path))

            if brand_name not in self.brands_map:
                self.brands_map[brand_name] = []
            self.brands_map[brand_name].append(file_path)

        self.valid_brands_list = list(self.brands_map.keys())

        if not self.valid_brands_list:
            raise ValueError(f"No brand found with more than {n_shot} images.")

        #    'itertools.cycle' creates an infinite loop over the valid brands list.
        self.brand_iterator = cycle(self.valid_brands_list)

    def __call__(self):
        """
        Executed when the class instance is called.
        Logic:
        1. Pick next brand (Sequential).
        2. Pick Support Set (Random 5 images from that brand).
        3. Pick Query Set (EVERYTHING else in the testset).
        """
        # A. Get the next brand sequentially from the cycle
        try:
            selected_brand_name = next(self.brand_iterator)
        except StopIteration:
            # Gracefully signal that we are done
            print("Iterator finished: All brands have been processed.")
            return None

        # B. Retrieve all images specific to this chosen brand
        images_of_current_brand = self.brands_map[selected_brand_name]

        # Select a random number between 1 to 5 which is the number of images of the support brand guaranteed in the query set
        num_query_guarantee_if_available = random.randint(1, 5)

        # C. Create SUPPORT SET
        #    Select 'n_shot' unique images randomly from the current brand.
        support_set_list = random.sample(images_of_current_brand, self.n_shot)
        support_set_set = set(support_set_list)

        # D. Create QUERY SET (Global Subtraction)
        #    Requirement: The Query Set contains 50 images of which at least 1 is from the support brand
        #    Step 1: initialize the Query list
        query_set_list = []

        #    Step 2: Sample randomly the images to guarantee in the query set
        remaining_brand_images = list(set(images_of_current_brand) - support_set_set)
        guaranteed_images_in_query = random.sample(remaining_brand_images, min(num_query_guarantee_if_available, len(remaining_brand_images)))

        if (len(remaining_brand_images) == 0):
            print(f"for the brand {selected_brand_name} {len(remaining_brand_images)} images where put in the query set")

        #    Step 3: Sample the Negative Queries (Distractors from OTHER brands)
        # We subtract ALL images of the current brand to ensure zero accidental matches
        remaining_images_in_query = list(self.all_files_set - set(images_of_current_brand) - set(guaranteed_images_in_query))

        total_query_size = 50
        num_remaining_query = total_query_size - min(num_query_guarantee_if_available, len(remaining_brand_images))
        query_remaining = random.sample(remaining_images_in_query,  min(len(remaining_images_in_query),num_remaining_query))

        #    Step 4: Combine and Shuffle
        query_set_list = guaranteed_images_in_query + query_remaining
        random.shuffle(query_set_list)


        return {
            "brand_name": selected_brand_name,
            "support_set": support_set_list,
            "query_set": query_set_list
        }

def getTestPaths(root_dir, total_set_size=None, min_images_per_brand=2):
    test_path = os.path.join(root_dir, 'test')
    test_brand_list = []

    # Collect brand folders
    if not os.path.exists(test_path):
        print(f"Warning: {test_path} not found.")
        return []

    for category in os.listdir(test_path):
        cat_path = os.path.join(test_path, category)
        if os.path.isdir(cat_path):
            for brand in os.listdir(cat_path):
                brand_full_path = os.path.join(cat_path, brand)
                if os.path.isdir(brand_full_path):
                    test_brand_list.append(brand_full_path)

    test_data_list = []

    # Sampling Logic
    if total_set_size is not None:
        images_per_brand = round(total_set_size / len(test_brand_list))

        if images_per_brand < min_images_per_brand:
            new_test_brand_count = round(total_set_size / min_images_per_brand)
            test_brand_list = random.sample(test_brand_list, min(len(test_brand_list), new_test_brand_count))
            images_per_brand = min_images_per_brand

        for brand in test_brand_list:
            imgs = glob.glob(os.path.join(brand, '*.jpg'))

            if len(imgs) < min_images_per_brand:
                print(f"images are less than {min_images_per_brand} for this brand: {brand} in the TEST set")

            test_data_list.extend(random.sample(imgs, min(images_per_brand, len(imgs))))
    else:
        for brand in test_brand_list:
            test_data_list.extend(glob.glob(os.path.join(brand, '*.jpg')))

    return test_data_list

### Evaluation Loop

**Key operations:**
1. **evaluate_few_shot**: Runs the evaluation episodes, computes embeddings, and calculates aggregated metrics.
2. **Distance Calculation**: Distances are computed using F.pairwise_distance.
3. **compute_global_embeddings**: A function used to compute the embeddings of all passed images and return the embeddings + the respective labels.

In [None]:
def evaluate_few_shot(model, fewshot_iterator, transform, device, num_episodes=100):
    evaluator = MetricEvaluator(device=device)
    accuracies = []
    precisions = []
    recalls = []
    f1_scores = []
    r_at_95p = []
    ap_scores = []

    torch.cuda.empty_cache()

    unique_paths = list(fewshot_iterator.all_files_set)
    embs, labels = compute_global_embeddings(model, unique_paths, transform, device)
    j_score = evaluator.compute_discriminant_ratio(embs, labels)

    # 2. Set to eval mode and disable gradient tracking
    model.eval()
    with torch.no_grad():
        for i in range(num_episodes):

            task = fewshot_iterator()

            if task is None:
                print(f"Stopped early at episode {i} because we ran out of brands.")
                break

            support_paths = task["support_set"]
            query_paths = task["query_set"]

            # Build datasets and loaders
            support_dataset = DatasetTest(support_paths, transform)
            query_dataset = DatasetTest(query_paths, transform)

            support_loader = DataLoader(support_dataset, batch_size=32)
            query_loader = DataLoader(query_dataset, batch_size=64)

            # Extract embeddings
            support_embeddings = []
            query_embeddings = []
            query_labels = []
            support_labels = []

            # Compute embeddings for support set
            for data in support_loader:
                images = data["image"].to(device)
                support_embeddings.append(F.normalize(model(images), p=2, dim=1))

                batch_labels = data["label"]
                support_labels.append(batch_labels)
                support_brand = batch_labels[0]

            support_embeddings_tensor = torch.cat(support_embeddings)
            support_labels_tensor = torch.cat(support_labels)

            # Average embeddings
            averaged_support_embeddings = support_embeddings_tensor.mean(dim=0)
            averaged_support_embeddings_for_ap = support_embeddings_tensor.mean(dim=0, keepdim=True)
            averaged_support_embeddings_unsqueezed = averaged_support_embeddings.unsqueeze(0)
           
            # Compute embeddings for query set
            for data in query_loader:
                images = data["image"].to(device)
                query_embeddings.append(F.normalize(model(images), p=2, dim=1))

                batch_labels = data["label"]
                query_labels.append(batch_labels)

            # query_embeddings and query_labels are list of tensors, this unrolls them
            query_embeddings_tensor = torch.cat(query_embeddings)
            query_labels_tensor = torch.cat(query_labels)

            # mAP
            ap_score_single = evaluator.compute_map(
                query_emb=averaged_support_embeddings_unsqueezed,
                gallery_emb=query_embeddings_tensor,
                query_labels=support_brand.unsqueeze(0).to(device),
                gallery_labels=query_labels_tensor.to(device)
            )

            ap_scores.append(ap_score_single)

            # Pairwise Distance
            dists = F.pairwise_distance(averaged_support_embeddings_unsqueezed, query_embeddings_tensor)

            # Ground truth: query belongs to support brand?
            gt = (query_labels_tensor == support_brand).float().to(device)

            # Predictions
            pred = (dists <= Config.prediciton_threashold).float()


            # Accuracy
            acc = (pred == gt).float().mean().item()
            accuracies.append(acc)

            # Precision, Recall, F1
            neg_dists = -dists
            neg_threshold = -Config.prediciton_threashold

            prec, rec = evaluator.compute_precision_recall(neg_dists, gt, threshold=neg_threshold)
            f1 = evaluator.compute_f1_score(prec, rec)
            r95 = evaluator.compute_recall_at_fixed_precision(neg_dists, gt, min_precision=0.95)

            precisions.append(prec)
            recalls.append(rec)
            f1_scores.append(f1)
            r_at_95p.append(r95)

    # Aggregate results
    results = {
        "accuracy": sum(accuracies) / len(accuracies),
        "precision": sum(precisions) / len(precisions),
        "recall": sum(recalls) / len(recalls),
        "f1": sum(f1_scores) / len(f1_scores),
        "r@95p": sum(r_at_95p) / len(r_at_95p),
        "map": sum(ap_scores) / len(ap_scores),
        "J": j_score,
    }
    return results


def compute_global_embeddings(model, file_list, transform, device):
    model.eval()
    all_embeddings = []
    all_labels = []

    # Simple linear dataset of all unique test images
    dataset = DatasetTest(file_list, transform)
    loader = DataLoader(dataset, batch_size=64, shuffle=False)

    with torch.no_grad():
        for data in loader:
            images = data["image"].to(device)
            labels = data["label"]

            # Extract and move to CPU to save VRAM
            embeddings = F.normalize(model(images), p=2, dim=1).cpu()

            all_embeddings.append(embeddings)
            all_labels.append(labels)

    return torch.cat(all_embeddings), torch.cat(all_labels)

def main():
    transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(),
                                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    device = torch.device(Config.device)
    model = load_model(Config.trained_model_path, device)
    # model.load_state_dict(torch.load(Config.trained_model_path)) # Uncomment to load trained weights

    test_paths = getTestPaths(Config.dataset_root) # Optionally add total_set_size and min_images_per_brand
    iterator = FewShotIterator(test_paths, n_shot=Config.n_shot)

    results = evaluate_few_shot(model, iterator, transform, device, Config.num_episodes)
    print("\n=== Evaluation Results ===")
    for k, v in results.items():
        print(f"{k.capitalize():<10}: {v:.4f}")

if __name__ == "__main__":
    main()

Freezing Level 5. Frozen blocks: ['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'layer4']
J score computed: 1.9644616842269897

=== Evaluation Results ===
Accuracy  : 0.8944
Precision : 0.1313
Recall    : 0.4800
F1        : 0.1786
R@95p     : 0.3300
Map       : 0.4174
J         : 1.9645


### **show_visual_demo** Function 

This function performs a Visual Qualitative Analysis of the Deep Metric Learning model's performance. The goal is to simulate a real-world image retrieval scenario and verify if the network has successfully learned to distinguish between different brands.

The workflow is as follows:

1. **Triplet Sampling**: The code randomly selects three images from the Test Dataset:

- Anchor (Query): A reference image of a random brand (e.g., "Adidas").

- Positive: Another image of the same brand as the Anchor (the correct match the model should retrieve).

- Negative: An image of a different brand (distractor) that the model should reject.

2. **Feature Extraction (Inference)**: The three images are pre-processed and passed through the trained model. The model outputs a feature embedding (a 128-dimensional numeric vector) for each image, representing the visual characteristics of the logo. These vectors are then normalized.

3. **Distance Calculation**: The Pairwise Distance (Euclidean Distance) is computed between the embeddings:

Anchor vs. Positive: We expect a very low value (close to 0.0), indicating the vectors are close in the latent space.

Anchor vs. Negative: We expect a higher value, indicating the model successfully pushed the features of different brands apart.

4. **Result Visualization**: A plot with three side-by-side panels is generated. The title color indicates the test outcome based on the configured threshold (Config.prediciton_threashold, typically 0.5):

**Center Panel (Positive Pair)**:

- Green (MATCH): Similarity is above the threshold. The model correctly recognized the brand (True Positive).

- Red (MISSED): Similarity is below the threshold. The model failed to recognize the same brand (False Negative).

**Right Panel (Negative Pair)**:

- Green (REJECTED): Similarity is below the threshold. The model correctly distinguished the different brands (True Negative).

- Red (FALSE POSITIVE): Similarity is above the threshold. The model confused the different brand with the Anchor (False Positive).

In [None]:
import matplotlib.pyplot as plt


def get_embedding_demo(model, image_path, transform, device):
    """Extracts the embedding of a single image for the demo."""
    try:
        img = Image.open(image_path).convert('RGB')
    except Exception as e:
        print(f"Error opening {image_path}: {e}")
        return None, None
    
    # Prepare the tensor (add batch dimension)
    img_tensor = transform(img).unsqueeze(0).to(device)
    
    model.eval()
    with torch.no_grad():
        emb = model(img_tensor)
        emb = F.normalize(emb, p=2, dim=1) 
    
    return emb, img

def show_visual_demo():
 
   
    device = torch.device(Config.device)
    

    try:
        model = load_model(Config.trained_model_path, device)
    except NameError:
        print("ERROR: You must run the cell with class definitions (LogoResNet50, Config, etc.) first.")
        return

    # Trasformations
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # Retrieve test files
    test_paths = getTestPaths(Config.dataset_root)
    if not test_paths:
        print("No images found. Check Config.dataset_root")
        return

    # Organize by Brand
    brands_map = {}
    for p in test_paths:
        b = os.path.basename(os.path.dirname(p))
        brands_map.setdefault(b, []).append(p)
    all_brands = list(brands_map.keys())

    # Search for a valid pair (Anchor + Positive)
    max_retries = 100
    found = False
    
    for _ in range(max_retries):
        target_brand = random.choice(all_brands)
        if len(brands_map[target_brand]) >= 2:
            # Found a brand with at least 2 photos
            target_imgs = random.sample(brands_map[target_brand], 2)
            anchor_path = target_imgs[0]
            positive_path = target_imgs[1]
            
            # Search for a negative (different brand)
            while True:
                neg_brand = random.choice(all_brands)
                if neg_brand != target_brand and len(brands_map[neg_brand]) > 0:
                    negative_path = random.choice(brands_map[neg_brand])
                    distractor_name = neg_brand
                    break
            found = True
            break
    
    if not found:
        print("Could not find a brand with enough images for the demo.")
        return

    # Compute Embeddings and Similarity
    emb_anchor, img_anchor = get_embedding_demo(model, anchor_path, transform, device)
    emb_pos, img_pos = get_embedding_demo(model, positive_path, transform, device)
    emb_neg, img_neg = get_embedding_demo(model, negative_path, transform, device)


    dist_pos = F.pairwise_distance(emb_anchor, emb_pos).item()
    dist_neg = F.pairwise_distance(emb_anchor, emb_neg).item()

    # Draw the Plot
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Anchor Image
    axes[0].imshow(img_anchor)
    axes[0].set_title(f"ANCHOR (Query)\nBrand: {target_brand}", fontsize=14, color='blue', fontweight='bold')
    axes[0].axis('off')

    # Positive Image
    is_match = dist_pos <= Config.prediciton_threashold
    color_p = 'green' if is_match else 'red'
    label_p = "MATCH" if is_match else "MISSED"
    
    axes[1].imshow(img_pos)
    axes[1].set_title(f"POSITIVE (Stesso Brand)\nDist: {dist_pos:.4f}\n{label_p}", fontsize=14, color=color_p, fontweight='bold')
    axes[1].axis('off')

    # Negative Image
    is_reject = dist_neg > Config.prediciton_threashold
    color_n = 'green' if is_reject else 'red'
    label_n = "REJECTED" if is_reject else "FALSE POSITIVE"
    
    axes[2].imshow(img_neg)
    axes[2].set_title(f"NEGATIVE (Brand: {distractor_name})\nDist: {dist_neg:.4f}\n{label_n}", fontsize=14, color=color_n, fontweight='bold')
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()


show_visual_demo()

### **3D t-SNE Visualization of the Embedding Space** 
Generates an interactive 3D t-SNE visualization of the model's embeddings.

This function randomly selects a subset of classes (brands) from the test set,
extracts their feature embeddings using the provided model, and reduces
the dimensionality to 3 components using t-SNE.

The result is plotted as an interactive 3D scatter plot where clusters
can be rotated and inspected to assess the quality of the learned metric space.

In [None]:
import plotly.express as px
import pandas as pd
from sklearn.manifold import TSNE
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image


def visualize_tsne_3d(model, test_paths, num_classes=10):
    device = torch.device(Config.device)
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

   
    brands_map = {}
    for p in test_paths:
        b = os.path.basename(os.path.dirname(p))
        brands_map.setdefault(b, []).append(p)


    available_brands = list(brands_map.keys())
    num_to_plot = min(len(available_brands), num_classes)
   
    selected_brands = random.sample(available_brands, num_to_plot)

    embeddings = []
    labels = []

  
    model.eval()
    with torch.no_grad():
        for i, brand in enumerate(selected_brands):
            img_paths = brands_map[brand][:30]
            for p in img_paths:
                try:
                    img = Image.open(p).convert('RGB')
                    t_img = transform(img).unsqueeze(0).to(device)
                    emb = F.normalize(model(t_img), p=2, dim=1).cpu().numpy()

                    embeddings.append(emb[0])
                    labels.append(brand)
                except Exception as e:
                    print(f"Error reading {p}: {e}")
                    pass

    embeddings = np.array(embeddings)


    tsne = TSNE(n_components=3, random_state=42, perplexity=30, n_iter=1000)
    tsne_results = tsne.fit_transform(embeddings)


    df_tsne = pd.DataFrame({
        'x': tsne_results[:, 0],
        'y': tsne_results[:, 1],
        'z': tsne_results[:, 2],
        'Brand': labels
    })

    fig = px.scatter_3d(
        df_tsne, x='x', y='y', z='z',
        color='Brand', 
        title=f"3D t-SNE Visualization of {num_to_plot} Brands",
        hover_name='Brand', 
        opacity=0.7,
        size_max=10,
        template="plotly_dark" 
    )

    fig.update_traces(marker=dict(size=4, line=dict(width=0)))

    fig.update_layout(margin=dict(l=0, r=0, b=0, t=40))
    
    fig.show()


if 'Config' in globals() and 'load_model' in globals() and 'getTestPaths' in globals():
    try:
    
        if 'model' not in globals():
             model = load_model(Config.trained_model_path, torch.device(Config.device))

        test_paths = getTestPaths(Config.dataset_root)
        
        if test_paths:
            visualize_tsne_3d(model, test_paths, num_classes=10) 
        else:
            print("Image not found")
            
    except Exception as e:
        print(f"Error during the execution: {e}")
else:
    print("Config, load_model o getTestPaths not defined")