SCRIPT ANALYSIS: TRIPLET NETWORK TRAINING
=================================================================

Overview:
This script implements a Deep Metric Learning pipeline using a Triplet Network architecture.
The goal is to learn a 128-dimensional embedding space where intraclass distances are minimized
and interclass distances are maximized via a margin-based ranking loss.

Key Components:

1. Data Sampling Strategy (DatasetTriplet):
   - Efficient Indexing: Pre-computes a `label_to_images` map to enable O(1) retrieval of
     positive/negative samples.
   - Dynamic Triplet Construction: For every iteration, the dataset generates:
     * Anchor: The current image.
     * Positive: Randomly sampled from the same class (handles single-image classes by duplication).
     * Negative: Randomly sampled from a disjoint class.

2. Model Architecture (LogoResNet50):
   - Backbone: ResNet50 initialized with ImageNet weights.
   - Projection Head: The classification head is replaced by a linear layer projecting
     features to `embedding_dim=128`.
   - Progressive Fine-Tuning: Includes `freeze_numer_of_layer` logic to selectively freeze
     ResNet blocks (from conv1 up to layer4) during transfer learning.

3. Optimization Objective:
   - Loss Function: `TripletMarginLoss` with `margin=1.0` and `p=2` (Euclidean distance).
     The loss ensures: d(a, p) + margin < d(a, n).
   - Optimizer: Adam with a conservative learning rate (1e-5) to preserve pretrained feature quality.

In [None]:
import torch

class Config:
    # 1. SETUP
    project_name = "FewShot"

    # Paths for saving results and checkpoints
    logs_dir = "./logs"
    checkpoints_dir = "./checkpoints"

    # Device configuration
    if torch.backends.mps.is_available():
     device = "mps"
    elif torch.cuda.is_available():
     device = "cuda"
    else:
     device = "cpu"
    seed = 42  # For reproducibility

    # 2. DATASET PATH
    dataset_root = "LogoDet-3K/LogoDet-3K-divided"
    csv_index_path = "LogoDet-3K/brand_to_index.csv"

    # Split Ratios: 70% Train, 20% Validation
    train_split_ratio = 0.7
    val_split_ratio = 0.2

    # 3. TRAINING HYPERPARAMETERS
    epochs = 20
    batch_size = 8
    learning_rate = 1e-5

    # 4. MODEL ARCHITECTURE
    backbone = "resnet50"
    pretrained = True
    embedding_dim = 128

    # TRAINED MODEL PATH
    trained_model_path = ""

    # Prediciton threadshold used to decide if two logos are the same during inference
    prediciton_threashold = 0.5




    freeze_layers = 0
    # Transfer Learning Strategy
    freeze_early_layers = True
    # Unfreeze all layers after this specific epoch for fine-tuning
    unfreeze_at_epoch = 5

    # 5. LOSS FUNCTION
    margin = 0.2           # Minimal distance between different logos


In [None]:
import sys
import os
import random
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
from torchvision import transforms
try:
    from tqdm import tqdm
except ImportError:
    def tqdm(iterator, desc=""): return iterator
import glob
import torchvision.models as models
from torchvision.models import ResNet50_Weights
from google.colab import drive
import zipfile
import pandas as pd
import matplotlib.pyplot as plt

torch.manual_seed(Config.seed)
random.seed(Config.seed)

def setup_dataset(zip_path, extract_to):
    """
    Mounts Google Drive and extracts the dataset if not already present.
    """
    # 1. Mount Google Drive
    if not os.path.exists('/content/drive'):
        drive.mount('/content/drive')

    # 2. Check if the folder already exists
    if os.path.exists(extract_to):
        print(f"Dataset folder '{extract_to}' already exists. Skipping extraction.")
    else:
        print(f"Extracting dataset from {zip_path}...")
        if os.path.exists(zip_path):
            with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                zip_ref.extractall(extract_to)
            print("Extraction complete.")
        else:
            print(f"ERROR: Zip file not found at {zip_path}. Check your path.")

setup_dataset("/content/drive/MyDrive/LogoDet-3K-divided.zip", "/content/LogoDet-3K")


Dataset folder '/content/LogoDet-3K' already exists. Skipping extraction.


Custom Dataset for Triplet Learning (Anchor, Positive, Negative).

Key operations:
1. Efficient Indexing: Pre-computes a {label: [paths]} dictionary for fast positive/negative retrieval.
2. Triplet Sampling:
   - Anchor: Image at current index.
   - Positive: Random different image from the same class (handles single-image classes).
   - Negative: Random image from a different class.
3. Robustness: Includes try-except block to return black fallback images if file loading fails.

In [None]:


class DatasetTriplet(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

        # Load label string to index mapping
        df = pd.read_csv(Config.csv_index_path)
        self.label_to_id = {row['brand']: int(row['index']) for _, row in df.iterrows()}


        # --- OTTIMIZZAZIONE ---
        # Creiamo un dizionario {label: [lista_di_percorsi]}
        # Questo serve per trovare velocemente i positivi e i negativi senza scorrere tutto ogni volta
        self.label_to_images = {}
        for img_path in image_paths:
            # Assumiamo struttura: .../BrandName/img.jpg
            # Adatta questo split se le tue cartelle sono diverse!
            label = os.path.basename(os.path.dirname(img_path))

            if label not in self.label_to_images:
                self.label_to_images[label] = []
            self.label_to_images[label].append(img_path)

        self.labels = list(self.label_to_images.keys())

    def __getitem__(self, index):
        # 1. ANCHOR (Immagine di partenza)
        anchor_path = self.image_paths[index]
        anchor_label = os.path.basename(os.path.dirname(anchor_path))
        anchor_id = self.label_to_id[anchor_label]

        # 2. POSITIVE (Stesso brand, immagine diversa)
        potential_positives = self.label_to_images[anchor_label]

        # Se c'è solo un'immagine per quel brand (caso limite), usiamo la stessa
        if len(potential_positives) > 1:
            while True:
                pos_path = random.choice(potential_positives)
                if pos_path != anchor_path:
                    break
        else:
            pos_path = anchor_path

        # 3. NEGATIVE (Brand diverso)
        while True:
            neg_label = random.choice(self.labels)
            if neg_label != anchor_label:
                break
        neg_id = self.label_to_id[neg_label]
        neg_path = random.choice(self.label_to_images[neg_label])

        # Caricamento immagini con gestione errori (se un file è corrotto non crasha tutto)
        try:
            anchor_img = Image.open(anchor_path).convert('RGB')
            pos_img = Image.open(pos_path).convert('RGB')
            neg_img = Image.open(neg_path).convert('RGB')
        except Exception as e:
            print(f"Errore caricamento: {e}. Uso immagini nere di fallback.")
            anchor_img = Image.new('RGB', (224, 224))
            pos_img = Image.new('RGB', (224, 224))
            neg_img = Image.new('RGB', (224, 224))

        if self.transform:
            anchor_img = self.transform(anchor_img)
            pos_img = self.transform(pos_img)
            neg_img = self.transform(neg_img)

        # Ritorna le 3 immagini + la label (utile per debug)
        return anchor_img, pos_img, neg_img, anchor_id, neg_id

    def __len__(self):
        return len(self.image_paths)


ResNet50-based architecture modified for Metric Learning (generating embeddings).

Key operations:
1. Backbone Initialization: Loads a standard ResNet50 (optionally with ImageNet weights).
2. Head Replacement: Swaps the original 1000-class classifier with a linear projection layer to output embeddings of size `embedding_dim`.
3. Progressive Freezing: Implements a custom `freeze_numer_of_layer` method to selectively freeze backbone blocks (from shallow 'conv1' to deep 'layer4') for controlled fine-tuning.


In [None]:

class LogoResNet50(nn.Module):
    def __init__(self, embedding_dim=128, pretrained=True, num_of_freeze_layer=5, activation_fn=None):
        super(LogoResNet50, self).__init__()

        # 1. Load Pre-trained Weights
        # Initialize the model with weights pretrained on ImageNet for transfer learning
        if pretrained:
            weights = ResNet50_Weights.DEFAULT
            self.model = models.resnet50(weights=weights)
        else:
            self.model = models.resnet50(weights=None)

        # 2. Modify the Head (Fully Connected Layer)
        # We need to produce feature embeddings instead of class probabilities
        input_features_fc = self.model.fc.in_features # Typically 2048 for ResNet50

        head_layers = []
        # Project features to the desired embedding dimension (e.g., 128)
        head_layers.append(nn.Linear(input_features_fc, embedding_dim))

        # Add an optional activation function if provided
        if activation_fn is not None:
            head_layers.append(activation_fn)

        # Replace the original classifier with our custom embedding head
        self.model.fc = nn.Sequential(*head_layers)

        # 3. Freezing Management
        # Define the blocks here to access them in the freeze method.
        # This structure allows progressive freezing/unfreezing strategies
        self.blocks = [
            ['conv1', 'bn1'],   # Level 1
            ['layer1'],         # Level 2
            ['layer2'],         # Level 3
            ['layer3'],         # Level 4
            ['layer4'],         # Level 5: Entire backbone frozen
        ]

        # Apply the initial freezing configuration
        self.freeze_numer_of_layer(num_of_freeze_layer)

    def forward(self, x):
        return self.model(x)

    def freeze_numer_of_layer(self, num_of_freeze_layer):
        """
        Manages layer freezing for transfer learning strategies.

        Args:
            num_of_freeze_layer (int):
              0   -> All layers unlocked (Full Fine-Tuning)
              1-5 -> Progressively freezes the backbone layers from shallow to deep
        """

        # STEP 1: RESET. Unfreeze everything (requires_grad = True).
        # This ensures we start from a clean state before applying new constraints.
        for param in self.model.parameters():
            param.requires_grad = True

        # If num is 0, exit immediately (Full Fine-Tuning mode)
        if num_of_freeze_layer == 0:
            print("Configuration: Full Fine-Tuning (All layers are trainable)")
            return

        # Safety check to avoid index out of bounds
        limit = min(num_of_freeze_layer, len(self.blocks))

        frozen_list = []

        # STEP 2: Progressively freeze the requested blocks
        for i in range(limit):
            current_blocks = self.blocks[i]
            for block_name in current_blocks:
                # Retrieve the layer by name
                layer = getattr(self.model, block_name)

                # Freeze parameters for this specific block
                for param in layer.parameters():
                    param.requires_grad = False

                frozen_list.append(block_name)

        print(f"Freezing Level {limit}. Frozen blocks: {frozen_list}")



Splits the dataset into training and validation sets at the brand level, ensuring no class overlap.

Key operations:
1. Brand Separation: Divides brand folders into train/val subsets based on `val_split` using a fixed seed to ensure reproducibility.
2. Adaptive Downsampling: If `total_set_size` is enforced, calculates the quota of images per brand. If this falls below `min_images_per_brand`, it reduces the number of participating brands to ensure the remaining ones meet the minimum image count.
3. Image Collection: Randomly samples the calculated number of images for each selected brand, or retrieves all images if no total size limit is set.

In [None]:
def getTrainValPaths(root_dir, val_split, total_set_size=None, min_images_per_brand=2):
    train_val_path = os.path.join(root_dir, 'train_val')
    train_val_brands = []

    # Collect brand folders
    if not os.path.exists(train_val_path):
        print(f"Warning: {train_val_path} not found.")
        return [], []

    for category in os.listdir(train_val_path):
        cat_path = os.path.join(train_val_path, category)
        if os.path.isdir(cat_path):
            for brand in os.listdir(cat_path):
                brand_full_path = os.path.join(cat_path, brand)
                if os.path.isdir(brand_full_path):
                    train_val_brands.append(brand_full_path)

    # Split brands into Train and Val
    val_size = int(len(train_val_brands) * val_split)
    train_size = len(train_val_brands) - val_size
    generator = torch.Generator().manual_seed(Config.seed)
    train_subset, val_subset = random_split(train_val_brands, [train_size, val_size], generator=generator)

    train_brand_list = [train_val_brands[i] for i in train_subset.indices]
    val_brand_list = [train_val_brands[i] for i in val_subset.indices]

    train_data_list = []
    val_data_list = []

    # Sampling Logic
    if total_set_size is not None:
        images_per_brand = round(total_set_size / len(train_val_brands))

        if images_per_brand < min_images_per_brand:
            print(f"Not enough images per brand ({images_per_brand}), downscaling brand sets to ensure {min_images_per_brand} images/brand.")

            # Calculate how many brands we can actually afford
            new_total_brand_count = round(total_set_size / min_images_per_brand)
            new_val_size = round(new_total_brand_count * val_split)
            new_train_size = new_total_brand_count - new_val_size

            train_brand_list = random.sample(train_brand_list, min(len(train_brand_list), new_train_size))
            val_brand_list = random.sample(val_brand_list, min(len(val_brand_list), new_val_size))
            images_per_brand = min_images_per_brand

        for brand in train_brand_list:
            imgs = glob.glob(os.path.join(brand, '*.jpg'))

            if len(imgs) < min_images_per_brand:
                print(f"images are less than {min_images_per_brand} for this brand: {brand} in the TRAIN set")

            train_data_list.extend(random.sample(imgs, min(images_per_brand, len(imgs))))

        for brand in val_brand_list:
            imgs = glob.glob(os.path.join(brand, '*.jpg'))

            if len(imgs) < min_images_per_brand:
                print(f"images are less than {min_images_per_brand} for this brand: {brand} in the VALIDATION set")

            val_data_list.extend(random.sample(imgs, min(images_per_brand, len(imgs))))
    else:
        for brand in train_brand_list:
            train_data_list.extend(glob.glob(os.path.join(brand, '*.jpg')))
        for brand in val_brand_list:
            val_data_list.extend(glob.glob(os.path.join(brand, '*.jpg')))

    return train_data_list, val_data_list



Define the metric evaluator used to guide training

In [None]:
class MetricEvaluator:
    """
    A class to calculate evaluation metrics for Few-Shot Learning and Metric Learning.

    Implements:
    1. Discriminant Ratio (J): Optimized scalar implementation (O(d) memory).
    2. Mean Average Precision (mAP): Ranking quality metric.
    3. Recall at Fixed Precision (R@P): Operational metric.
    4. Precision & Recall: Raw metrics at a specific similarity threshold.
    5. F1 Score: Harmonic mean of Precision and Recall.
    """

    def __init__(self, device=None):
        """
        Initialize the evaluator.

        Args:
            device (str): 'cuda' or 'cpu'. If None, detects automatically.
        """
        if device:
            self.device = device
        else:
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.epsilon = 1e-6  # For numerical stability

    def compute_discriminant_ratio(self, embeddings, labels):
        """
        Calculates the Discriminant Ratio (J) using the optimized Scalar approach.

        Theory:
            J = Tr(Sb) / Tr(Sw)
            Using the Trace Trick: Tr(Sw) = Tr(St) - Tr(Sb)

        Args:
            embeddings (torch.Tensor): Tensor of shape (Batch_Size, Dimension).
            labels (torch.Tensor): Tensor of class labels.

        Returns:
            float: The Discriminant Ratio score.
        """
        embeddings = embeddings.to(self.device)
        labels = labels.to(self.device)

        # 1. Global Mean Computation
        global_mean = embeddings.mean(dim=0)

        # 2. Calculate Trace of Total Scatter (St)
        # Sum of squared Euclidean distances of all points from the global mean.
        tr_st = torch.sum((embeddings - global_mean) ** 2)

        # 3. Calculate Trace of Between-Class Scatter (Sb)
        tr_sb = 0
        unique_classes = torch.unique(labels)

        for c in unique_classes:
            class_mask = (labels == c)
            class_embeddings = embeddings[class_mask]
            n_c = class_embeddings.size(0)

            if n_c > 0:
                mu_c = class_embeddings.mean(dim=0)
                tr_sb += n_c * torch.sum((mu_c - global_mean) ** 2)

        # 4. Calculate Trace of Within-Class Scatter (Sw)
        tr_sw = tr_st - tr_sb

        # Calculate J
        j_score = tr_sb / (tr_sw + self.epsilon)

        return j_score.item()

    def compute_map(self, query_emb, gallery_emb, query_labels, gallery_labels):
        """
        Calculates Mean Average Precision (mAP).
        """
        query_emb = query_emb.to(self.device)
        gallery_emb = gallery_emb.to(self.device)
        query_labels = query_labels.to(self.device)
        gallery_labels = gallery_labels.to(self.device)

        # L2 Normalize for Cosine Similarity
        query_emb = F.normalize(query_emb, p=2, dim=1)
        gallery_emb = F.normalize(gallery_emb, p=2, dim=1)

        # Similarity Matrix: S = Q * G^T
        similarity_matrix = torch.matmul(query_emb, gallery_emb.T)

        num_queries = query_labels.size(0)
        average_precisions = []

        for i in range(num_queries):
            scores = similarity_matrix[i]
            target_label = query_labels[i]

            # Ranking
            sorted_indices = torch.argsort(scores, descending=True)
            sorted_gallery_labels = gallery_labels[sorted_indices]

            # Relevance Mask
            relevance_mask = (sorted_gallery_labels == target_label).float()

            total_relevant = relevance_mask.sum()
            if total_relevant == 0:
                average_precisions.append(0.0)
                continue

            # Cumulative Precision
            cumsum = torch.cumsum(relevance_mask, dim=0)
            ranks = torch.arange(1, len(relevance_mask) + 1).to(self.device)
            precisions = cumsum / ranks

            # Average Precision (AP)
            ap = (precisions * relevance_mask).sum() / total_relevant
            average_precisions.append(ap.item())

        if not average_precisions:
            return 0.0
        return sum(average_precisions) / len(average_precisions)

    def compute_precision_recall(self, similarity_scores, is_match, threshold=0.5):
        """
        Calculates raw Precision and Recall at a specific similarity threshold.

        Definitions:
            Precision = TP / (TP + FP)
            Recall    = TP / (TP + FN)

        Args:
            similarity_scores (torch.Tensor): 1D tensor of scores (0.0 to 1.0).
            is_match (torch.Tensor): 1D binary tensor (Ground Truth).
            threshold (float): Cutoff for deciding if a retrieval is Positive.

        Returns:
            tuple: (precision, recall)
        """
        similarity_scores = similarity_scores.to(self.device)
        is_match = is_match.to(self.device)

        # Binarize predictions: 1 if score >= threshold (Positive), else 0 (Negative)
        predicted_positive = (similarity_scores >= threshold).float()

        # True Positives (TP): Predicted Positive AND Actually Match
        tp = (predicted_positive * is_match).sum()

        # False Positives (FP): Predicted Positive BUT Actually Non-Match
        fp = (predicted_positive * (1 - is_match)).sum()

        # False Negatives (FN): Predicted Negative BUT Actually Match
        # (We invert the prediction mask to find negatives)
        fn = ((1 - predicted_positive) * is_match).sum()

        precision = tp / (tp + fp + self.epsilon)
        recall = tp / (tp + fn + self.epsilon)

        return precision.item(), recall.item()

    def compute_recall_at_fixed_precision(self, similarity_scores, is_match, min_precision=0.95):
        """
        Calculates Recall at a Fixed Precision (R@P).
        Finds the lowest threshold where Precision >= min_precision.
        """
        similarity_scores = similarity_scores.to(self.device)
        is_match = is_match.to(self.device)

        sorted_indices = torch.argsort(similarity_scores, descending=True)
        sorted_matches = is_match[sorted_indices]

        tps = torch.cumsum(sorted_matches, dim=0)
        total_retrieved = torch.arange(1, len(sorted_matches) + 1).to(self.device)

        precisions = tps / total_retrieved

        # Find indices where Precision satisfies the constraint
        valid_indices = torch.where(precisions >= min_precision)[0]

        if len(valid_indices) == 0:
            return 0.0

        cutoff_index = valid_indices[-1]

        # Recall = TP_at_cutoff / Total_Relevant_In_Dataset
        total_relevant_in_dataset = is_match.sum()

        if total_relevant_in_dataset == 0:
            return 0.0

        recall = tps[cutoff_index] / total_relevant_in_dataset

        return recall.item()

    def compute_f1_score(self, precision, recall):
        """
        Calculates F1 Score (Harmonic Mean).
        """
        if (precision + recall) == 0:
            return 0.0
        return 2 * (precision * recall) / (precision + recall)

Training script for a Triplet Network (Anchor, Positive, Negative).

Key operations:
1. Data Pipeline:
   - Splits data into train/val and applies heavy augmentation (ColorJitter, Flip) to training data.
   - Initializes `DatasetTriplet` which yields triplets of images.
2. Model Setup: Initializes `LogoResNet50` on the specified device.
3. Optimization:
   - Loss: Uses `TripletMarginLoss` (margin=1.0) to ensure the Anchor is closer to the Positive than the Negative by at least the margin.
   - Optimizer: Adam with a low learning rate (1e-5).
4. Training Loop: Feeds triplets into the network to generate three embeddings, calculates loss, and updates weights.
5. Monitoring: Evaluates on the validation set and saves the 'best' model (lowest validation loss) and periodic checkpoints.

In [None]:
def cosine_similarity(emb1, emb2):
    """
    1. Triplet (Batch-to-Batch): emb1 [N, D], emb2 [N, D] -> returns [N]
    """
    # Normalize along the embedding dimension
    emb1_norm = F.normalize(emb1, p=2, dim=-1)
    emb2_norm = F.normalize(emb2, p=2, dim=-1)

    # Now return the dot product
    return (emb1_norm * emb2_norm).sum(dim=-1)

def train_triplet():

    save_dir = os.path.join("checkpoints", "triplet_run")
    os.makedirs(save_dir, exist_ok=True)


    device = torch.device(Config.device)

    metrics_to_plot = {
        "loss": [], "mAP": [], "J": [], "accuracy": [],
        "precision": [], "recall": [], "f1": [], "r@95p": []
    }

    # 1. Dataset e Dataloader

    train_files, val_files = getTrainValPaths(
        Config.dataset_root,
        val_split=Config.val_split_ratio,
        total_set_size=100,
        min_images_per_brand=2
    )

    # Transformations
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_dataset = DatasetTriplet(train_files, transform=train_transform)
    val_dataset = DatasetTriplet(val_files, transform=val_transform)

    # Dataloader
    train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=Config.batch_size, shuffle=False, num_workers=2)

    # 2. Model
    print("Model Initialization (Triplet)...")
    # Using Freeze=0
    model = LogoResNet50(embedding_dim=Config.embedding_dim, pretrained=Config.pretrained, num_of_freeze_layer=Config.freeze_layers)
    model = model.to(device)

    # 3. Loss e Optimizer
    # Margin 1.0
    criterion = nn.TripletMarginLoss(margin=1.0, p=2)

    # Optimizer
    # Using 0.00001 (1e-5)
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.00001)
    # 4. Training Loop
    best_val_loss = float('inf')
    num_epochs = 100

    print(f"Starting training Triplet for {num_epochs} epochs...")

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        # Progress bar
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

        # The dataset returns: anchor, positive, negative, label
        for anchor, positive, negative, _, _ in pbar:
            anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)

            optimizer.zero_grad()

            # Forward pass triplo
            out_a = model(anchor)
            out_p = model(positive)
            out_n = model(negative)

            # Calculate Loss
            loss = criterion(out_a, out_p, out_n)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            pbar.set_postfix({'loss': f"{loss.item():.4f}"})

        avg_train_loss = running_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}] - Train Loss: {avg_train_loss:.4f}")

        # --- VALIDATION ---
        model.eval()
        evaluator = MetricEvaluator(device=device)

        # Metric accumulators
        accuracies = []
        precisions = []
        recalls = []
        f1_scores = []
        r_at_95p = []

        # For Discriminant Ratio
        j_embeddings = []
        j_labels = []

        # For mAP
        map_scores = []
        # For loss
        val_loss = 0.0

        with torch.no_grad():
            for anchor, positive, negative, anchor_label, neg_label in val_loader:
                anchor, positive, negative, anchor_label, neg_label = anchor.to(device), positive.to(device), negative.to(device), anchor_label.to(device), neg_label.to(device)

                out_a = model(anchor)
                out_p = model(positive)
                out_n = model(negative)

                # Compute validation loss
                loss = criterion(out_a, out_p, out_n)
                val_loss += loss.item()

                # Discrimination ratio J
                j_embeddings.append(out_a)
                j_embeddings.append(out_p)
                j_embeddings.append(out_n)
                j_labels.append(anchor_label)
                j_labels.append(anchor_label)
                j_labels.append(neg_label)

                # mAP
                qa = F.normalize(out_a, dim=1)
                gp = F.normalize(out_p, dim=1)
                gn = F.normalize(out_n, dim=1)

                gallery_emb = torch.cat([gp, gn], dim=0)
                gallery_labels = torch.cat([anchor_label, neg_label])

                batch_map = evaluator.compute_map(
                    query_emb=qa,
                    gallery_emb=gallery_emb,
                    query_labels=anchor_label,
                    gallery_labels=gallery_labels
                )
                map_scores.append(batch_map)

                # Similarities scores
                sim_pos = cosine_similarity(out_p, out_a)
                sim_neg = cosine_similarity(out_n, out_a)

                sims = torch.cat([sim_pos, sim_neg], dim=0)
                gt = torch.cat([
                    torch.ones(len(sim_pos), device=device),
                    torch.zeros(len(sim_neg), device=device)
                ])

                pred = (sims >= Config.prediciton_threashold).float()

                # Accuracy
                acc = (pred == gt).float().mean().item()
                accuracies.append(acc)

                # Precision
                prec, rec = evaluator.compute_precision_recall(
                    sims, gt, threshold=Config.prediciton_threashold
                )

                # F1 score
                f1 = evaluator.compute_f1_score(prec, rec)

                # Recall at 95 precision
                r95 = evaluator.compute_recall_at_fixed_precision(
                    sims, gt, min_precision=0.95
                )

                precisions.append(prec)
                recalls.append(rec)
                f1_scores.append(f1)
                r_at_95p.append(r95)



        avg_val_loss = val_loss / len(val_loader)

        j_embeddings_tensor = torch.cat(j_embeddings, dim=0)
        j_labels_tensor = torch.cat(j_labels, dim=0)
        j_score = evaluator.compute_discriminant_ratio(j_embeddings_tensor, j_labels_tensor)

        val_results = {
            "loss": avg_val_loss,
            "mAP": sum(map_scores) / len(map_scores),
            "J": j_score,
            "accuracy": sum(accuracies) / len(accuracies),
            "precision": sum(precisions) / len(precisions),
            "recall": sum(recalls) / len(recalls),
            "f1": sum(f1_scores) / len(f1_scores),
            "r@95p": sum(r_at_95p) / len(r_at_95p),
        }

        for key in metrics_to_plot.keys():
            metrics_to_plot[key].append(val_results[key])

        print(
            f"VALIDATION Epoch {epoch+1} | "
            f"Loss: {val_results['loss']:.4f} | "
            f"Acc: {val_results['accuracy']:.4f} | "
            f"Prec: {val_results['precision']:.4f} | "
            f"Rec: {val_results['recall']:.4f} | "
            f"F1: {val_results['f1']:.4f} | "
            f"R@95P: {val_results['r@95p']:.4f} | "
            f"mAP: {val_results['mAP']:.4f} | "
            f"J: {val_results['J']:.4f}"
        )


        # Saving checkpoint
        checkpoint_path = os.path.join(save_dir, f"model_epoch_{epoch+1}.pth")
        torch.save(model.state_dict(), checkpoint_path)

        # Saving Best Model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), os.path.join(save_dir, "best_model_triplet.pth"))
            print("New Best Triplet Model Saved!")

        print("-" * 50)
    plot_training_results(metrics_to_plot, save_dir)

def plot_training_results(history, save_dir):
    epochs = range(1, len(history['loss']) + 1)

    # Create a 2x4 grid to fit all 8 metrics
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    axes = axes.flatten()

    colors = ['red', 'blue', 'green', 'orange', 'purple', 'brown', 'cyan', 'magenta']

    for i, (key, values) in enumerate(history.items()):
        axes[i].plot(epochs, values, marker='o', color=colors[i], linestyle='-', linewidth=2)
        axes[i].set_title(key.upper(), fontsize=14, fontweight='bold')
        axes[i].set_xlabel('Epoch')
        axes[i].set_ylabel('Score')
        axes[i].grid(True, linestyle='--', alpha=0.7)

    plt.tight_layout()
    plot_path = os.path.join(save_dir, "metric_plots.png")
    plt.savefig(plot_path)
    plt.show() # This will display the plot if you are in a Notebook/IDE
    print(f"Metrics plot saved to: {plot_path}")

if __name__ == "__main__":
    train_triplet()