## **Deep Metric Learning Training Pipeline**
Overview: This notebook implements a comprehensive Deep Metric Learning pipeline specifically designed for Logo Recognition. The pipeline covers every stage of the process, including data preparation, model architecture definition, loss function implementation, and a robust training loop equipped with logging and checkpointing mechanisms.

# **Configuration System Analysis (Config)**
It separates data paths and hyperparameters from the execution logic, ensuring that the experiment is both reproducible and portable across different environments.
1. Environment & Path Management
This section handles the setup of the working environment to ensure smooth execution.
* WORK_DIR: The root directory of the project.
* dataset_root: The source location of the image dataset.
* checkpoints_base_dir: The destination folder where model weights (.pth files) and training logs are saved.
* device: Automatically detects and assigns hardware acceleration (cuda for GPU or cpu).
* seed: A fixed integer (set to 42) to enforce reproducibility in data splits and weight initialization.
2. Model Architecture
Defines the neural network structure without modifying the model class directly.
* backbone: The pre-trained network architecture used as the feature extractor (set to 'resnet50').
* embedding_dim: The dimension of the output vector (set to 128). This defines the coordinates of the geometric space where images are mapped.
* freeze_layers: An integer determining how many initial layers are locked to preserve pre-trained features (0 indicates Full Fine-tuning).
3. Training Strategy
Contains the general parameters governing the learning cycle.
* loss_type: The master switch (options: 'triplet', 'euclidean', 'cosine'). This determines which Dataset class is loaded and which Loss function is instantiated.
* epochs: The total number of training iterations over the dataset.
* split_ratios: The percentages used for partitioning data into Training, Validation, and Test sets.
4. Hyperparameter Dictionary (HYPERPARAMS)
A nested dictionary containing optimized configuration recipes tailored for each specific loss function. This allows for context-aware parameter loading.
* Triplet / Euclidean Loss: Uses margin=1.0. These loss functions work on absolute Euclidean distances.
* Cosine Loss: Uses margin=0.2. This works on normalized angular similarities (ranging from -1 to 1).
* Optimizer: Uses Adam with a conservative learning rate (1e-5) to preserve the quality of the pre-trained features.


**Dataset**

In [None]:
import sys
import os
import shutil
import random
import glob
import matplotlib.pyplot as plt
import pandas as pd
import torch
import csv
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from torchvision.models import ResNet50_Weights
from PIL import Image
from collections import defaultdict

zip_path = '/content/LogoDet-3K.zip'
extract_path = '/content/LogoDet-3K'

if not os.path.exists(extract_path):
    if os.path.exists(zip_path):

        !unzip -q -o {zip_path} -d /content/
        print(f"Dataset in: {extract_path}")
    else:
        print(f"ERROR: File {zip_path} not found")

else:
    print("Dataset found.")

Dataset found.


In [None]:
import torch

class Config:
    loss_type = "triplet"

    WORK_DIR = "/content"
    dataset_root = "/content/LogoDet-3K/LogoDet-3K"

    checkpoints_base_dir = os.path.join(WORK_DIR, "runs")

    device = "cuda" if torch.cuda.is_available() else "cpu"
    seed = 42

    # Dataset Split
    train_split_ratio = 0.7
    val_split_ratio = 0.2

    # Training Params
    epochs = 20
    backbone = "resnet50"
    pretrained = True
    embedding_dim = 128

    freeze_layers = 0
    freeze_early_layers = True
    unfreeze_at_epoch = 5

    HYPERPARAMS = {
        "euclidean": {
            "batch_size": 32,
            "margin": 1.0,
            "learning_rate": 1e-5,
            "folder_name": "contrastive_euclidean",
            "description": "Contrastive Loss"
        },
        "cosine": {
            "batch_size": 32,
            "margin": 0.2,
            "learning_rate": 1e-5,
            "folder_name": "contrastive_cosine",
            "description": "Contrastive Loss"
        },
        "triplet": {
            "batch_size": 32,
            "margin": 1.0,
            "learning_rate": 1e-5,
            "folder_name": "triplet_loss",
            "description": "Triplet Loss"
        }
    }


**Data Loading and Processing logic**

It defines how images are loaded, paired, or grouped into triplets to train a neural network using PyTorch.

Key components:

1. DatasetTriplet Class

      This class is designed for training with Triplet Loss.

      Goal: To provide the model with three images at once:

      Anchor: The reference image.

      Positive: A different image of the same brand.

      Negative: An image of a different brand.

      Mechanism:

      It parses the file path to extract the label (Brand Name).

      It maintains a dictionary (label_to_indices) to know which images belong to which brand.

      In __getitem__, it randomly samples a Positive and a Negative to form a valid triplet.

      Safety Features: It includes fallback logic (e.g., if a brand has only one image, the Positive becomes the Anchor itself; it retries up to 50 times to find a valid Negative).

2. DatasetContrastive Class

This class is designed for training with Contrastive Loss (Siamese Networks).

Goal: To provide the model with a Pair of images and a binary label.

 Mechanism:

  It selects a first image (img1).

  It flips a coin (50% chance) to decide if the second image (img2) should be the Same Brand (Positive pair) or a Different Brand (Negative pair).

  Output: Returns (img1, img2, label), where label is 1 for similar and 0 for different (or vice versa depending on loss implementation).

3. getPathsSetsByBrand Function

This is a crucial utility for Data Splitting.

Unique Feature: Unlike standard splitters that shuffle all images, this function splits the dataset by Brand (Class), not by Image.

Example: If "Coca-Cola" is in the Training set, all Coca-Cola images go to Training. The Validation set will contain completely unseen brands (e.g., "Pepsi").

Purpose: This simulates an Open Set or Few-Shot learning scenario, testing if the model can generalize to recognize similarity rather than just memorizing specific brands.

It splits the brands into Train, Validation, and Test sets based on the provided ratios (e.g., 0.7, 0.2, 0.1).

4. Technical Details

Reproducibility: It sets a fixed SEED = 101 for random and torch to ensure that data splits and pairings are consistent every time you run the code.

Cross-Platform Compatibility: It uses .replace('\\', '/') to ensure file paths work correctly on both Windows and Linux/Colab.

Robustness: Both dataset classes have try-except blocks in load_image to handle corrupted images or path errors gracefully (returning a black image instead of crashing).

In [None]:


SEED = 101
random.seed(SEED)
torch.manual_seed(SEED)


class DatasetTriplet(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

        self.label_to_indices = defaultdict(list)
        for idx, img_path in enumerate(self.file_list):
            label = img_path.replace('\\', '/').split('/')[-2]
            self.label_to_indices[label].append(idx)


        self.labels = list(self.label_to_indices.keys())

        print(f"[DatasetTriplet] Unique brands found: {len(self.labels)}")

    def __len__(self):
        return len(self.file_list)

    def load_image(self, image_path):
        try:
            img = Image.open(image_path).convert('RGB')
        except:
            img = Image.new('RGB', (224, 224))

        if self.transform:
            img_transformed = self.transform(img)
        else:
            img_transformed = transforms.ToTensor()(img)


        return {"image": img_transformed}

    def __getitem__(self, idx):
        anchor_img_path = self.file_list[idx]
        anchor_label = anchor_img_path.replace('\\', '/').split('/')[-2]

        # 1. Anchor
        anchor_dict = self.load_image(anchor_img_path)

        # 2. Positive
        positive_indices = [i for i in self.label_to_indices[anchor_label] if i != idx]
        if len(positive_indices) > 0:
            positive_idx = random.choice(positive_indices)
            pos_path = self.file_list[positive_idx]
        else:
            pos_path = anchor_img_path
        positive_dict = self.load_image(pos_path)

        # 3. Negative
        if len(self.labels) < 2:
            neg_label = anchor_label
        else:
            attempts = 0
            while True:
                neg_label = random.choice(self.labels)
                if neg_label != anchor_label:
                    break
                attempts += 1
                if attempts > 50:
                    neg_label = anchor_label
                    break

        neg_idx = random.choice(self.label_to_indices[neg_label])
        neg_path = self.file_list[neg_idx]
        negative_dict = self.load_image(neg_path)

        return anchor_dict['image'], positive_dict['image'], negative_dict['image']


class DatasetContrastive(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform
        self.label_to_indices = defaultdict(list)
        for idx, img_path in enumerate(self.file_list):
            label = img_path.replace('\\', '/').split('/')[-2]
            self.label_to_indices[label].append(idx)


        self.labels = list(self.label_to_indices.keys())

    def __len__(self): return len(self.file_list)

    def load_image(self, image_path):
        try: img = Image.open(image_path).convert('RGB')
        except: img = Image.new('RGB', (224, 224))
        if self.transform: img = self.transform(img)
        else: img = transforms.ToTensor()(img)
        return {"image": img}

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        label = img_path.replace('\\', '/').split('/')[-2]
        img1 = self.load_image(img_path)

        is_positive_pair = random.choice([0, 1])

        if len(self.labels) < 2:
            is_positive_pair = 1

        if is_positive_pair:
            pos_indices = [i for i in self.label_to_indices[label] if i != idx]
            if len(pos_indices) > 0:
                idx2 = random.choice(pos_indices)
                path2 = self.file_list[idx2]
            else:
                path2 = img_path
        else:
            attempts = 0
            while True:
                neg_label = random.choice(self.labels)
                if neg_label != label:
                    break
                attempts += 1
                if attempts > 50:
                    neg_label = label
                    break
            idx2 = random.choice(self.label_to_indices[neg_label])
            path2 = self.file_list[idx2]

        img2 = self.load_image(path2)
        return img1, img2, torch.tensor(is_positive_pair, dtype=torch.float32)

# ==========================================
# UTILS
# ==========================================
class DatasetTest(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform
    def __len__(self): return len(self.file_list)
    def load_image(self, image_path):
        img = Image.open(image_path).convert('RGB')
        if self.transform: img = self.transform(img)
        return {"image": img, "label": -1}
    def __getitem__(self, idx): return self.load_image(self.file_list[idx])

def getPathsSetsByBrand(dir, val_split, test_split, total_set_size=None, min_images_per_brand=2):
    brand_list = []
    for category in os.listdir(dir):
        category_path = os.path.join(dir, category)
        if os.path.isdir(category_path):
            for brand in os.listdir(category_path):
                brand_path = os.path.join(category_path, brand)
                if os.path.isdir(brand_path):
                    brand_list.append(brand_path)

    if not brand_list: return [], [], []

    test_size = int(len(brand_list) * test_split)
    val_size = int(len(brand_list) * val_split)
    train_size = len(brand_list) - test_size - val_size

    generator = torch.Generator().manual_seed(SEED)
    tr_sub, val_sub, te_sub = random_split(brand_list, [train_size, val_size, test_size], generator=generator)

    train_brand_list = [brand_list[i] for i in tr_sub.indices]
    val_brand_list   = [brand_list[i] for i in val_sub.indices]
    test_brand_list  = [brand_list[i] for i in te_sub.indices]

    def collect(brands, limit=None):
        imgs = []
        for b in brands:
            f = glob.glob(os.path.join(b, '*.jpg'))
            if f:
                if limit: imgs.extend(random.sample(f, min(len(f), limit)))
                else: imgs.extend(f)
        return imgs

    limit = None
    if total_set_size:
        limit = max(min_images_per_brand, int(total_set_size / len(brand_list)))

    return collect(train_brand_list, limit), collect(val_brand_list, limit), collect(test_brand_list, limit)

def show_contrastive_with_bboxes(img1, img2): pass

ResNet50-based architecture modified for Metric Learning (generating embeddings).

Key operations:

1.	Backbone Initialization: Loads a standard ResNet50 (optionally with ImageNet weights).
2.	Head Replacement: Swaps the original 1000-class classifier with a linear projection layer to output embeddings of size embedding_dim.
3.	Progressive Freezing: Implements a custom freeze_numer_of_layer method to selectively freeze backbone blocks (from shallow 'conv1' to deep 'layer4') for controlled fine-tuning.

In [None]:
class LogoResNet50(nn.Module):
    def __init__(self, embedding_dim=128, pretrained=True, num_of_freeze_layer=5, activation_fn=None):
        super(LogoResNet50, self).__init__()

        # 1. Load Pre-trained Weights
        # Initialize the model with weights pretrained on ImageNet for transfer learning
        if pretrained:
            weights = ResNet50_Weights.DEFAULT
            self.model = models.resnet50(weights=weights)
        else:
            self.model = models.resnet50(weights=None)

        # 2. Modify the Head (Fully Connected Layer)
        # We need to produce feature embeddings instead of class probabilities
        input_features_fc = self.model.fc.in_features # Typically 2048 for ResNet50

        head_layers = []
        # Project features to the desired embedding dimension (e.g., 128)
        head_layers.append(nn.Linear(input_features_fc, embedding_dim))

        # Add an optional activation function if provided
        if activation_fn is not None:
            head_layers.append(activation_fn)

        # Replace the original classifier with our custom embedding head
        self.model.fc = nn.Sequential(*head_layers)

        # 3. Freezing Management
        # Define the blocks here to access them in the freeze method.
        # This structure allows progressive freezing/unfreezing strategies
        self.blocks = [
            ['conv1', 'bn1'],   # Level 1
            ['layer1'],         # Level 2
            ['layer2'],         # Level 3
            ['layer3'],         # Level 4
            ['layer4'],         # Level 5: Entire backbone frozen
        ]

        # Apply the initial freezing configuration
        self.freeze_numer_of_layer(num_of_freeze_layer)

    def forward(self, x):
        return self.model(x)

    def freeze_numer_of_layer(self, num_of_freeze_layer):
        """
        Manages layer freezing for transfer learning strategies.

        Args:
            num_of_freeze_layer (int):
              0   -> All layers unlocked (Full Fine-Tuning)
              1-5 -> Progressively freezes the backbone layers from shallow to deep
        """

        # STEP 1: RESET. Unfreeze everything (requires_grad = True).
        # This ensures we start from a clean state before applying new constraints.
        for param in self.model.parameters():
            param.requires_grad = True

        # If num is 0, exit immediately (Full Fine-Tuning mode)
        if num_of_freeze_layer == 0:
            print("Configuration: Full Fine-Tuning (All layers are trainable)")
            return

        # Safety check to avoid index out of bounds
        limit = min(num_of_freeze_layer, len(self.blocks))

        frozen_list = []

        # STEP 2: Progressively freeze the requested blocks
        for i in range(limit):
            current_blocks = self.blocks[i]
            for block_name in current_blocks:
                # Retrieve the layer by name
                layer = getattr(self.model, block_name)

                # Freeze parameters for this specific block
                for param in layer.parameters():
                    param.requires_grad = False

                frozen_list.append(block_name)

        print(f"Freezing Level {limit}. Frozen blocks: {frozen_list}")

**Loss Function** definitions used to train the LogoResNet50 model. These classes define the mathematical rules that teach the neural network how to distinguish between similar and dissimilar logo images in a vector space.
Here is a breakdown of the three classes defined in the script:
1. ContrastiveLossEuclidean
This class implements the classic Contrastive Loss based on Euclidean distance.
* Function: It creates a "Siamese" objective. It pulls pairs of images belonging to the same class (Label 1) closer together and pushes pairs from different classes (Label 0) apart.
* Mechanism:
    * If images are Similar: It minimizes the squared Euclidean distance between their embeddings.
    * If images are Different: It penalizes the model if the distance is smaller than the defined margin (default 2.0).
2. ContrastiveLossCosine
This class implements Cosine Embedding Loss.
* Function: Instead of measuring the straight-line distance (Euclidean), this measures the angle between the two feature vectors. This is often more effective for high-dimensional spaces where the magnitude of the vector matters less than its direction.
* Input Requirements: It expects labels where 1 represents similar pairs and -1 represents dissimilar pairs.
3. TripletLoss
This class implements the standard Triplet Margin Loss.
* Function: It uses a three-part input structure rather than pairs:
    * Anchor: The reference image.
    * Positive: An image of the same class as the Anchor.
    * Negative: An image of a different class.
* Goal: It forces the distance between the Anchor and the Positive to be smaller than the distance between the Anchor and the Negative by at least the specified margin (default 1.0). This is generally considered more robust than Contrastive Loss for ranking tasks like logo retrieval.




In [None]:
class ContrastiveLossEuclidean(nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLossEuclidean, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        # Label 1 = Similar | Label 0 = Dissimilar
        euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True)
        loss_contrastive = torch.mean(
            (label) * torch.pow(euclidean_distance, 2) +
            (1-label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
        )
        return loss_contrastive

class ContrastiveLossCosine(nn.Module):
    def __init__(self, margin=0.2):
        super(ContrastiveLossCosine, self).__init__()
        self.loss_fn = nn.CosineEmbeddingLoss(margin=margin)

    def forward(self, output1, output2, label):
        return self.loss_fn(output1, output2, label)

class TripletLoss(nn.Module):
    """
    Triplet Margin Loss Standard.
    Input: Anchor, Positive, Negative.
    """
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.loss_fn = nn.TripletMarginLoss(margin=margin, p=2)

    def forward(self, anchor, positive, negative):
        return self.loss_fn(anchor, positive, negative)

# **Training script**
It coordinates all core components—configuration, dataset handling, model architecture, and loss functions—to execute the full training workflow.
The script:
* Sets up the execution environment and imports all required custom and PyTorch modules.
* Loads training settings from a configuration class and dynamically selects hyperparameters based on the chosen loss function (Triplet, Euclidean, or Cosine).
* Splits the dataset into training and validation sets and applies data augmentation for better generalization.
* Initializes a ResNet-50–based embedding model and the corresponding metric learning loss.
* Runs the training and validation loop for multiple epochs, performing forward passes, loss computation, backpropagation, and optimization.
* Saves model checkpoints, logs training history to a CSV file, and generates loss curve plots to monitor convergence.


In [None]:
def save_plots(train_losses, val_losses, output_dir, title_suffix=""):
    plt.figure(figsize=(10, 5))
    plt.title(f"Loss Curve {title_suffix}")
    plt.plot(train_losses, label="Train Loss", color="blue")
    plt.plot(val_losses, label="Val Loss", color="red")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plot_path = os.path.join(output_dir, "training_plot.png")
    plt.savefig(plot_path)
    plt.close()

# =========================================================================
# MAIN
# =========================================================================
def main():
    loss_type = Config.loss_type

    if hasattr(Config, 'HYPERPARAMS'):
        params = Config.HYPERPARAMS[loss_type]
        BATCH_SIZE = params['batch_size']
        MARGIN = params['margin']
        LR = params['learning_rate']
        SAVE_FOLDER = params['folder_name']
    else:
        BATCH_SIZE = 32
        MARGIN = 1.0
        LR = 1e-5
        SAVE_FOLDER = f"{loss_type}_run"

    device = torch.device(Config.device)
    save_dir = os.path.join(Config.checkpoints_base_dir, SAVE_FOLDER)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    print(f"\n{'-'*50}")
    print(f"STARTING TRAINING: {loss_type.upper()}")
    print(f"Batch Size: {BATCH_SIZE} | Margin: {MARGIN} | LR: {LR}")
    print(f"Output Folder: {save_dir}")
    print(f"{'-'*50}\n")

    # --- DATASET ---
    print("Loading dataset paths...")
    test_ratio = 1.0 - Config.train_split_ratio - Config.val_split_ratio

    train_files, val_files, test_files = getPathsSetsByBrand(
        Config.dataset_root,
        val_split=Config.val_split_ratio,
        test_split=test_ratio,
        min_images_per_brand=2
    )

    # Transforms
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.5, 1.0)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    if loss_type == "triplet":
        print("Using DatasetTriplet")
        train_dataset = DatasetTriplet(train_files, transform=train_transform)
        val_dataset = DatasetTriplet(val_files, transform=val_transform)
    else:
        print("Using DatasetContrastive (Pairs)")
        train_dataset = DatasetContrastive(train_files, transform=train_transform)
        val_dataset = DatasetContrastive(val_files, transform=val_transform)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, persistent_workers=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    # --- MODEL & LOSS ---
    print("Initializing LogoResNet50...")
    model = LogoResNet50(embedding_dim=Config.embedding_dim, pretrained=Config.pretrained, num_of_freeze_layer=Config.freeze_layers)
    model = model.to(device)

    if loss_type == "euclidean":
        criterion = ContrastiveLossEuclidean(margin=MARGIN)
    elif loss_type == "cosine":
        criterion = ContrastiveLossCosine(margin=MARGIN)
    elif loss_type == "triplet":
        criterion = TripletLoss(margin=MARGIN)
    else:
        raise ValueError(f"Loss type {loss_type} not supported")

    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR)

    # --- SETUP CSV E HISTORY ---
    train_losses = []
    val_losses = []
    csv_path = os.path.join(save_dir, "training_history.csv")


    start_epoch = 0
    if os.path.exists(csv_path):
        print("Existing CSV found, loading historical data for charts")
        try:
            df = pd.read_csv(csv_path)
            train_losses = df['Train Loss'].tolist()
            val_losses = df['Val Loss'].tolist()
            start_epoch = len(train_losses)
            print(f"Resuming from epoch {start_epoch + 1}")
        except:
            print("Error reading CSV")
    else:
        with open(csv_path, mode='w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["Epoch", "Train Loss", "Val Loss"])

    # Resume (if the file for the current epoch exists)
    resume_ckpt = os.path.join(save_dir, f"model_epoch_{start_epoch}.pth")
    if os.path.exists(resume_ckpt) and start_epoch > 0:
        print(f"Loading weights from: {resume_ckpt}")
        model.load_state_dict(torch.load(resume_ckpt, map_location=device))
        print("Weights loaded")

    try: from tqdm import tqdm
    except ImportError: tqdm = lambda iterator, desc="": iterator

    # --- TRAINING LOOP ---
    for epoch in range(start_epoch, Config.epochs):
        model.train()
        running_loss = 0.0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{Config.epochs}")

        for batch in pbar:
            optimizer.zero_grad()

            if loss_type == "triplet":
                anc, pos, neg = batch[0].to(device), batch[1].to(device), batch[2].to(device)
                loss = criterion(model(anc), model(pos), model(neg))
            else:
                img1 = batch[0]['image'].to(device)
                img2 = batch[1]['image'].to(device)
                label = batch[2].to(device)

                out1 = model(img1)
                out2 = model(img2)


                if loss_type == "euclidean":
                    target = label.float()
                else:
                    target = label.float()
                    target[target == 0] = -1

                loss = criterion(out1, out2, target)

            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if hasattr(pbar, "set_postfix"): pbar.set_postfix({'loss': f"{loss.item():.4f}"})

        avg_loss = running_loss / len(train_loader)
        train_losses.append(avg_loss)
        print(f"Epoch {epoch+1} DONE. Train Loss: {avg_loss:.4f}")

        # --- VALIDATION ---
        model.eval()
        val_loss_acc = 0.0
        with torch.no_grad():
            for batch in val_loader:
                if loss_type == "triplet":
                    anc, pos, neg = batch[0].to(device), batch[1].to(device), batch[2].to(device)
                    loss = criterion(model(anc), model(pos), model(neg))
                else:
                    img1 = batch[0]['image'].to(device)
                    img2 = batch[1]['image'].to(device)
                    label = batch[2].to(device)
                    out1, out2 = model(img1), model(img2)

                    if loss_type == "euclidean":
                        target = label.float()
                    else:
                        target = label.float(); target[target == 0] = -1
                    loss = criterion(out1, out2, target)
                val_loss_acc += loss.item()

        avg_val_loss = val_loss_acc / len(val_loader)
        val_losses.append(avg_val_loss)
        print(f"VALIDATION Epoch {epoch+1}: Loss = {avg_val_loss:.4f}")
        print("-" * 30)

        # SAVING
        ckpt_path = os.path.join(save_dir, f"model_epoch_{epoch+1}.pth")
        torch.save(model.state_dict(), ckpt_path)
        print(f"Checkpoint saved: {ckpt_path}")

        # UPDATING CSV E PLOT
        with open(csv_path, mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([epoch+1, avg_loss, avg_val_loss])
        print("CSV updated.")
        save_plots(train_losses, val_losses, save_dir, title_suffix=f"({loss_type})")

    print("Training finished.")

if __name__ == "__main__":
    main()