## **Deep Metric Learning Training Pipeline**
Overview: This notebook implements a comprehensive Deep Metric Learning pipeline specifically designed for Logo Recognition. The pipeline covers every stage of the process, including data preparation, model architecture definition, loss function implementation, and a robust training loop equipped with logging and checkpointing mechanisms.

# **Configuration System Analysis (Config)**
It separates data paths and hyperparameters from the execution logic, ensuring that the experiment is both reproducible and portable across different environments.
1. Environment & Path Management
This section handles the setup of the working environment to ensure smooth execution.
* WORK_DIR: The root directory of the project.
* dataset_root: The source location of the image dataset.
* checkpoints_base_dir: The destination folder where model weights (.pth files) and training logs are saved.
* device: Automatically detects and assigns hardware acceleration (cuda for GPU or cpu).
* seed: A fixed integer (set to 42) to enforce reproducibility in data splits and weight initialization.
2. Model Architecture
Defines the neural network structure without modifying the model class directly.
* backbone: The pre-trained network architecture used as the feature extractor (set to 'resnet50').
* embedding_dim: The dimension of the output vector (set to 128). This defines the coordinates of the geometric space where images are mapped.
* freeze_layers: An integer determining how many initial layers are locked to preserve pre-trained features (0 indicates Full Fine-tuning).
3. Training Strategy
Contains the general parameters governing the learning cycle.
* loss_type: The master switch (options: 'triplet', 'euclidean', 'cosine'). This determines which Dataset class is loaded and which Loss function is instantiated.
* epochs: The total number of training iterations over the dataset.
* split_ratios: The percentages used for partitioning data into Training, Validation, and Test sets.
4. Hyperparameter Dictionary (HYPERPARAMS)
A nested dictionary containing optimized configuration recipes tailored for each specific loss function. This allows for context-aware parameter loading.
* Triplet / Euclidean Loss: Uses margin=1.0. These loss functions work on absolute Euclidean distances.
* Cosine Loss: Uses margin=0.2. This works on normalized angular similarities (ranging from -1 to 1).
* Optimizer: Uses Adam with a conservative learning rate (1e-5) to preserve the quality of the pre-trained features.


**Dataset**

In [1]:
import sys
import os
import shutil
import random
import glob
import matplotlib.pyplot as plt
import pandas as pd
import torch
import csv
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from torchvision.models import ResNet50_Weights
from PIL import Image
from collections import defaultdict
from google.colab import drive
import zipfile
from torchsummary import summary

zip_path = '/content/drive/MyDrive/LogoDet-3K-divided.zip'
extract_path = '/content/LogoDet-3K'

if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

if not os.path.exists(extract_path):
    if os.path.exists(zip_path):

        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                zip_ref.extractall(extract_path)
        print(f"Dataset in: {extract_path}")
    else:
        print(f"ERROR: File {zip_path} not found")

else:
    print("Dataset found.")

Mounted at /content/drive
Dataset in: /content/LogoDet-3K


In [2]:
import torch

class Config:
    loss_type = "cosine"

    WORK_DIR = "/content/drive/MyDrive/LogoDet_Experiments"
    dataset_root = "/content/LogoDet-3K/LogoDet-3K-divided"

    checkpoints_base_dir = os.path.join(WORK_DIR, "runs")

    device = "cuda" if torch.cuda.is_available() else "cpu"
    seed = 42

    # Dataset Split
    total_size = "full"
    train_split_ratio = 0.8
    val_split_ratio = 0.2

    # Training Params
    epochs = 50
    backbone = "resnet50"
    pretrained = True
    embedding_dim = 256

    optimizer = "SGD"
    SGD_momentum = 0.9

    initial_freeze_layers = 4
    final_freeze_layers = 3
    LR_reduction = 1
    unfreeze_layers = True
    unfreeze_at_epoch = 15
    weight_decay = 0
    backbone_reduction_factor = 0.25

    HYPERPARAMS = {
        "euclidean": {
            "batch_size": 32,
            "margin": 1.0,
            "learning_rate": 1e-3,
            "folder_name": "contrastive_euclidean",
            "description": "Contrastive Loss"
        },
        "cosine": {
            "batch_size": 256,
            "margin": 0.2,
            "learning_rate": 1e-2,
            "folder_name": "contrastive_cosine",
            "description": "Contrastive Loss"
        },
        "triplet": {
            "batch_size": 256,
            "margin": 1.0,
            "learning_rate": 1e-1,
            "folder_name": "triplet_loss",
            "description": "Triplet Loss"
        }
    }

**Data Loading and Processing logic**

It defines how images are loaded, paired, or grouped into triplets to train a neural network using PyTorch.

Key components:

1. DatasetTriplet Class

      This class is designed for training with Triplet Loss.

      Goal: To provide the model with three images at once:

      Anchor: The reference image.

      Positive: A different image of the same brand.

      Negative: An image of a different brand.

      Mechanism:

      It parses the file path to extract the label (Brand Name).

      It maintains a dictionary (label_to_indices) to know which images belong to which brand.

      In __getitem__, it randomly samples a Positive and a Negative to form a valid triplet.

      Safety Features: It includes fallback logic (e.g., if a brand has only one image, the Positive becomes the Anchor itself; it retries up to 50 times to find a valid Negative).

2. DatasetContrastive Class

      This class is designed for training with Contrastive Loss (Siamese Networks).

      Goal: To provide the model with a Pair of images and a binary label.

      Mechanism:

      It selects a first image (img1).

      It flips a coin (50% chance) to decide if the second image (img2) should be the Same Brand (Positive pair) or a Different Brand (Negative pair).

      Output: Returns (img1, img2, label), where label is 1 for similar and 0 for different (or vice versa depending on loss implementation).

3. getTrainValPaths Function

      This utility splits the dataset ensuring brand disjointness. Unlike standard random splits, it groups images by brand first,
      
      then assigns entire brands to either the Training 

      or Validation set. This prevents "data leakage" and simulates a realistic Few-Shot scenario where the model must recognize 
      
      similarity between logos it has never seen before 

      during validation.

4. Technical Details

      Reproducibility: It sets a fixed SEED = 101 for random and torch to ensure that data splits and pairings are consistent every time you run the code.

      Cross-Platform Compatibility: It uses .replace('\\', '/') to ensure file paths work correctly on both Windows and Linux/Colab.

      Robustness: Both dataset classes have try-except blocks in load_image to handle corrupted images or path errors gracefully (returning a black image instead of crashing).

In [3]:

SEED = 101
random.seed(SEED)
torch.manual_seed(SEED)


class DatasetTriplet(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

        self.label_to_indices = defaultdict(list)
        for idx, img_path in enumerate(self.file_list):
            label = img_path.replace('\\', '/').split('/')[-2]
            self.label_to_indices[label].append(idx)


        self.labels = list(self.label_to_indices.keys())

        print(f"[DatasetTriplet] Unique brands found: {len(self.labels)}")

    def __len__(self):
        return len(self.file_list)

    def load_image(self, image_path):
        img = Image.open(image_path).convert('RGB')

        if self.transform:
            img_transformed = self.transform(img)
        else:
            img_transformed = transforms.ToTensor()(img)


        return img_transformed

    def __getitem__(self, idx):
        # 1. Anchor
        anchor_img_path = self.file_list[idx]
        anchor_label = anchor_img_path.replace('\\', '/').split('/')[-2]
        anchor = self.load_image(anchor_img_path)

        # 2. Positive
        positive_indices = [i for i in self.label_to_indices[anchor_label] if i != idx]
        if len(positive_indices) > 0:
            positive_idx = random.choice(positive_indices)
            pos_path = self.file_list[positive_idx]
        else:
            pos_path = anchor_img_path
        positive = self.load_image(pos_path)

        # 3. Negative
        while True:
            neg_label = random.choice(self.labels)
            if neg_label != anchor_label:
                break
        negative_indices = [i for i in self.label_to_indices[neg_label] if i != idx]
        negative_idx = random.choice(negative_indices)
        neg_path = self.file_list[negative_idx]
        negative = self.load_image(neg_path)

        return anchor, positive, negative


class DatasetContrastive(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform
        self.label_to_indices = defaultdict(list)
        for idx, img_path in enumerate(self.file_list):
            label = img_path.replace('\\', '/').split('/')[-2]
            self.label_to_indices[label].append(idx)


        self.labels = list(self.label_to_indices.keys())

    def __len__(self): return len(self.file_list)

    def load_image(self, image_path):
        img = Image.open(image_path).convert('RGB')
        if self.transform: img = self.transform(img)
        else: img = transforms.ToTensor()(img)
        return {"image": img}

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        label = img_path.replace('\\', '/').split('/')[-2]
        img1 = self.load_image(img_path)

        is_positive_pair = random.choice([0, 1])

        if len(self.labels) < 2:
            is_positive_pair = 1

        if is_positive_pair:
            pos_indices = [i for i in self.label_to_indices[label] if i != idx]
            if len(pos_indices) == 0:
                print("ERROR LOADING A POSITIVE MATCH FOR THE LOADED IMAGE")
                exit()
            else:
                idx2 = random.choice(pos_indices)
                path2 = self.file_list[idx2]
        else:
            neg_label = random.choice([l for l in self.label_to_indices.keys() if l != label])
            idx2 = random.choice(self.label_to_indices[neg_label])
            path2 = self.file_list[idx2]

        img2 = self.load_image(path2)
        return img1, img2, torch.tensor(is_positive_pair, dtype=torch.float32)

# ==========================================
# UTILS
# ==========================================
class DatasetTest(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform
    def __len__(self): return len(self.file_list)
    def load_image(self, image_path):
        img = Image.open(image_path).convert('RGB')
        if self.transform: img = self.transform(img)
        return {"image": img, "label": -1}
    def __getitem__(self, idx): return self.load_image(self.file_list[idx])

def getTrainValPaths(root_dir, val_split, total_set_size=None, min_images_per_brand=2):
    train_val_path = os.path.join(root_dir, 'train_val')
    train_val_brands = []

    # Collect brand folders
    if not os.path.exists(train_val_path):
        print(f"Warning: {train_val_path} not found.")
        return [], []

    for category in os.listdir(train_val_path):
        cat_path = os.path.join(train_val_path, category)
        if os.path.isdir(cat_path):
            for brand in os.listdir(cat_path):
                brand_full_path = os.path.join(cat_path, brand)
                if os.path.isdir(brand_full_path):
                    train_val_brands.append(brand_full_path)

    # Split brands into Train and Val
    val_size = int(len(train_val_brands) * val_split)
    train_size = len(train_val_brands) - val_size
    generator = torch.Generator().manual_seed(Config.seed)
    train_subset, val_subset = random_split(train_val_brands, [train_size, val_size], generator=generator)

    train_brand_list = [train_val_brands[i] for i in train_subset.indices]
    val_brand_list = [train_val_brands[i] for i in val_subset.indices]

    train_data_list = []
    val_data_list = []

    # Sampling Logic
    if total_set_size is not None:
        images_per_brand = round(total_set_size / len(train_val_brands))

        if images_per_brand < min_images_per_brand:
            print(f"Not enough images per brand ({images_per_brand}), downscaling brand sets to ensure {min_images_per_brand} images/brand.")

            # Calculate how many brands we can actually afford
            new_total_brand_count = round(total_set_size / min_images_per_brand)
            new_val_size = round(new_total_brand_count * val_split)
            new_train_size = new_total_brand_count - new_val_size

            train_brand_list = random.sample(train_brand_list, min(len(train_brand_list), new_train_size))
            val_brand_list = random.sample(val_brand_list, min(len(val_brand_list), new_val_size))
            images_per_brand = min_images_per_brand

        for brand in train_brand_list:
            imgs = glob.glob(os.path.join(brand, '*.jpg'))

            if len(imgs) < min_images_per_brand:
                print(f"images are less than {min_images_per_brand} for this brand: {brand} in the TRAIN set")

            train_data_list.extend(random.sample(imgs, min(images_per_brand, len(imgs))))

        for brand in val_brand_list:
            imgs = glob.glob(os.path.join(brand, '*.jpg'))

            if len(imgs) < min_images_per_brand:
                print(f"images are less than {min_images_per_brand} for this brand: {brand} in the VALIDATION set")

            val_data_list.extend(random.sample(imgs, min(images_per_brand, len(imgs))))
    else:
        for brand in train_brand_list:
            train_data_list.extend(glob.glob(os.path.join(brand, '*.jpg')))
        for brand in val_brand_list:
            val_data_list.extend(glob.glob(os.path.join(brand, '*.jpg')))

    return train_data_list, val_data_list

def show_contrastive_with_bboxes(img1, img2): pass

ResNet50-based architecture modified for Metric Learning (generating embeddings).

Key operations:

1.	Backbone Initialization: Loads a standard ResNet50 (optionally with ImageNet weights).
2.	Head Replacement: Swaps the original 1000-class classifier with a linear projection layer to output embeddings of size embedding_dim.
3.	Progressive Freezing: Implements a custom freeze_numer_of_layer method to selectively freeze backbone blocks (from shallow 'conv1' to deep 'layer4') for controlled fine-tuning.
4.  Differential Learning Rates (build_optimizer): Splits the network into distinct parameter groups to apply varying learning rates. The pre-trained backbone receives a reduced learning rate (scaled down by backbone_reduction_factor) to preserve generic visual features, while the newly initialized embedding head receives the full base learning rate to adapt quickly. It dynamically supports both SGD and Adam optimizers based on configuration.

In [4]:

class LogoResNet50(nn.Module):
    def __init__(self, embedding_dim=128, pretrained=True, num_of_freeze_layer=5, activation_fn=None):
        super(LogoResNet50, self).__init__()

        # 1. Load Pre-trained Weights
        # Initialize the model with weights pretrained on ImageNet for transfer learning
        if pretrained:
            weights = ResNet50_Weights.DEFAULT
            self.model = models.resnet50(weights=weights)
        else:
            self.model = models.resnet50(weights=None)

        # 2. Modify the Head (Fully Connected Layer)
        # We need to produce feature embeddings instead of class probabilities
        input_features_fc = self.model.fc.in_features # Typically 2048 for ResNet50

        head_layers = []
        # Project features to the desired embedding dimension (e.g., 128)
        head_layers.append(nn.Linear(input_features_fc, 1024))
        head_layers.append(nn.ReLU())
        head_layers.append(nn.Linear(1024, embedding_dim))

        # Add an optional activation function if provided
        if activation_fn is not None:
            head_layers.append(activation_fn)

        # Replace the original classifier with our custom embedding head
        self.model.fc = nn.Sequential(*head_layers)

        # 3. Freezing Management
        # Define the blocks here to access them in the freeze method.
        # This structure allows progressive freezing/unfreezing strategies
        self.blocks = [
            ['conv1', 'bn1'],   # Level 1
            ['layer1'],         # Level 2
            ['layer2'],         # Level 3
            ['layer3'],         # Level 4
            ['layer4'],         # Level 5: Entire backbone frozen
        ]

        # Apply the initial freezing configuration
        self.freeze_numer_of_layer(num_of_freeze_layer)

    def forward(self, x):
        return self.model(x)

    def freeze_numer_of_layer(self, num_of_freeze_layer):
        """
        Manages layer freezing for transfer learning strategies.

        Args:
            num_of_freeze_layer (int):
              0   -> All layers unlocked (Full Fine-Tuning)
              1-5 -> Progressively freezes the backbone layers from shallow to deep
        """

        # STEP 1: RESET. Unfreeze everything (requires_grad = True).
        # This ensures we start from a clean state before applying new constraints.
        for param in self.model.parameters():
            param.requires_grad = True

        # If num is 0, exit immediately (Full Fine-Tuning mode)
        if num_of_freeze_layer == 0:
            print("Configuration: Full Fine-Tuning (All layers are trainable)")
            return

        # Safety check to avoid index out of bounds
        limit = min(num_of_freeze_layer, len(self.blocks))

        frozen_list = []

        # STEP 2: Progressively freeze the requested blocks
        for i in range(limit):
            current_blocks = self.blocks[i]
            for block_name in current_blocks:
                # Retrieve the layer by name
                layer = getattr(self.model, block_name)

                # Freeze parameters for this specific block
                for param in layer.parameters():
                    param.requires_grad = False

                frozen_list.append(block_name)

        print(f"Freezing Level {limit}. Frozen blocks: {frozen_list}")


def build_optimizer(model, base_lr):
    backbone_params = []
    head_params = []

    for name, param in model.model.named_parameters():
        if not param.requires_grad:
            continue
        if name.startswith("fc."):
            head_params.append(param)
        else:
            backbone_params.append(param)

    param_groups = [
        {"params": backbone_params, "lr": base_lr * Config.backbone_reduction_factor},
        {"params": head_params,     "lr": base_lr},
    ]

    weight_decay = Config.weight_decay

    if Config.optimizer == "Adam":
        optimizer = optim.Adam(
            param_groups,
            weight_decay=weight_decay
        )
    else:
        optimizer = optim.SGD(
            param_groups,
            momentum=Config.SGD_momentum,
            weight_decay=weight_decay
        )

    return optimizer

**Loss Function** definitions used to train the LogoResNet50 model. These classes define the mathematical rules that teach the neural network how to distinguish between similar and dissimilar logo images in a vector space.
Here is a breakdown of the three classes defined in the script:
1. ContrastiveLossEuclidean
This class implements the classic Contrastive Loss based on Euclidean distance.
* Function: It creates a "Siamese" objective. It pulls pairs of images belonging to the same class (Label 1) closer together and pushes pairs from different classes (Label 0) apart.
* Mechanism:
    * If images are Similar: It minimizes the squared Euclidean distance between their embeddings.
    * If images are Different: It penalizes the model if the distance is smaller than the defined margin (default 2.0).
2. ContrastiveLossCosine
This class implements Cosine Embedding Loss.
* Function: Instead of measuring the straight-line distance (Euclidean), this measures the angle between the two feature vectors. This is often more effective for high-dimensional spaces where the magnitude of the vector matters less than its direction.
* Input Requirements: It expects labels where 1 represents similar pairs and -1 represents dissimilar pairs.
3. TripletLoss
This class implements the standard Triplet Margin Loss.
* Function: It uses a three-part input structure rather than pairs:
    * Anchor: The reference image.
    * Positive: An image of the same class as the Anchor.
    * Negative: An image of a different class.
* Goal: It forces the distance between the Anchor and the Positive to be smaller than the distance between the Anchor and the Negative by at least the specified margin (default 1.0). This is generally considered more robust than Contrastive Loss for ranking tasks like logo retrieval.




In [5]:
class ContrastiveLossEuclidean(nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLossEuclidean, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        # Label 1 = Similar | Label 0 = Dissimilar
        euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True)
        loss_contrastive = torch.mean(
            (label) * torch.pow(euclidean_distance, 2) +
            (1-label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
        )
        return loss_contrastive

class ContrastiveLossCosine(nn.Module):
    def __init__(self, margin=0.2):
        super(ContrastiveLossCosine, self).__init__()
        self.loss_fn = nn.CosineEmbeddingLoss(margin=margin)

    def forward(self, output1, output2, label):
        return self.loss_fn(output1, output2, label)

class TripletLoss(nn.Module):
    """
    Triplet Margin Loss Standard.
    Input: Anchor, Positive, Negative.
    """
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.loss_fn = nn.TripletMarginLoss(margin=margin, p=2)

    def forward(self, anchor, positive, negative):
        return self.loss_fn(anchor, positive, negative)

# **Training script**
It coordinates all core components—configuration, dataset handling, model architecture, and loss functions—to execute the full training workflow.
The script:
* Sets up the execution environment and imports all required custom and PyTorch modules.
* Loads training settings from a configuration class and dynamically selects hyperparameters based on the chosen loss function (Triplet, Euclidean, or Cosine).
* Splits the dataset into training and validation sets and applies data augmentation for better generalization.
* Initializes a ResNet-50–based embedding model and the corresponding metric learning loss.
* Runs the training and validation loop for multiple epochs, performing forward passes, loss computation, backpropagation, and optimization.
* Saves model checkpoints, logs training history to a CSV file, and generates loss curve plots to monitor convergence (Loss Curves, F1 Score Curve, Threshold Adaptation).


In [None]:
import torch.nn.functional as F  
def save_plots(model, optimizer, train_losses, val_losses, val_f1_scores, val_thresholds, output_dir, train_transform, title_suffix="", params=None, unfreeze_epoch=None):
    fig = plt.figure(figsize=(20, 8))
    gs = fig.add_gridspec(2, 3, height_ratios=[1, 1])

    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[0, 2])

    # --- Ax1: Loss ---
    ax1.set_title(f"Loss Curve {title_suffix}")
    ax1.plot(train_losses, label="Train Loss", color="blue")
    ax1.plot(val_losses, label="Val Loss", color="red")
    ax1.set_xlabel("Epochs")
    ax1.set_ylabel("Loss")
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # --- Ax2: F1 Score ---
    ax2.set_title(f"F1 Score Curve {title_suffix}")
    ax2.plot(val_f1_scores, label="Val F1 Score", color="green")
    ax2.set_xlabel("Epochs")
    ax2.set_ylabel("F1 Score")
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # --- Ax3: Threshold Evolution ---
    ax3.set_title(f"Threshold Evolution {title_suffix}")
    ax3.plot(val_thresholds, label="Best Threshold", color="purple", linestyle='-')
    ax3.set_xlabel("Epochs")
    ax3.set_ylabel("Distance Value")
    ax3.legend()
    ax3.grid(True, alpha=0.3)

    if params:

        freeze_text = f"Freeze: {Config.initial_freeze_layers} (Fixed)"

        lr_text = f"LR={params['learning_rate']}"


        if Config.unfreeze_layers and unfreeze_epoch and len(train_losses) > unfreeze_epoch:
             freeze_text = f"Freeze: {Config.initial_freeze_layers} -> {Config.final_freeze_layers} (at ep {unfreeze_epoch})"


             current_lr = params['learning_rate']
             prev_lr = current_lr * (1/Config.LR_reduction)
             lr_text = f"LR: {prev_lr:.1e} -> {current_lr:.1e}"

        info_text = (f"Config: Batch={params['batch_size']} | "
                     f"With Normalization on output | "
                     f"With lower backbone LR by a factor of {Config.backbone_reduction_factor} | "
                     f"Optimizer= {Config.optimizer} | "
                     f"weight decay: {Config.weight_decay} | "
                     f"Margin={params['margin']} | "
                     f"{lr_text} | \n"
                     f"total dataset size: {Config.total_size} | \n"
                     f"train transformations: {train_transform} | \n"
                     f"Emb={model.model.fc}\n"
                     f"{freeze_text}\n"
                     f"LR decrease factoir: {Config.LR_reduction}")

        ax_text = fig.add_subplot(gs[1, :])
        ax_text.axis('off')

        ax_text.text(0.5, 0.95, info_text, ha="center", va="top", fontsize=10,
                     wrap=True, bbox={"facecolor":"orange", "alpha":0.2, "pad":10})

    # Add unfreeze lines to all plots
    for ax in [ax1, ax2, ax3]:
        if Config.unfreeze_layers and unfreeze_epoch and len(train_losses) > unfreeze_epoch:
            ax.axvline(x=unfreeze_epoch, color='gray', linestyle='--', alpha=0.8, label='Unfreeze')

    plt.tight_layout()
    plot_path = os.path.join(output_dir, "training_plot.png")
    plt.savefig(plot_path)
    plt.close()

def calculate_f1(tp, tn, fp, fn):
    denominator = (2 * tp) + fp + fn
    if denominator == 0:
        return 0.0

    f1_score = (2 * tp) / denominator

    return f1_score

# =========================================================================
# MAIN
# =========================================================================
def main():
    print("starting execution")
    loss_type = Config.loss_type

    if hasattr(Config, 'HYPERPARAMS'):
        params = Config.HYPERPARAMS[loss_type]
        BATCH_SIZE = params['batch_size']
        MARGIN = params['margin']
        LR = params['learning_rate']
        SAVE_FOLDER = params['folder_name']
    else:
        BATCH_SIZE = 32
        MARGIN = 1.0
        LR = 1e-4
        SAVE_FOLDER = f"{loss_type}_run"

    device = torch.device(Config.device)
    save_dir = os.path.join(Config.checkpoints_base_dir, SAVE_FOLDER)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    print(f"\n{'-'*50}")
    print(f"STARTING TRAINING: {loss_type.upper()}")
    print(f"Batch Size: {BATCH_SIZE} | Margin: {MARGIN} | LR: {LR}")
    print(f"Output Folder: {save_dir}")
    print(f"{'-'*50}\n")

    # --- DATASET ---
    print("Loading dataset paths...")

    train_files, val_files = getTrainValPaths(
        Config.dataset_root,
        val_split=Config.val_split_ratio,
        min_images_per_brand=5
    )
    print(f"train_files length: {len(train_files)}")
    print(f"val_files length: {len(val_files)}")

    # Transforms
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomPerspective(p=0.3),
        transforms.RandomHorizontalFlip(p=0.3),
        transforms.RandomRotation(degrees=15),
        transforms.RandomGrayscale(p=0.1),
        transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    if loss_type == "triplet":
        print("Using DatasetTriplet")
        train_dataset = DatasetTriplet(train_files, transform=train_transform)
        val_dataset = DatasetTriplet(val_files, transform=val_transform)
    else:
        print("Using DatasetContrastive (Pairs)")
        train_dataset = DatasetContrastive(train_files, transform=train_transform)
        val_dataset = DatasetContrastive(val_files, transform=val_transform)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, pin_memory=True)

    # --- MODEL & LOSS ---
    print("Initializing LogoResNet50...")
    model = LogoResNet50(embedding_dim=Config.embedding_dim, pretrained=Config.pretrained, num_of_freeze_layer=Config.initial_freeze_layers)
    model = model.to(device)
    summary(model, input_size=(3, 224, 224))

    if loss_type == "euclidean":
        criterion = ContrastiveLossEuclidean(margin=MARGIN)
    elif loss_type == "cosine":
        criterion = ContrastiveLossCosine(margin=MARGIN)
    elif loss_type == "triplet":
        criterion = TripletLoss(margin=MARGIN)
    else:
        raise ValueError(f"Loss type {loss_type} not supported")

    optimizer = build_optimizer(model, LR)

    # --- SETUP CSV E HISTORY ---
    train_losses = []
    val_losses = []
    val_f1_scores = []
    val_thresholds = []
    csv_path = os.path.join(save_dir, "training_history.csv")
    # F1 values
    val_tp = 0
    val_tn = 0
    val_fp = 0
    val_fn = 0

    start_epoch = 0
    if os.path.exists(csv_path):
        print("Existing CSV found, loading historical data for charts")
        try:
            df = pd.read_csv(csv_path)
            train_losses = df['Train Loss'].tolist()
            val_losses = df['Val Loss'].tolist()
            if 'Val F1 Score' in df.columns:
                 val_f1_scores = df['Val F1 Score'].tolist()
            if 'Val threshold' in df.columns:
                 val_thresholds = df['Val threshold'].tolist()
            start_epoch = len(train_losses)
            print(f"Resuming from epoch {start_epoch + 1}")
            if (Config.unfreeze_layers and start_epoch > Config.unfreeze_at_epoch):
                model.freeze_numer_of_layer(Config.final_freeze_layers)
                LR = LR * Config.LR_reduction
                optimizer = build_optimizer(model, LR)
        except:
            print("Error reading CSV")
    else:
        with open(csv_path, mode='w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["Epoch", "Train Loss", "Val Loss", "Val F1 Score", "Val threshold"])

    # Resume (if the file for the current epoch exists)
    resume_ckpt = os.path.join(save_dir, f"model_epoch_{start_epoch}.pth")
    if os.path.exists(resume_ckpt) and start_epoch > 0:
        print(f"Loading weights from: {resume_ckpt}")
        model.load_state_dict(torch.load(resume_ckpt, map_location=device))
        print("Weights loaded")

    try: from tqdm import tqdm
    except ImportError: tqdm = lambda iterator, desc="": iterator

    # --- TRAINING LOOP ---
    for epoch in range(start_epoch, Config.epochs):
        if Config.unfreeze_layers and Config.unfreeze_at_epoch > 0 and epoch == Config.unfreeze_at_epoch:
                print(f"\nUNFREEZING LAYERS AT EPOCH {epoch}")


                model.freeze_numer_of_layer(Config.final_freeze_layers)
                summary(model, input_size=(3, 224, 224))


                LR = LR * Config.LR_reduction
                print(f"Learning Rate reduced to {LR}")


                optimizer = build_optimizer(model, LR)
        model.train()
        running_loss = 0.0


        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{Config.epochs}")

        for batch in pbar:
            optimizer.zero_grad()

            if loss_type == "triplet":
                anc, pos, neg = batch[0].to(device), batch[1].to(device), batch[2].to(device)
                emb_a = F.normalize(model(anc), p=2, dim=1)
                emb_p = F.normalize(model(pos), p=2, dim=1)
                emb_n = F.normalize(model(neg), p=2, dim=1)

                loss = criterion(emb_a, emb_p, emb_n)
            else:
                img1 = batch[0]['image'].to(device)
                img2 = batch[1]['image'].to(device)
                label = batch[2].to(device)

                # out1 = model(img1)
                # out2 = model(img2)
                out1 = F.normalize(model(img1), p=2, dim=1)
                out2 = F.normalize(model(img2), p=2, dim=1)


                if loss_type == "euclidean":
                    target = label.float()
                else:
                    target = label.float()
                    target[target == 0] = -1

                loss = criterion(out1, out2, target)

            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if hasattr(pbar, "set_postfix"): pbar.set_postfix({'loss': f"{loss.item():.4f}"})

        avg_loss = running_loss / len(train_loader)
        train_losses.append(avg_loss)
        print(f"Epoch {epoch+1} DONE. Train Loss: {avg_loss:.4f}")

        # --- VALIDATION ---
        model.eval()
        val_loss_acc = 0.0
        val_tp, val_tn, val_fp, val_fn = 0, 0, 0, 0
        threshold = MARGIN

        all_distances = []
        all_labels = []

        with torch.no_grad():
            for batch in val_loader:
                if loss_type == "triplet":
                    anc, pos, neg = batch[0].to(device), batch[1].to(device), batch[2].to(device)
                    emb_a = F.normalize(model(anc), p=2, dim=1)
                    emb_p = F.normalize(model(pos), p=2, dim=1)
                    emb_n = F.normalize(model(neg), p=2, dim=1)
                    loss = criterion(emb_a, emb_p, emb_n)

                    dist_pos = F.pairwise_distance(emb_a, emb_p).cpu().tolist()
                    dist_neg = F.pairwise_distance(emb_a, emb_n).cpu().tolist()

                    # Accumulate for optimal threshold search
                    all_distances.extend(dist_pos)
                    all_labels.extend([1] * len(dist_pos))
                    all_distances.extend(dist_neg)
                    all_labels.extend([0] * len(dist_neg))
                else:
                    img1, img2, label = batch[0]['image'].to(device), batch[1]['image'].to(device), batch[2].to(device)
                    # out1 = model(img1)
                    # out2 = model(img2)
                    out1 = F.normalize(model(img1), p=2, dim=1)
                    out2 = F.normalize(model(img2), p=2, dim=1)

                    target = label.float().clone()
                    if loss_type != "euclidean":
                        target[target == 0] = -1
                    loss = criterion(out1, out2, target)
                    if loss_type == "cosine":
                        dist = (1 - F.cosine_similarity(out1, out2)).cpu().tolist()
                    else:
                        dist = (F.pairwise_distance(out1, out2)).cpu().tolist()

                    all_distances.extend(dist)
                    all_labels.extend(label.cpu().tolist())

                val_loss_acc += loss.item()

        # --- FINAL F1 CALCULATION ---
        best_threshold = None
        import numpy as np
        all_distances = np.array(all_distances)
        all_labels = np.array(all_labels)

        best_f1 = 0.0
        best_threshold = 0.0
        # Try 100 thresholds between the min and max distance observed
        threshold_candidates = np.linspace(all_distances.min(), all_distances.max(), 100)

        for t in threshold_candidates:
            preds = (all_distances < t)
            tp = int(np.sum((preds == 1) & (all_labels == 1)))
            fp = int(np.sum((preds == 1) & (all_labels == 0)))
            fn = int(np.sum((preds == 0) & (all_labels == 1)))
            tn = 0 # TN not used by your calculate_f1 denominator

            current_f1 = calculate_f1(tp, tn, fp, fn)
            if current_f1 > best_f1:
                best_f1 = current_f1
                best_threshold = t
        epoch_f1 = best_f1

        avg_val_loss = val_loss_acc / len(val_loader)
        if best_threshold:
            threshold = best_threshold

        val_thresholds.append(threshold)
        val_losses.append(avg_val_loss)
        val_f1_scores.append(epoch_f1)
        print(f"VALIDATION Epoch {epoch+1}: Loss = {avg_val_loss:.4f} | F1 Score = {epoch_f1:.4f}")
        print("-" * 30)

        # SAVING
        ckpt_path = os.path.join(save_dir, f"model_epoch_{epoch+1}.pth")
        torch.save(model.state_dict(), ckpt_path)
        print(f"Checkpoint saved: {ckpt_path}")

        # UPDATING CSV E PLOT
        with open(csv_path, mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([epoch+1, avg_loss, avg_val_loss,epoch_f1, threshold])
        print("CSV updated.")
        current_params = {
                'batch_size': BATCH_SIZE,
                'margin': MARGIN,
                'learning_rate': LR
            }
        save_plots(model, optimizer, train_losses, val_losses, val_f1_scores, val_thresholds, save_dir, title_suffix=f"({loss_type})", params=current_params,
              unfreeze_epoch=Config.unfreeze_at_epoch, train_transform = train_transform)

    print("Training finished.")

if __name__ == "__main__":
    main()

starting execution

--------------------------------------------------
STARTING TRAINING: COSINE
Batch Size: 256 | Margin: 0.2 | LR: 0.01
Output Folder: /content/drive/MyDrive/LogoDet_Experiments/runs/contrastive_cosine
--------------------------------------------------

Loading dataset paths...
train_files length: 108171
val_files length: 27528
Using DatasetContrastive (Pairs)
Initializing LogoResNet50...
Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 243MB/s]


Freezing Level 4. Frozen blocks: ['conv1', 'bn1', 'layer1', 'layer2', 'layer3']
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
           Conv2d-13   

Epoch 43/50: 100%|██████████| 423/423 [10:38<00:00,  1.51s/it, loss=0.1580]

Epoch 43 DONE. Train Loss: 0.1870





VALIDATION Epoch 43: Loss = 0.2014 | F1 Score = 0.8103
------------------------------
Checkpoint saved: /content/drive/MyDrive/LogoDet_Experiments/runs/contrastive_cosine/model_epoch_43.pth
CSV updated.


Epoch 44/50: 100%|██████████| 423/423 [10:36<00:00,  1.51s/it, loss=0.1640]

Epoch 44 DONE. Train Loss: 0.1864





VALIDATION Epoch 44: Loss = 0.2045 | F1 Score = 0.8079
------------------------------
Checkpoint saved: /content/drive/MyDrive/LogoDet_Experiments/runs/contrastive_cosine/model_epoch_44.pth
CSV updated.


Epoch 45/50: 100%|██████████| 423/423 [10:35<00:00,  1.50s/it, loss=0.1972]

Epoch 45 DONE. Train Loss: 0.1840





VALIDATION Epoch 45: Loss = 0.2025 | F1 Score = 0.8096
------------------------------
Checkpoint saved: /content/drive/MyDrive/LogoDet_Experiments/runs/contrastive_cosine/model_epoch_45.pth
CSV updated.


Epoch 46/50: 100%|██████████| 423/423 [10:35<00:00,  1.50s/it, loss=0.1790]

Epoch 46 DONE. Train Loss: 0.1833





VALIDATION Epoch 46: Loss = 0.2046 | F1 Score = 0.8047
------------------------------
Checkpoint saved: /content/drive/MyDrive/LogoDet_Experiments/runs/contrastive_cosine/model_epoch_46.pth
CSV updated.


Epoch 47/50: 100%|██████████| 423/423 [10:36<00:00,  1.51s/it, loss=0.1474]

Epoch 47 DONE. Train Loss: 0.1813





VALIDATION Epoch 47: Loss = 0.2024 | F1 Score = 0.8054
------------------------------
Checkpoint saved: /content/drive/MyDrive/LogoDet_Experiments/runs/contrastive_cosine/model_epoch_47.pth
CSV updated.


Epoch 48/50: 100%|██████████| 423/423 [10:36<00:00,  1.51s/it, loss=0.1822]

Epoch 48 DONE. Train Loss: 0.1813





VALIDATION Epoch 48: Loss = 0.1997 | F1 Score = 0.8117
------------------------------
Checkpoint saved: /content/drive/MyDrive/LogoDet_Experiments/runs/contrastive_cosine/model_epoch_48.pth
CSV updated.


Epoch 49/50: 100%|██████████| 423/423 [10:36<00:00,  1.51s/it, loss=0.2018]

Epoch 49 DONE. Train Loss: 0.1794





VALIDATION Epoch 49: Loss = 0.2039 | F1 Score = 0.8049
------------------------------
Checkpoint saved: /content/drive/MyDrive/LogoDet_Experiments/runs/contrastive_cosine/model_epoch_49.pth
CSV updated.


Epoch 50/50: 100%|██████████| 423/423 [10:36<00:00,  1.51s/it, loss=0.1566]

Epoch 50 DONE. Train Loss: 0.1773





VALIDATION Epoch 50: Loss = 0.2040 | F1 Score = 0.8082
------------------------------
Checkpoint saved: /content/drive/MyDrive/LogoDet_Experiments/runs/contrastive_cosine/model_epoch_50.pth
CSV updated.
Training finished.
