# Few-Shot Evaluation Notebook

This notebook implements the evaluation logic for a few-shot learning model using a ResNet50 backbone.

### Imports and configuration settings

**Key operations:**
1. The libraries needed are imported
2. **Config**: Config contains all the configurations necessary to run the script
3. Seeding to obtain a repeatable experiment

In [65]:
import torch
import sys
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import glob
import random
import xml.etree.ElementTree as ET
from itertools import cycle
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.models as models
from torchvision.models import ResNet50_Weights
from google.colab import drive
import zipfile

class Config:
    # 1. SETUP
    project_name = "FewShot_Evaluation"
    seed = 42
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # 2. DATASET PATH
    dataset_root = "LogoDet-3K/LogoDet-3K-divided"

    # 3. MODEL PARAMETERS
    embedding_dim = 128
    pretrained = True
    freeze_layers = 5
    trained_model_path = ""

    # 4. EVALUATION SETTINGS
    prediciton_threashold = 0.5
    n_shot = 1
    num_episodes = 100

torch.manual_seed(Config.seed)
random.seed(Config.seed)

def setup_dataset(zip_path, extract_to):
    """
    Mounts Google Drive and extracts the dataset if not already present.
    """
    # 1. Mount Google Drive
    if not os.path.exists('/content/drive'):
        drive.mount('/content/drive')

    # 2. Check if the folder already exists
    if os.path.exists(extract_to):
        print(f"Dataset folder '{extract_to}' already exists. Skipping extraction.")
    else:
        print(f"Extracting dataset from {zip_path}...")
        if os.path.exists(zip_path):
            with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                zip_ref.extractall(extract_to)
            print("Extraction complete.")
        else:
            print(f"ERROR: Zip file not found at {zip_path}. Check your path.")

setup_dataset("/content/drive/MyDrive/LogoDet-3K-divided.zip", "/content/LogoDet-3K")

Dataset folder '/content/LogoDet-3K' already exists. Skipping extraction.


### Dataset and Model Architecture

**Key operations:**
1. **DatasetTest**: Specialized loader for test images that parses XML files to retrieve label indices.
2. **LogoResNet50**: A modified ResNet50 architecture that replaces the final classifier with a layer generating feature embeddings.
3. **load_model**: A function that given the path of a saved model loads it and returns it

In [66]:
class DatasetTest(Dataset):
  def __init__(self, file_list, transform=None):
      self.file_list = file_list
      self.transform = transform

  def __len__(self):
      return len(self.file_list)

  def load_image(self, image_path):
      xml_path = image_path.replace(".jpg", ".xml")
      img = Image.open(image_path)
      if self.transform:
          img = self.transform(img)

      # Parse XML
      tree = ET.parse(xml_path)
      root = tree.getroot()
      objects = root.findall("object")

      # Take first label's index only
      index_text = objects[0].find("index").text
      label_idx = int(index_text)  # Convert string to int

      return {"image": img, "label": label_idx}

  def __getitem__(self, idx):
      return self.load_image(self.file_list[idx])

class LogoResNet50(nn.Module):
    def __init__(self, embedding_dim=128, pretrained=True, num_of_freeze_layer=5, activation_fn=None):
        super(LogoResNet50, self).__init__()

        # 1. Load Pre-trained Weights
        # Initialize the model with weights pretrained on ImageNet for transfer learning
        if pretrained:
            weights = ResNet50_Weights.DEFAULT
            self.model = models.resnet50(weights=weights)
        else:
            self.model = models.resnet50(weights=None)

        # 2. Modify the Head (Fully Connected Layer)
        # We need to produce feature embeddings instead of class probabilities
        input_features_fc = self.model.fc.in_features # Typically 2048 for ResNet50

        head_layers = []
        # Project features to the desired embedding dimension (e.g., 128)
        head_layers.append(nn.Linear(input_features_fc, embedding_dim))

        # Add an optional activation function if provided
        if activation_fn is not None:
            head_layers.append(activation_fn)

        # Replace the original classifier with our custom embedding head
        self.model.fc = nn.Sequential(*head_layers)

        # 3. Freezing Management
        # Define the blocks here to access them in the freeze method.
        # This structure allows progressive freezing/unfreezing strategies
        self.blocks = [
            ['conv1', 'bn1'],   # Level 1
            ['layer1'],         # Level 2
            ['layer2'],         # Level 3
            ['layer3'],         # Level 4
            ['layer4'],         # Level 5: Entire backbone frozen
        ]

        # Apply the initial freezing configuration
        self.freeze_numer_of_layer(num_of_freeze_layer)

    def forward(self, x):
        return self.model(x)

    def freeze_numer_of_layer(self, num_of_freeze_layer):
        """
        Manages layer freezing for transfer learning strategies.

        Args:
            num_of_freeze_layer (int):
              0   -> All layers unlocked (Full Fine-Tuning)
              1-5 -> Progressively freezes the backbone layers from shallow to deep
        """

        # STEP 1: RESET. Unfreeze everything (requires_grad = True).
        # This ensures we start from a clean state before applying new constraints.
        for param in self.model.parameters():
            param.requires_grad = True

        # If num is 0, exit immediately (Full Fine-Tuning mode)
        if num_of_freeze_layer == 0:
            print("Configuration: Full Fine-Tuning (All layers are trainable)")
            return

        # Safety check to avoid index out of bounds
        limit = min(num_of_freeze_layer, len(self.blocks))

        frozen_list = []

        # STEP 2: Progressively freeze the requested blocks
        for i in range(limit):
            current_blocks = self.blocks[i]
            for block_name in current_blocks:
                # Retrieve the layer by name
                layer = getattr(self.model, block_name)

                # Freeze parameters for this specific block
                for param in layer.parameters():
                    param.requires_grad = False

                frozen_list.append(block_name)

        print(f"Freezing Level {limit}. Frozen blocks: {frozen_list}")

def load_model(model_path, device):
    model = LogoResNet50(embedding_dim=Config.embedding_dim, pretrained=Config.pretrained, num_of_freeze_layer=Config.freeze_layers)
    # state = torch.load(model_path)
    # model.load_state_dict(state)
    model.to(device)
    model.eval()
    return model

### Evaluation Utilities

**Key operations:**
1. **MetricEvaluator**: Calculates standard metrics like Discriminant Ration, mAP, Precision, Recall, and F1 Score.
2. **FewShotIterator**: Manages the sampling of N-Shot tasks from the test dataset.
3. **getTestPaths**: A function that returns a set of paths from the test set

In [67]:
class MetricEvaluator:
    """
    A class to calculate evaluation metrics for Few-Shot Learning and Metric Learning.

    Implements:
    1. Discriminant Ratio (J): Optimized scalar implementation (O(d) memory).
    2. Mean Average Precision (mAP): Ranking quality metric.
    3. Recall at Fixed Precision (R@P): Operational metric.
    4. Precision & Recall: Raw metrics at a specific similarity threshold.
    5. F1 Score: Harmonic mean of Precision and Recall.
    """

    def __init__(self, device=None):
        """
        Initialize the evaluator.

        Args:
            device (str): 'cuda' or 'cpu'. If None, detects automatically.
        """
        if device:
            self.device = device
        else:
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.epsilon = 1e-6  # For numerical stability

    def compute_discriminant_ratio(self, embeddings, labels):
        """
        Calculates the Discriminant Ratio (J) using the optimized Scalar approach.

        Theory:
            J = Tr(Sb) / Tr(Sw)
            Using the Trace Trick: Tr(Sw) = Tr(St) - Tr(Sb)

        Args:
            embeddings (torch.Tensor): Tensor of shape (Batch_Size, Dimension).
            labels (torch.Tensor): Tensor of class labels.

        Returns:
            float: The Discriminant Ratio score.
        """
        embeddings = embeddings.to(self.device)
        labels = labels.to(self.device)

        # 1. Global Mean Computation
        global_mean = embeddings.mean(dim=0)

        # 2. Calculate Trace of Total Scatter (St)
        # Sum of squared Euclidean distances of all points from the global mean.
        tr_st = torch.sum((embeddings - global_mean) ** 2)

        # 3. Calculate Trace of Between-Class Scatter (Sb)
        tr_sb = 0
        unique_classes = torch.unique(labels)

        for c in unique_classes:
            class_mask = (labels == c)
            class_embeddings = embeddings[class_mask]
            n_c = class_embeddings.size(0)

            if n_c > 0:
                mu_c = class_embeddings.mean(dim=0)
                tr_sb += n_c * torch.sum((mu_c - global_mean) ** 2)

        # 4. Calculate Trace of Within-Class Scatter (Sw)
        tr_sw = tr_st - tr_sb

        # Calculate J
        j_score = tr_sb / (tr_sw + self.epsilon)

        return j_score.item()

    def compute_map(self, query_emb, gallery_emb, query_labels, gallery_labels):
        """
        Calculates Mean Average Precision (mAP).
        """
        query_emb = query_emb.to(self.device)
        gallery_emb = gallery_emb.to(self.device)
        query_labels = query_labels.to(self.device)
        gallery_labels = gallery_labels.to(self.device)

        # L2 Normalize for Cosine Similarity
        query_emb = F.normalize(query_emb, p=2, dim=1)
        gallery_emb = F.normalize(gallery_emb, p=2, dim=1)

        # Similarity Matrix: S = Q * G^T
        similarity_matrix = torch.matmul(query_emb, gallery_emb.T)

        num_queries = query_labels.size(0)
        average_precisions = []

        for i in range(num_queries):
            scores = similarity_matrix[i]
            target_label = query_labels[i]

            # Ranking
            sorted_indices = torch.argsort(scores, descending=True)
            sorted_gallery_labels = gallery_labels[sorted_indices]

            # Relevance Mask
            relevance_mask = (sorted_gallery_labels == target_label).float()

            total_relevant = relevance_mask.sum()
            if total_relevant == 0:
                average_precisions.append(0.0)
                continue

            # Cumulative Precision
            cumsum = torch.cumsum(relevance_mask, dim=0)
            ranks = torch.arange(1, len(relevance_mask) + 1).to(self.device)
            precisions = cumsum / ranks

            # Average Precision (AP)
            ap = (precisions * relevance_mask).sum() / total_relevant
            average_precisions.append(ap.item())

        if not average_precisions:
            return 0.0
        return sum(average_precisions) / len(average_precisions)

    def compute_precision_recall(self, similarity_scores, is_match, threshold=0.5):
        """
        Calculates raw Precision and Recall at a specific similarity threshold.

        Definitions:
            Precision = TP / (TP + FP)
            Recall    = TP / (TP + FN)

        Args:
            similarity_scores (torch.Tensor): 1D tensor of scores (0.0 to 1.0).
            is_match (torch.Tensor): 1D binary tensor (Ground Truth).
            threshold (float): Cutoff for deciding if a retrieval is Positive.

        Returns:
            tuple: (precision, recall)
        """
        similarity_scores = similarity_scores.to(self.device)
        is_match = is_match.to(self.device)

        # Binarize predictions: 1 if score >= threshold (Positive), else 0 (Negative)
        predicted_positive = (similarity_scores >= threshold).float()

        # True Positives (TP): Predicted Positive AND Actually Match
        tp = (predicted_positive * is_match).sum()

        # False Positives (FP): Predicted Positive BUT Actually Non-Match
        fp = (predicted_positive * (1 - is_match)).sum()

        # False Negatives (FN): Predicted Negative BUT Actually Match
        # (We invert the prediction mask to find negatives)
        fn = ((1 - predicted_positive) * is_match).sum()

        precision = tp / (tp + fp + self.epsilon)
        recall = tp / (tp + fn + self.epsilon)

        return precision.item(), recall.item()

    def compute_recall_at_fixed_precision(self, similarity_scores, is_match, min_precision=0.95):
        """
        Calculates Recall at a Fixed Precision (R@P).
        Finds the lowest threshold where Precision >= min_precision.
        """
        similarity_scores = similarity_scores.to(self.device)
        is_match = is_match.to(self.device)

        sorted_indices = torch.argsort(similarity_scores, descending=True)
        sorted_matches = is_match[sorted_indices]

        tps = torch.cumsum(sorted_matches, dim=0)
        total_retrieved = torch.arange(1, len(sorted_matches) + 1).to(self.device)

        precisions = tps / total_retrieved

        # Find indices where Precision satisfies the constraint
        valid_indices = torch.where(precisions >= min_precision)[0]

        if len(valid_indices) == 0:
            return 0.0

        cutoff_index = valid_indices[-1]

        # Recall = TP_at_cutoff / Total_Relevant_In_Dataset
        total_relevant_in_dataset = is_match.sum()

        if total_relevant_in_dataset == 0:
            return 0.0

        recall = tps[cutoff_index] / total_relevant_in_dataset

        return recall.item()

    def compute_f1_score(self, precision, recall):
        """
        Calculates F1 Score (Harmonic Mean).
        """
        if (precision + recall) == 0:
            return 0.0
        return 2 * (precision * recall) / (precision + recall)

class FewShotIterator:
    def __init__(self, file_list, n_shot):
        """
        Initializes the iterator class.
        It prepares the global testset and creates a cyclic iterator over the valid brands.
        """
        self.n_shot = n_shot

        # 1. Validation: Check if input list is empty
        if not file_list:
            raise ValueError("The test file list is empty.")

        #    (Dataset - SupportSet)  is significantly faster with sets (O(1)) compared to lists.
        self.all_files_set = set(file_list)

        # 3. Organize data by Brand
        #    We create a dictionary mapping: { 'BrandName': [list_of_image_paths] }
        self.brands_map = {}

        for file_path in file_list:
            # Extract brand name assuming structure: .../Category/Brand/Image.jpg
            brand_name = os.path.basename(os.path.dirname(file_path))

            if brand_name not in self.brands_map:
                self.brands_map[brand_name] = []
            self.brands_map[brand_name].append(file_path)

        self.valid_brands_list = list(self.brands_map.keys())

        if not self.valid_brands_list:
            raise ValueError(f"No brand found with more than {n_shot} images.")

        #    'itertools.cycle' creates an infinite loop over the valid brands list.
        self.brand_iterator = cycle(self.valid_brands_list)

    def __call__(self):
        """
        Executed when the class instance is called.
        Logic:
        1. Pick next brand (Sequential).
        2. Pick Support Set (Random 5 images from that brand).
        3. Pick Query Set (EVERYTHING else in the testset).
        """
        while True:
            # A. Get the next brand sequentially from the cycle
            selected_brand_name = next(self.brand_iterator)

            # B. Retrieve all images specific to this chosen brand
            images_of_current_brand = self.brands_map[selected_brand_name]

            # Select a random number between 1 to 5 which is the number of images of the support brand guaranteed in the query set
            num_query_guarantee = random.randint(1, 5)

            # Check if this brand has enough images
            if len(images_of_current_brand) >= self.n_shot + num_query_guarantee:
                break  # Brand is suitable, exit the loop
            else:
                print(f"The brand {selected_brand_name} wa skipped because it doesnt have enough images to create a query set with {num_query_guarantee} images of the support brand")

        # C. Create SUPPORT SET
        #    Select 'n_shot' unique images randomly from the current brand.
        support_set_list = random.sample(images_of_current_brand, self.n_shot)
        support_set_set = set(support_set_list)

        # D. Create QUERY SET (Global Subtraction)
        #    Requirement: The Query Set contains 50 images of which at least 1 is from the support brand
        #    Step 1: initialize the Query list
        query_set_list = []

        #    Step 2: Sample randomly the images to guarantee in the query set
        remaining_brand_images = list(set(images_of_current_brand) - support_set_set)
        guaranteed_images_in_query = random.sample(remaining_brand_images, num_query_guarantee)

        #    Step 3: Sample the Negative Queries (Distractors from OTHER brands)
        # We subtract ALL images of the current brand to ensure zero accidental matches
        remaining_images_in_query = list(self.all_files_set - set(images_of_current_brand) - set(guaranteed_images_in_query))

        total_query_size = 50
        num_remaining_query = total_query_size - num_query_guarantee
        query_remaining = random.sample(remaining_images_in_query,  min(len(remaining_images_in_query),num_remaining_query))

        #    Step 4: Combine and Shuffle
        query_set_list = guaranteed_images_in_query + query_remaining
        random.shuffle(query_set_list)


        return {
            "brand_name": selected_brand_name,
            "support_set": support_set_list,
            "query_set": query_set_list
        }

def getTestPaths(root_dir, total_set_size=None, min_images_per_brand=2):
    test_path = os.path.join(root_dir, 'test')
    test_brand_list = []

    # Collect brand folders
    if not os.path.exists(test_path):
        print(f"Warning: {test_path} not found.")
        return []

    for category in os.listdir(test_path):
        cat_path = os.path.join(test_path, category)
        if os.path.isdir(cat_path):
            for brand in os.listdir(cat_path):
                brand_full_path = os.path.join(cat_path, brand)
                if os.path.isdir(brand_full_path):
                    test_brand_list.append(brand_full_path)

    test_data_list = []

    # Sampling Logic
    if total_set_size is not None:
        images_per_brand = round(total_set_size / len(test_brand_list))

        if images_per_brand < min_images_per_brand:
            new_test_brand_count = round(total_set_size / min_images_per_brand)
            test_brand_list = random.sample(test_brand_list, min(len(test_brand_list), new_test_brand_count))
            images_per_brand = min_images_per_brand

        for brand in test_brand_list:
            imgs = glob.glob(os.path.join(brand, '*.jpg'))

            if len(imgs) < min_images_per_brand:
                print(f"images are less than {min_images_per_brand} for this brand: {brand} in the TEST set")

            test_data_list.extend(random.sample(imgs, min(images_per_brand, len(imgs))))
    else:
        for brand in test_brand_list:
            test_data_list.extend(glob.glob(os.path.join(brand, '*.jpg')))

    return test_data_list

### Evaluation Loop

**Key operations:**
1. **evaluate_few_shot**: Runs the evaluation episodes, computes embeddings, and calculates aggregated metrics.
2. **cosine_similarity**: A funciton that computes the cosine similarity between two passed embeddings

In [68]:
def evaluate_few_shot(model, fewshot_iterator, transform, device, num_episodes=100):
    evaluator = MetricEvaluator(device=device)
    accuracies = []
    precisions = []
    recalls = []
    f1_scores = []
    r_at_95p = []

    torch.cuda.empty_cache()

    # 2. Set to eval mode and disable gradient tracking
    model.eval()
    with torch.no_grad():
        for _ in range(num_episodes):
            print("New iteration ___________________________________")
            task = fewshot_iterator()
            support_paths = task["support_set"]
            query_paths = task["query_set"]

            # Build datasets and loaders
            support_dataset = DatasetTest(support_paths, transform)
            query_dataset = DatasetTest(query_paths, transform)

            support_loader = DataLoader(support_dataset, batch_size=32)
            query_loader = DataLoader(query_dataset, batch_size=64)

            # Extract embeddings
            support_embeddings = []
            query_embeddings = []
            query_labels = []

            # Compute embeddings for support set
            for data in support_loader:
                print(f"support labels: {data["label"]}")
                images = data["image"].to(device)
                support_embeddings.append(model(images))

                batch_labels = data["label"]
                support_brand = batch_labels[0]

            support_embeddings_tensor = torch.cat(support_embeddings)

            # Average embeddings
            averaged_support_embeddings = support_embeddings_tensor.mean(dim=0)

            # Compute embeddings for query set
            for data in query_loader:
                print(f"query labels: {data["label"]}")
                images = data["image"].to(device)
                query_embeddings.append(model(images))

                batch_labels = data["label"]
                query_labels.append(batch_labels)

            # query_embeddings and query_labels are list of tensors, this unrolls them
            query_embeddings_tensor = torch.cat(query_embeddings)
            query_labels_tensor = torch.cat(query_labels)

            # Compute similarity
            sims = cosine_similarity(averaged_support_embeddings, query_embeddings_tensor)

            # Ground truth: query belongs to support brand?
            gt = (query_labels_tensor == support_brand).float()

            # Predictions, does the model predict it is the same brand?
            pred = (sims >= Config.prediciton_threashold).float().cpu()

            # Accuracy
            acc = (pred == gt).float().mean().item()
            accuracies.append(acc)

            # Precision, Recall, F1
            prec, rec = evaluator.compute_precision_recall(sims, gt, threshold=Config.prediciton_threashold)
            f1 = evaluator.compute_f1_score(prec, rec)
            r95 = evaluator.compute_recall_at_fixed_precision(sims, gt, min_precision=0.95)

            precisions.append(prec)
            recalls.append(rec)
            f1_scores.append(f1)
            r_at_95p.append(r95)

    # Aggregate results
    results = {
        "accuracy": sum(accuracies) / len(accuracies),
        "precision": sum(precisions) / len(precisions),
        "recall": sum(recalls) / len(recalls),
        "f1": sum(f1_scores) / len(f1_scores),
        "r@95p": sum(r_at_95p) / len(r_at_95p),
    }
    return results

def cosine_similarity(averaged_support_embeddings, query_embeddings_tensor):

    # Normalize embeddings if you want cosine similarity
    support_emb_norm = F.normalize(averaged_support_embeddings, p=2, dim=0)       # [embedding_dim]
    query_emb_norm = F.normalize(query_embeddings_tensor, p=2, dim=1)             # [num_queries, embedding_dim]

    # Compute cosine similarity
    sims = torch.matmul(query_emb_norm, support_emb_norm)  # [num_queries]
    return sims

def main():
    transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(),
                                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    device = torch.device(Config.device)
    model = load_model(Config.trained_model_path, device)
    # model.load_state_dict(torch.load(Config.trained_model_path)) # Uncomment to load trained weights

    test_paths = getTestPaths(Config.dataset_root, total_set_size=500, min_images_per_brand=6) # Optionally add total_set_size and min_images_per_brand
    iterator = FewShotIterator(test_paths, n_shot=Config.n_shot)

    results = evaluate_few_shot(model, iterator, transform, device, Config.num_episodes)
    print("\n=== Evaluation Results ===")
    for k, v in results.items():
        print(f"{k.capitalize():<10}: {v:.4f}")

if __name__ == "__main__":
    main()

Freezing Level 5. Frozen blocks: ['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'layer4']
images are less than 6 for this brand: LogoDet-3K/LogoDet-3K-divided/test/Clothes/regatta in the TEST set
images are less than 6 for this brand: LogoDet-3K/LogoDet-3K-divided/test/Others/stage stores inc in the TEST set
images are less than 6 for this brand: LogoDet-3K/LogoDet-3K-divided/test/Others/Carpathia in the TEST set
New iteration ___________________________________
support labels: tensor([338])
query labels: tensor([ 209, 1562,  668,  466,    9,   97, 1540, 2245,  668,  533,  387, 1521,
         668, 2458, 1962,  691, 2354, 2147,  209, 1962, 1586, 1471, 1628,  338,
        1586,  338,    9, 2245, 2536, 1918, 1096, 1145, 1232, 1474, 1540,  423,
        2147, 1995, 2666, 1302,  338, 1093,  149, 2632, 2487,  209,  560,  338,
        2567, 2200])
New iteration ___________________________________
support labels: tensor([576])
query labels: tensor([1540, 1446,  576,  949, 1628,  576,  912, 2354

KeyboardInterrupt: 