# Few-Shot Logo Recognition Project: ResNet50 Baseline Evaluation

### 1. Abstract and Objective
This notebook aims to establish a **baseline** for the *Few-Shot Logo Recognition* task using the **LogoDet-3K** dataset.
Before proceeding with fine-tuning or modifying the network architecture, it is crucial to measure the performance of a standard pre-trained model. The results obtained in this notebook will serve as a benchmark to evaluate the effectiveness of future optimizations.

### 2. Methodology
The approach used in this phase is **"Off-the-shelf Feature Extraction"**:
* **Model:** A **ResNet50** pre-trained on *ImageNet* is used.
* **Feature Extraction:** The final classification layer (Fully Connected) is removed and replaced with an identity function. Instead of predicting the 1000 ImageNet classes, the model returns the feature vector (embedding) of dimension **2048**.
* **Inference:** Classification is performed via **Cosine Similarity**. The distance between the *query* image embedding (to be classified) and the *support* image embedding (known brand example) is calculated.
* **Protocol:** Testing follows an *episodic* approach (N-Way, K-Shot) simulated over 1000 episodes to ensure statistical robustness.

### 3. Notebook Structure
The code is organized into the following logical sections:

1.  **Configuration and Setup:** Definition of global parameters (seed, path, device) and environment setup.
2.  **Dataset Management:**
    * `DatasetTest`: PyTorch class for loading images.
    * `FewShotIterator`: Logic for creating episodes (Support Set vs Query Set).
3.  **Evaluation Metrics:** Implementation of the `MetricEvaluator` class for calculating Accuracy, mAP, F1-Score, Precision, Recall, and Discriminant Ratio (J).
4.  **Baseline Model Definition:** `get_baseline_resnet50` function for loading the model and modifying the head.
5.  **Evaluation Loop:** Execution of the test and aggregation of final results.

---
**Dataset:** LogoDet-3K | **Model:** ResNet50 (Frozen) | **Main Metric:** Cosine Similarity

In [None]:
import torch
import sys
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import glob
import random
import xml.etree.ElementTree as ET
from itertools import cycle
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.models as models
from torchvision.models import ResNet50_Weights
#from google.colab import drive
import zipfile
import pandas as pd

class Config:
    # 1. SETUP
    project_name = "FewShot_Evaluation"
    seed = 42
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # 2. DATASET PATH
    #dataset_root = "LogoDet-3K/LogoDet-3K-divided"
    dataset_root = "./LogoDet-3K"
    
    csv_index_path = "LogoDet-3K/brand_to_index.csv"
    # 3. MODEL PARAMETERS
    embedding_dim = 128
    pretrained = True
    freeze_layers = 5
    trained_model_path = ""

    # 4. EVALUATION SETTINGS
    prediciton_threashold = 0.5
    n_shot = 1
    num_episodes = 1000

torch.manual_seed(Config.seed)
random.seed(Config.seed)

#def setup_dataset(zip_path, extract_to):
#    """
#    Mounts Google Drive and extracts the dataset if not already present.
#    """
#    # 1. Mount Google Drive
#    if not os.path.exists('/content/drive'):
#        drive.mount('/content/drive')
#
#    # 2. Check if the folder already exists
#    if os.path.exists(extract_to):
#        print(f"Dataset folder '{extract_to}' already exists. Skipping extraction.")
#    else:
#        print(f"Extracting dataset from {zip_path}...")
#        if os.path.exists(zip_path):
#            with zipfile.ZipFile(zip_path, 'r') as zip_ref:
#                zip_ref.extractall(extract_to)
#            print("Extraction complete.")
#        else:
#            print(f"ERROR: Zip file not found at {zip_path}. Check your path.")
def setup_dataset():
    """
    Verifica semplicemente se la cartella del dataset esiste in locale.
    Non scarica né decomprime nulla.
    """
    # Verifica esistenza dataset
    if os.path.exists(Config.dataset_root):
        print(f" Dataset trovato correttemente in: {os.path.abspath(Config.dataset_root)}")
    else:
        print(f" ERRORE: La cartella '{Config.dataset_root}' non è stata trovata.")
        print("Assicurati che il nome della cartella in Config.dataset_root corrisponda esattamente a quella sul tuo PC.")
        
    # Verifica esistenza CSV (opzionale ma utile)
    # Se il csv è dentro la cartella del dataset, aggiusta il path in Config
    if os.path.exists(Config.csv_index_path):
        print(f" CSV index trovato.")
    else:
        print(f" ATTENZIONE: File CSV '{Config.csv_index_path}' non trovato. Controlla il percorso.")

# Esegui il setup (che ora è solo un controllo)
def generate_brand_index_csv(root_dir, csv_output_path):
    """
    Scansiona le cartelle dei brand e genera automaticamente il file CSV
    brand_to_index.csv necessario per il mapping (Brand -> ID Numerico).
    """
    print(f"Generazione file indice CSV da: {root_dir}...")
    
    brands = set()
    
    # Scansiona la struttura: root -> Categoria -> Brand
    if os.path.exists(root_dir):
        for category in os.listdir(root_dir):
            cat_path = os.path.join(root_dir, category)
            if os.path.isdir(cat_path):
                for brand in os.listdir(cat_path):
                    brand_path = os.path.join(cat_path, brand)
                    if os.path.isdir(brand_path):
                        brands.add(brand)
    
    if not brands:
        print("❌ ERRORE: Nessun brand trovato per generare il CSV!")
        return

    # Ordina alfabeticamente per avere indici consistenti
    sorted_brands = sorted(list(brands))
    
    # Crea il DataFrame e salva
    df = pd.DataFrame({
        'brand': sorted_brands,
        'index': range(len(sorted_brands))
    })
    
    df.to_csv(csv_output_path, index=False)
    print(f"✅ File CSV generato con successo: {csv_output_path} ({len(brands)} brand mappati)")

generate_brand_index_csv(Config.dataset_root, Config.csv_index_path)
setup_dataset()

#setup_dataset("/content/drive/MyDrive/LogoDet-3K-divided.zip", "/content/LogoDet-3K")

Generazione file indice CSV da: ./LogoDet-3K...
✅ File CSV generato con successo: LogoDet-3K/brand_to_index.csv (3000 brand mappati)
✅ Dataset trovato correttemente in: c:\Users\Lenovo\Desktop\ProgettoFinale1\PY_script\LogoDet-3K
✅ CSV index trovato.


In [None]:



class DatasetTest(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    # Load label string to index mapping
        df = pd.read_csv(Config.csv_index_path)
        self.label_to_id = {row['brand']: int(row['index']) for _, row in df.iterrows()}


    def __len__(self):
        return len(self.file_list)

    def load_image(self, image_path):
        img = Image.open(image_path)
        if self.transform:
            img = self.transform(img)
        image_label = os.path.basename(os.path.dirname(image_path))

        label_idx = self.label_to_id[image_label]

        return {"image": img, "label": label_idx}

    def __getitem__(self, idx):
        return self.load_image(self.file_list[idx])


def getTestPaths(root_dir, total_set_size=None, min_images_per_brand=2):
    # MODIFICA: Puntiamo direttamente alla root, senza cercare la sottocartella 'test'
    test_path = root_dir 
    test_brand_list = []

    # Collect brand folders
    if not os.path.exists(test_path):
        print(f"Warning: {test_path} not found.")
        return []

    print(f"Scansione cartelle in: {os.path.abspath(test_path)}")

    for category in os.listdir(test_path):
        cat_path = os.path.join(test_path, category)
        # Assicuriamoci di processare solo cartelle (ignora eventuali file readme, csv, ecc.)
        if os.path.isdir(cat_path):
            for brand in os.listdir(cat_path):
                brand_full_path = os.path.join(cat_path, brand)
                if os.path.isdir(brand_full_path):
                    test_brand_list.append(brand_full_path)
    
    # Se non trova nessun brand, stampiamo un errore chiaro
    if not test_brand_list:
        print("ERRORE: Nessun brand trovato! Controlla se il path punta alla cartella giusta.")
        return []

    print(f"Trovati {len(test_brand_list)} brand totali.")

    test_data_list = []

    # Sampling Logic (Invariata)
    if total_set_size is not None:
        if len(test_brand_list) == 0: return []
        images_per_brand = round(total_set_size / len(test_brand_list))

        if images_per_brand < min_images_per_brand:
            new_test_brand_count = round(total_set_size / min_images_per_brand)
            test_brand_list = random.sample(test_brand_list, min(len(test_brand_list), new_test_brand_count))
            images_per_brand = min_images_per_brand

        for brand in test_brand_list:
            imgs = glob.glob(os.path.join(brand, '*.jpg'))

            if len(imgs) < min_images_per_brand:
                # print(f"images are less than {min_images_per_brand} for this brand: {brand}")
                pass

            test_data_list.extend(random.sample(imgs, min(images_per_brand, len(imgs))))
    else:
        for brand in test_brand_list:
            test_data_list.extend(glob.glob(os.path.join(brand, '*.jpg')))

    return test_data_list

In [None]:
def compute_global_embeddings(model, file_list, transform, device):
    """
    Estrae le feature da tutte le immagini della lista passata.
    Versione con PRINT DI DEBUG per monitorare l'avanzamento.
    """
    total_files = len(file_list)
    
    model.eval()
    all_embeddings = []
    all_labels = []

    # Simple linear dataset of all unique test images
    dataset = DatasetTest(file_list, transform)
    
    # Batch size 64 è un buon compromesso per la tua RAM
    batch_size = 64
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    num_batches = len(loader)
    
    with torch.no_grad():
        for i, data in enumerate(loader):
            # Stampa lo stato di avanzamento sulla stessa riga
            # Usa end="\r" per sovrascrivere la riga e creare un effetto "caricamento"
            print(f"  > [Batch {i+1}/{num_batches}] Elaborazione batch in corso...", end="\r")
            
            images = data["image"].to(device)
            labels = data["label"]

            # Extract and move to CPU to save VRAM
            # Qui avviene il calcolo pesante della ResNet
            embeddings = F.normalize(model(images), p=2, dim=1).cpu()

            all_embeddings.append(embeddings)
            all_labels.append(labels)
            
            if (i + 1) % 10 == 0:
                torch.cuda.empty_cache()

    
    final_embeddings = torch.cat(all_embeddings)
    final_labels = torch.cat(all_labels)
    
    print(f"[END] 'compute_global_embeddings' completato. Shape finale: {final_embeddings.shape}")
    return final_embeddings, final_labels

def cosine_similarity(averaged_support_embeddings, query_embeddings_tensor):

    # Normalize embeddings if you want cosine similarity
    support_emb_norm = F.normalize(averaged_support_embeddings, p=2, dim=0)       # [embedding_dim]
    query_emb_norm = F.normalize(query_embeddings_tensor, p=2, dim=1)             # [num_queries, embedding_dim]

    # Compute cosine similarity
    sims = torch.matmul(query_emb_norm, support_emb_norm)  # [num_queries]
    return sims

def evaluate_few_shot(model, fewshot_iterator, transform, device, num_episodes=100):
    
    evaluator = MetricEvaluator(device=device)
    accuracies = []
    precisions = []
    recalls = []
    f1_scores = []
    r_at_95p = []
    ap_scores = []

    torch.cuda.empty_cache()


    unique_paths = list(fewshot_iterator.all_files_set)
    
    print("[HEAVY TASK] Inizio compute_global_embeddings")
    
    # Qui è dove probabilmente perdeva tempo
    embs, labels = compute_global_embeddings(model, unique_paths, transform, device)
    j_score = evaluator.compute_discriminant_ratio(embs, labels)
    print("[HEAVY TASK] fine compute_global_embeddings")

    # -----------------------------------

    # 2. Set to eval mode and disable gradient tracking
    model.eval()
    
    
    with torch.no_grad():
        for i in range(num_episodes):
            # Print di stato per ogni episodio
            print(f"  > [Episodio {i+1}/{num_episodes}] Generazione task...", end="\r")

            task = fewshot_iterator()

            if task is None:
                print(f"\n[STOP] Fermato anticipatamente all'episodio {i} (finiti i brand).")
                break

            support_paths = task["support_set"]
            query_paths = task["query_set"]
            

            # Build datasets and loaders
            support_dataset = DatasetTest(support_paths, transform)
            query_dataset = DatasetTest(query_paths, transform)

            support_loader = DataLoader(support_dataset, batch_size=32)
            query_loader = DataLoader(query_dataset, batch_size=64)

            # Extract embeddings
            support_embeddings = []
            query_embeddings = []
            query_labels = []
            support_labels = []

            # Compute embeddings for support set
            for data in support_loader:
                images = data["image"].to(device)
                support_embeddings.append(F.normalize(model(images), p=2, dim=1))

                batch_labels = data["label"]
                support_labels.append(batch_labels)
                support_brand = batch_labels[0]

            support_embeddings_tensor = torch.cat(support_embeddings)
            support_labels_tensor = torch.cat(support_labels)

            # Average embeddings
            averaged_support_embeddings = support_embeddings_tensor.mean(dim=0)
            averaged_support_embeddings_for_ap = support_embeddings_tensor.mean(dim=0, keepdim=True)

            # Compute embeddings for query set
            for data in query_loader:
                images = data["image"].to(device)
                query_embeddings.append(F.normalize(model(images), p=2, dim=1))

                batch_labels = data["label"]
                query_labels.append(batch_labels)

            # query_embeddings and query_labels are list of tensors, this unrolls them
            query_embeddings_tensor = torch.cat(query_embeddings)
            query_labels_tensor = torch.cat(query_labels)

            # mAP
            ap_score_single = evaluator.compute_map(
                query_emb=averaged_support_embeddings_for_ap,
                gallery_emb=query_embeddings_tensor,
                query_labels=support_labels_tensor,
                gallery_labels=query_labels_tensor
            )

            ap_scores.append(ap_score_single)

            # Compute similarity
            sims = cosine_similarity(averaged_support_embeddings, query_embeddings_tensor)

            # Ground truth: query belongs to support brand?
            gt = (query_labels_tensor == support_brand).float()

            # Predictions, does the model predict it is the same brand?
            pred = (sims >= Config.prediciton_threashold).float().cpu()

            # Accuracy
            acc = (pred == gt).float().mean().item()
            accuracies.append(acc)

            # Precision, Recall, F1
            prec, rec = evaluator.compute_precision_recall(sims, gt, threshold=Config.prediciton_threashold)
            f1 = evaluator.compute_f1_score(prec, rec)
            r95 = evaluator.compute_recall_at_fixed_precision(sims, gt, min_precision=0.95)

            precisions.append(prec)
            recalls.append(rec)
            f1_scores.append(f1)
            r_at_95p.append(r95)
            
            # Feedback visivo ogni 10 episodi per non intasare troppo, ma vedere che avanza
            if (i + 1) % 10 == 0:
                 print(f"  > [Episodio {i+1}/{num_episodes}] Completato. (Acc parziale: {acc:.2f})")

    # Aggregate results
    results = {
        "accuracy": sum(accuracies) / len(accuracies),
        "precision": sum(precisions) / len(precisions),
        "recall": sum(recalls) / len(recalls),
        "f1": sum(f1_scores) / len(f1_scores),
        "r@95p": sum(r_at_95p) / len(r_at_95p),
        "map": sum(ap_scores) / len(ap_scores),
        "J": j_score,
    }
    return results


class MetricEvaluator:
    """
    A class to calculate evaluation metrics for Few-Shot Learning and Metric Learning.

    Implements:
    1. Discriminant Ratio (J): Optimized scalar implementation (O(d) memory).
    2. Mean Average Precision (mAP): Ranking quality metric.
    3. Recall at Fixed Precision (R@P): Operational metric.
    4. Precision & Recall: Raw metrics at a specific similarity threshold.
    5. F1 Score: Harmonic mean of Precision and Recall.
    """

    def __init__(self, device=None):
        """
        Initialize the evaluator.

        Args:
            device (str): 'cuda' or 'cpu'. If None, detects automatically.
        """
        if device:
            self.device = device
        else:
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.epsilon = 1e-6  # For numerical stability

    def compute_discriminant_ratio(self, embeddings, labels):
        """
        Calculates the Discriminant Ratio (J) using the optimized Scalar approach.

        Theory:
            J = Tr(Sb) / Tr(Sw)
            Using the Trace Trick: Tr(Sw) = Tr(St) - Tr(Sb)

        Args:
            embeddings (torch.Tensor): Tensor of shape (Batch_Size, Dimension).
            labels (torch.Tensor): Tensor of class labels.

        Returns:
            float: The Discriminant Ratio score.
        """
        embeddings = embeddings.to(self.device)
        labels = labels.to(self.device)

        # 1. Global Mean Computation
        global_mean = embeddings.mean(dim=0)

        # 2. Calculate Trace of Total Scatter (St)
        # Sum of squared Euclidean distances of all points from the global mean.
        tr_st = torch.sum((embeddings - global_mean) ** 2)

        # 3. Calculate Trace of Between-Class Scatter (Sb)
        tr_sb = 0
        unique_classes = torch.unique(labels)

        for c in unique_classes:
            class_mask = (labels == c)
            class_embeddings = embeddings[class_mask]
            n_c = class_embeddings.size(0)

            if n_c > 0:
                mu_c = class_embeddings.mean(dim=0)
                tr_sb += n_c * torch.sum((mu_c - global_mean) ** 2)

        # 4. Calculate Trace of Within-Class Scatter (Sw)
        tr_sw = tr_st - tr_sb

        # Calculate J
        j_score = tr_sb / (tr_sw + self.epsilon)

        return j_score.item()

    def compute_map(self, query_emb, gallery_emb, query_labels, gallery_labels):
        """
        Calculates Mean Average Precision (mAP).
        """
        query_emb = query_emb.to(self.device)
        gallery_emb = gallery_emb.to(self.device)
        query_labels = query_labels.to(self.device)
        gallery_labels = gallery_labels.to(self.device)

        # L2 Normalize for Cosine Similarity
        query_emb = F.normalize(query_emb, p=2, dim=1)
        gallery_emb = F.normalize(gallery_emb, p=2, dim=1)

        # Similarity Matrix: S = Q * G^T
        similarity_matrix = torch.matmul(query_emb, gallery_emb.T)

        num_queries = query_labels.size(0)
        average_precisions = []

        for i in range(num_queries):
            scores = similarity_matrix[i]
            target_label = query_labels[i]

            # Ranking
            sorted_indices = torch.argsort(scores, descending=True)
            sorted_gallery_labels = gallery_labels[sorted_indices]

            # Relevance Mask
            relevance_mask = (sorted_gallery_labels == target_label).float()

            total_relevant = relevance_mask.sum()
            if total_relevant == 0:
                average_precisions.append(0.0)
                continue

            # Cumulative Precision
            cumsum = torch.cumsum(relevance_mask, dim=0)
            ranks = torch.arange(1, len(relevance_mask) + 1).to(self.device)
            precisions = cumsum / ranks

            # Average Precision (AP)
            ap = (precisions * relevance_mask).sum() / total_relevant
            average_precisions.append(ap.item())

        if not average_precisions:
            return 0.0
        return sum(average_precisions) / len(average_precisions)

    def compute_precision_recall(self, similarity_scores, is_match, threshold=0.5):
        """
        Calculates raw Precision and Recall at a specific similarity threshold.

        Definitions:
            Precision = TP / (TP + FP)
            Recall    = TP / (TP + FN)

        Args:
            similarity_scores (torch.Tensor): 1D tensor of scores (0.0 to 1.0).
            is_match (torch.Tensor): 1D binary tensor (Ground Truth).
            threshold (float): Cutoff for deciding if a retrieval is Positive.

        Returns:
            tuple: (precision, recall)
        """
        similarity_scores = similarity_scores.to(self.device)
        is_match = is_match.to(self.device)

        # Binarize predictions: 1 if score >= threshold (Positive), else 0 (Negative)
        predicted_positive = (similarity_scores >= threshold).float()

        # True Positives (TP): Predicted Positive AND Actually Match
        tp = (predicted_positive * is_match).sum()

        # False Positives (FP): Predicted Positive BUT Actually Non-Match
        fp = (predicted_positive * (1 - is_match)).sum()

        # False Negatives (FN): Predicted Negative BUT Actually Match
        # (We invert the prediction mask to find negatives)
        fn = ((1 - predicted_positive) * is_match).sum()

        precision = tp / (tp + fp + self.epsilon)
        recall = tp / (tp + fn + self.epsilon)

        return precision.item(), recall.item()

    def compute_recall_at_fixed_precision(self, similarity_scores, is_match, min_precision=0.95):
        """
        Calculates Recall at a Fixed Precision (R@P).
        Finds the lowest threshold where Precision >= min_precision.
        """
        similarity_scores = similarity_scores.to(self.device)
        is_match = is_match.to(self.device)

        sorted_indices = torch.argsort(similarity_scores, descending=True)
        sorted_matches = is_match[sorted_indices]

        tps = torch.cumsum(sorted_matches, dim=0)
        total_retrieved = torch.arange(1, len(sorted_matches) + 1).to(self.device)

        precisions = tps / total_retrieved

        # Find indices where Precision satisfies the constraint
        valid_indices = torch.where(precisions >= min_precision)[0]

        if len(valid_indices) == 0:
            return 0.0

        cutoff_index = valid_indices[-1]

        # Recall = TP_at_cutoff / Total_Relevant_In_Dataset
        total_relevant_in_dataset = is_match.sum()

        if total_relevant_in_dataset == 0:
            return 0.0

        recall = tps[cutoff_index] / total_relevant_in_dataset

        return recall.item()

    def compute_f1_score(self, precision, recall):
        """
        Calculates F1 Score (Harmonic Mean).
        """
        if (precision + recall) == 0:
            return 0.0
        return 2 * (precision * recall) / (precision + recall)



class FewShotIterator:
    def __init__(self, file_list, n_shot):
        """
        Initializes the iterator class.
        It prepares the global testset and creates a cyclic iterator over the valid brands.
        """
        self.n_shot = n_shot

        # 1. Validation: Check if input list is empty
        if not file_list:
            raise ValueError("The test file list is empty.")

        #    (Dataset - SupportSet)  is significantly faster with sets (O(1)) compared to lists.
        self.all_files_set = set(file_list)

        # 3. Organize data by Brand
        #    We create a dictionary mapping: { 'BrandName': [list_of_image_paths] }
        self.brands_map = {}

        for file_path in file_list:
            # Extract brand name assuming structure: .../Category/Brand/Image.jpg
            brand_name = os.path.basename(os.path.dirname(file_path))

            if brand_name not in self.brands_map:
                self.brands_map[brand_name] = []
            self.brands_map[brand_name].append(file_path)

        self.valid_brands_list = list(self.brands_map.keys())

        if not self.valid_brands_list:
            raise ValueError(f"No brand found with more than {n_shot} images.")

        #    'itertools.cycle' creates an infinite loop over the valid brands list.
        self.brand_iterator = cycle(self.valid_brands_list)

    def __call__(self):
        """
        Executed when the class instance is called.
        Logic:
        1. Pick next brand (Sequential).
        2. Pick Support Set (Random 5 images from that brand).
        3. Pick Query Set (EVERYTHING else in the testset).
        """
        # A. Get the next brand sequentially from the cycle
        try:
            selected_brand_name = next(self.brand_iterator)
        except StopIteration:
            # Gracefully signal that we are done
            print("Iterator finished: All brands have been processed.")
            return None

        # B. Retrieve all images specific to this chosen brand
        images_of_current_brand = self.brands_map[selected_brand_name]

        # Select a random number between 1 to 5 which is the number of images of the support brand guaranteed in the query set
        num_query_guarantee_if_available = random.randint(1, 5)

        # C. Create SUPPORT SET
        #    Select 'n_shot' unique images randomly from the current brand.
        support_set_list = random.sample(images_of_current_brand, self.n_shot)
        support_set_set = set(support_set_list)

        # D. Create QUERY SET (Global Subtraction)
        #    Requirement: The Query Set contains 50 images of which at least 1 is from the support brand
        #    Step 1: initialize the Query list
        query_set_list = []

        #    Step 2: Sample randomly the images to guarantee in the query set
        remaining_brand_images = list(set(images_of_current_brand) - support_set_set)
        guaranteed_images_in_query = random.sample(remaining_brand_images, min(num_query_guarantee_if_available, len(remaining_brand_images)))

        if (len(remaining_brand_images) == 0):
            print(f"for the brand {selected_brand_name} {len(remaining_brand_images)} images where put in the query set")

        #    Step 3: Sample the Negative Queries (Distractors from OTHER brands)
        # We subtract ALL images of the current brand to ensure zero accidental matches
        remaining_images_in_query = list(self.all_files_set - set(images_of_current_brand) - set(guaranteed_images_in_query))

        total_query_size = 50
        num_remaining_query = total_query_size - min(num_query_guarantee_if_available, len(remaining_brand_images))
        query_remaining = random.sample(remaining_images_in_query,  min(len(remaining_images_in_query),num_remaining_query))

        #    Step 4: Combine and Shuffle
        query_set_list = guaranteed_images_in_query + query_remaining
        random.shuffle(query_set_list)


        return {
            "brand_name": selected_brand_name,
            "support_set": support_set_list,
            "query_set": query_set_list
        }


In [4]:
def get_baseline_resnet50(device):
    """
    Scarica la ResNet50 originale pre-addestrata su ImageNet.
    Sostituisce il layer fully connected (fc) con Identity per restituire
    i feature embeddings (dimensione 2048) invece delle classi.
    """
    weights = ResNet50_Weights.DEFAULT
    model = models.resnet50(weights=weights)
    model.fc = nn.Identity()
    model.to(device)
    model.eval() 
    return model

def main():

    # 1. Configurazione Dispositivo e Trasformazioni
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    device_name = Config.device if hasattr(Config, 'device') else 'cuda'
    device = torch.device(device_name)

    # 2. Istanzia il modello Baseline
    # Chiama la funzione con i print aggiunti sopra
    model = get_baseline_resnet50(device)

    dataset_root = Config.dataset_root
    
    test_paths = getTestPaths(dataset_root) 
    iterator = FewShotIterator(test_paths, n_shot=Config.n_shot)
    
    # Nota: i print interni al loop dipendono da come è scritta evaluate_few_shot, 
    # ma qui vediamo quando entra e quando esce.
    results = evaluate_few_shot(
        model=model, 
        fewshot_iterator=iterator, 
        transform=transform, 
        device=device, 
        num_episodes=Config.num_episodes
    )

    # 5. Stampa dei Risultati
    for metrica, valore in results.items():
        print(f"{metrica.capitalize():<15}: {valore:.4f}")

if __name__ == "__main__":
    main()

Scansione cartelle in: c:\Users\Lenovo\Desktop\ProgettoFinale1\PY_script\LogoDet-3K
✅ Trovati 3000 brand totali.
[HEAVY TASK] Inizio compute_global_embeddings

[START] Inizio 'compute_global_embeddings' su 158654 immagini.
[INFO] Creazione Dataset e DataLoader lineare...
[INFO] DataLoader pronto. Totale batch da processare: 2479
[LOOP] Inizio estrazione feature (questo processo usa CPU/GPU intensamente)...
  > [Batch 2479/2479] Elaborazione batch in corso...
[INFO] Loop batch terminato. Concatenazione dei risultati...
[END] 'compute_global_embeddings' completato. Shape finale: torch.Size([158654, 2048])
[HEAVY TASK] fine compute_global_embeddings
  > [Episodio 1/10] Generazione task...
    [DEBUG Ep.1] Support size: 1, Query size: 50
  > [Episodio 10/10] Completato. (Acc parziale: 0.98)
Accuracy       : 0.9160
Precision      : 0.1542
Recall         : 0.1950
F1             : 0.1624
R@95p          : 0.1550
Map            : 0.2852
J              : 0.3716
