# Signature Verification Model

Objective: Differentiates forged offline signatures from their original counterparts.

`ResNet` pre-trained model to choose from:

- `ResNet18`
- `ResNet34`
- `ResNet50`
- `ResNet101`
- `ResNet152`

This model uses a Triplet Loss Function to aid in differentiating forged signatures from the originals.

Dataset: [CEDAR](https://www.kaggle.com/datasets/shreelakshmigp/cedardataset)

PyTorch documentation: [PyTorch Documentation](https://docs.pytorch.org/docs/stable/index.html)

PyTorch installed: 
```bash
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
```


# Notebook Setup and Configuration


## Import Libraries


In [None]:
import sys
import random
import re
from pathlib import Path
from typing import Optional, Dict, Tuple, Any, List, DefaultDict, Callable
from collections import defaultdict

# PyTorch
import torch
import torch.nn as nn
import torchvision.models as models # type: ignore
import torchvision.transforms as transforms # type: ignore
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.models import get_model_weights # type: ignore
from torch import autocast, GradScaler

# NumPy
import numpy as np
import numpy.typing as npt

# Sklearn
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_curve, auc, confusion_matrix # type: ignore
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split, KFold # type: ignore

# Graphing
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go # type: ignore

# Visualisation
from PIL import Image

## Seeding

> Did you know? 42 is a reference to Douglas Adams's The Hitchhiker Guide to the Galaxy!
> In the book, the supercomputer Deep Thought reveals that 42 is the answer to the great question of “life, the universe and everything”


In [None]:
def set_seed(seed: int = 42) -> None:
    """
    Ensure reproducibility
    """

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)  # type:ignore
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
FIXED_SEED = 42
set_seed(FIXED_SEED)

## Configurations

Modify these to control how the fine-tuning of the model goes.


In [None]:
DATASET_CONFIG: Dict[str, str] = {"DATASET_PATH": "processed_signature_images"}

LEARNING_CONFIG: Dict[str, str | int | float] = {
    "BATCH_SIZE": 64,
    "EPOCH": 150,
    "LEARNING_RATE": 1e-3,
    "EMBEDDING_DIM": 256,
    # "NUM_CLASSES": 2 ,
    "EARLY_STOPPING_PATIENT": 10,
    "CHECKPOINT_DIR": "checkpoint/",
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    "GRAD_CLIP": 1.0,
    "K_FOLDS": 5,
}

OPTIMISER_PARAMS: Dict[str, str | Tuple[float, float] | float] = {
    "optimiser": "Adam",
    "betas": (0.9, 0.999),
    "weight_decay": 1e-4,
}

SCHEDULER_PARAMS: Dict[str, str | float | int] = {
    "scheduler": "OneCycleLR",
    "max_lr": 1e-3,
    "epochs": LEARNING_CONFIG["EPOCH"],
    "pct_start": 0.2,
    "div_factor": 10.0,
    "final_div_factor": 10000.0,
}

BACKBONE_CONFIG: Dict[str, Dict[str, Any]] = {
    "resnet18": {"builder": models.resnet18, "out_channels": 512},
    "resnet34": {"builder": models.resnet34, "out_channels": 512},
    "resnet50": {"builder": models.resnet50, "out_channels": 2048},
    "resnet101": {"builder": models.resnet101, "out_channels": 2048},
    "resnet152": {"builder": models.resnet152, "out_channels": 2048},
    # Add other models if needed, and determine their output channels before pooling
    # Example for a different model:
    # 'vgg16': {'builder': models.vgg16, 'out_channels': 512},
}

# Functions and Classes


## Model Architecture

Definition of neural network model. 

The model's first layer has been modified to accept grayscale images instead of RGB images and the classification layer is removed. 

This version uses `ResNet` pre-trained model from PyTorch



In [None]:
class FeatureExtraction(nn.Module):
    """
    FeatureExtration is a module that uses pre-trained ResNet models from PyTorch to extract feature embeddings from signature images.

    The pre-trained model has its classification layer removed, as it is not useful here.

    The first layer has been modified to accept grayscale images instead of RGB images

    Attributes
    -----------

    embedding_dim: int
        Dimension of the output embedding.
    weights: str
        String value for the weight (e.g. : IMAGENET1K_V1)
    backbone_type: str
        Indicate the type of pre-trained model to use
    extra_channels: Optional[List[int]]
        Control the number of extra channels based on the size of the embeddings

    Methods
    -------

    _initialise_custom_weights(self):
        Initializes weights for custom layers in the model to ensure proper weight distribution for improved training performance.
    """

    def __init__(
        self,
        embedding_dim: int = 256,
        weights: Optional[str] = None,
        use_extra_layers: bool = True,
        backbone_type: str = "resnet50",
        extra_channels: Optional[List[int]] = None,
        dropout_rate: float = 0.3,
    ) -> None:

        super().__init__() # type: ignore

        if embedding_dim <= 0:
            raise ValueError(
                f"Embedding dimension must be positive, got {embedding_dim}"
            )
        if dropout_rate < 0 or dropout_rate > 1:
            raise ValueError(
                f"Dropout rate must be between 0 and 1, got {dropout_rate}"
            )
        if backbone_type not in BACKBONE_CONFIG:
            raise ValueError(
                f"Unsupported backbone type: {backbone_type}. Choose from {list(BACKBONE_CONFIG.keys())}"
            )
        if extra_channels is not None and not all(c > 0 for c in extra_channels):
            raise ValueError(
                f"extra_channels must be a list of positive integers or None, got {extra_channels}"
            )

        self.embedding_dim = embedding_dim
        self.use_extra_layers = use_extra_layers
        self.backbone_type = backbone_type
        self.extra_channels = extra_channels if extra_channels is not None else []
        self.dropout_rate = dropout_rate

        backbone_builder = BACKBONE_CONFIG[self.backbone_type]["builder"]
        backbone_out_channels = BACKBONE_CONFIG[self.backbone_type]["out_channels"]

        weights_enum = None
        if weights is not None:
            try:
                weights_enum_type = get_model_weights(backbone_builder)

                if weights_enum_type is None: # type: ignore
                    print(
                        f"Warning: Could not get weights type for backbone '{self.backbone_type}'. Using default random initialisation.",
                        file=sys.stderr,
                    )
                    weights = None
                else:
                    weights_enum = getattr(weights_enum_type, weights)

            except AttributeError:
                print(
                    f"Warning: Specified weights alias '{weights}' not found for {self.backbone_type}. Check available weights in torchvision documentation. Using default random initialisation.",
                    file=sys.stderr,
                )
                weights = None
            except Exception as e:
                print(
                    f"Warning: An unexpected error occurred looking up weights '{weights}' for {self.backbone_type}: {e}. Using default random initialisation.",
                    file=sys.stderr,
                )
                weights = None

        self.model = backbone_builder(weights=weights_enum)

        original_conv1 = self.model.conv1

        if original_conv1.in_channels == 3:
            self.model.conv1 = nn.Conv2d(
                in_channels=1,
                out_channels=original_conv1.out_channels,
                kernel_size=original_conv1.kernel_size,
                stride=original_conv1.stride,
                padding=original_conv1.padding,
                bias=original_conv1.bias is not None,
            )
            if weights is not None:
                with torch.no_grad():
                    self.model.conv1.weight.data = original_conv1.weight.data.mean(
                        dim=1, keepdim=True
                    )
            else:
                nn.init.kaiming_normal_(
                    self.model.conv1.weight, mode="fan_out", nonlinearity="relu"
                )
                if self.model.conv1.bias is not None:
                    nn.init.constant_(self.model.conv1.bias, 0)

        self.backbone = nn.Sequential(*list(self.model.children())[:-2])

        self.extra_layers = None
        if self.use_extra_layers and len(self.extra_channels) > 0:
            extra_layer_list: List[nn.Module] = []
            current_in_channels = backbone_out_channels

            for _, output_c in enumerate(self.extra_channels):
                extra_layer_list.append(
                    nn.Conv2d(
                        current_in_channels,
                        output_c,
                        kernel_size=3,
                        stride=1,
                        padding=1,
                        bias=False,
                    )
                )
                extra_layer_list.append(nn.BatchNorm2d(output_c))
                extra_layer_list.append(nn.ReLU(inplace=True))
                current_in_channels = output_c

            extra_layer_list.append(
                nn.Conv2d(
                    current_in_channels,
                    self.embedding_dim,
                    kernel_size=3,
                    stride=1,
                    padding=1,
                    bias=False,
                )
            )
            extra_layer_list.append(nn.BatchNorm2d(self.embedding_dim))
            extra_layer_list.append(nn.ReLU(inplace=True))

            self.extra_layers = nn.Sequential(*extra_layer_list)

        final_fc_in_features = (
            self.embedding_dim
            if (self.use_extra_layers and len(self.extra_channels) > 0)
            else backbone_out_channels
        )

        self.final_processing = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            nn.Flatten(),
            nn.BatchNorm1d(final_fc_in_features),
            nn.Linear(final_fc_in_features, self.embedding_dim),
            nn.BatchNorm1d(self.embedding_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(self.dropout_rate),
        )

        self._initialise_custom_weights()

    def _initialise_custom_weights(self) -> None:
        """
        Initializes weights for custom layers in the model.

        - Uses Kaiming Normal initialization for Conv2D and Linear layers.
        - Sets biases to zero for stability.
        - Initializes BatchNorm layers with weights of 1 and biases of 0.

        This ensures proper weight distribution for improved training performance.
        """

        if self.extra_layers is not None:
            for m in self.extra_layers.modules():
                if isinstance(m, (nn.Conv2d, nn.Linear)):
                    nn.init.kaiming_normal_(
                        m.weight, mode="fan_out", nonlinearity="relu"
                    )
                    if hasattr(m, "bias") and m.bias is not None:
                        nn.init.constant_(m.bias, 0)
                elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)

        for m in self.final_processing.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if hasattr(m, "bias") and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.ndim != 4:
            raise ValueError(
                f"Expected input tensor to be 4D (batch, channels, H, W), but got {x.ndim}D"
            )

        if x.size(1) != self.model.conv1.in_channels:
            raise ValueError(
                f"Input tensor has {x.size(1)} channels, but the model expects {self.model.conv1.in_channels} channels."
            )

        features = self.backbone(x)

        if self.use_extra_layers and self.extra_layers is not None:
            features = self.extra_layers(features)

        embeddings = self.final_processing(features)

        return F.normalize(embeddings, p=2, dim=1)

## Loss Function

The chosen loss function is the **Triplet Loss Function**, which works by measuring distances between anchor nodes and positive nodes, as well as between anchor nodes and negative nodes. By applying this loss function, the model learns to widen the distances between negative nodes and  anchor nodes, while minimising distances between anchor nodes and positive nodes. 

This loss function aligns perfectly with verification tasks, such as signature verification. 

In [None]:
class BatchTripletLoss(nn.Module):
    """

    Implements triplet loss function with batch-wise minig strategies.

    Supports batch hard and batch semi-hard mining strategies.

    Supports Euclidean and cosine distance metrics.

    [Optional] Diversity regularisation

    Attributes
    ----------
    margin: float
        Non-negative margin value for the triplet loss function.
    mining_strategy: str
        'batch_semi_hard' or 'batch_hard'. Default is 'batch_semi_hard'.
    distance_metric: str
        'euclidean' or 'cosine'. Default is 'euclidean'.
    soft_margin: bool
    lambda_diversity: float
        Diversity regularisation weight
    use_diversity: bool
    p: int
        p-norm in Euclidean distance computation

    Methods
    -------

    _compute_euclidean_distance(self, embedding_one: torch.Tensor, embedding_two: torch.Tensor) -> torch.Tensor:
        Computes the Euclidean distance between two tensors
    _compute_cosine_distance(self, embedding_one: torch.Tensor, embedding_two: torch.Tensor) -> torch.Tensor:
        Computes the Cosine distance between two tensors
    _get_triplet_mask(self, labels: torch.Tensor) -> Tuple:
        Generate boolean masks for positive and negative pairs based on their labels
    _batch_hard_mining(self, distances: torch.Tensor, mask_anchor_positive: torch.Tensor, mask_anchor_negative: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        Performs hard mining; selects hardest positive and negative pairs
    _batch_semi_hard_mining(self, distance: torch.Tensor, mask_anchor_positive: torch.Tensor, mask_anchor_negative: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        Performs semi-hard mining; selects moderately difficult positive and negative pairs
    _compute_diversity_regularisation(self, embeddings: torch.Tensor) -> torch.Tensor:
        Computes diversity regularisation

    """

    def __init__(
        self,
        margin: float = 1.0,
        mining_strategy: str = "batch_semi_hard",
        distance_metric: str = "euclidean",
        soft_margin: bool = False,
        lambda_diversity: float = 0.1,
        use_diversity: bool = True,
        p: int = 2,
        normalise_embeddings: bool = False,
    ) -> None:

        super().__init__() # type: ignore

        if margin < 0:
            raise ValueError(f"Margin must be non-negative, got {margin}")
        if mining_strategy not in ["batch_hard", "batch_semi_hard"]:
            raise ValueError(f"Invalid mining strategy, got {mining_strategy}")
        if distance_metric not in ["euclidean", "cosine"]:
            raise ValueError(f"Invalid distance metric, got {distance_metric}")

        self.margin = margin
        self.mining_strategy = mining_strategy
        self.distance_metric = distance_metric
        self.soft_margin = soft_margin
        self.lambda_diversity = lambda_diversity
        self.use_diversity = use_diversity
        self.p = p
        self.normalise_embeddings = normalise_embeddings

    def _compute_euclidean_distance(
        self, embedding_one: torch.Tensor, embedding_two: torch.Tensor
    ) -> torch.Tensor:
        """

        Computes the Euclidean distance between two tensors.

        Args
        ----
            embedding_one: torch.Tensor
                The embedding of the original or forged signature.
            embedding_two: torch.Tensor
                The embedding of the anchor signature.

        Returns
        -------
            torch.Tensor
                The Euclidean distance value

        The function accepts the embeddings of both original or forged signature image and an anchor image,
        calculates the Euclidean distance, and returns it as a tensor.

        """

        return F.pairwise_distance(embedding_one, embedding_two, p=self.p)

    def _compute_cosine_distance(
        self, embedding_one: torch.Tensor, embedding_two: torch.Tensor
    ) -> torch.Tensor:
        """

        Computes the Cosine distance between two tensors.

        Args
        ----
            embedding_one: torch.Tensor
                The embedding of the original or forged signature.
            embedding_two: torch.Tensor
                The embedding of the anchor signature.

        Returns
        -------
            torch.Tensor
                The Cosine distance value

        The function accepts the embeddings of both original or forged signature image and an anchor image,
        calculates the Cosine distance, and returns it as a tensor.

        """

        return 1 - F.cosine_similarity(embedding_one, embedding_two, dim=-1)

    def _get_triplet_mask(
        self, labels: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Generate boolean masks for positive and negative pairs based on their labels

        Args
        ----

        labels: torch.Tensor
            The signers' ids

        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor]
            The masks for the pairs

        This function accepts the labels (signers' ids), marks the pairs as positive or negative
        to enforce similarity and dissimilarity, and returns the masks in a tuple.

        """

        labels = labels.unsqueeze(1)
        mask_anchor_positive = (labels == labels.T) & ~torch.eye(
            labels.size(0), dtype=bool, device=labels.device # type: ignore
        )
        mask_anchor_negative = labels != labels.T

        return (mask_anchor_positive, mask_anchor_negative)

    def _batch_hard_mining(
        self,
        distances: torch.Tensor,
        mask_anchor_positive: torch.Tensor,
        mask_anchor_negative: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Performs hard mining by selecting hardest positive and negative pairs

        Args
        ----

        distances: torch.Tensor
            Distance calculated with Euclidean or Cosine distance metric.
        mask_anchor_positive: torch.Tensor
            Mask indicating pairs with the same signer.
        mask_anchor_negative: torch.Tensor
            Mask indicating pairs with different signers.

        Returns
        ------
        Tuple[torch.Tensor, torch.Tensor]
            Largest positive distance and smallest negative distance

        The function accepts distances between embeddings and their masks to compute
        the largest positive distance and the smallest negative distance.

        """

        positive_distance = torch.where(
            mask_anchor_positive, distances, torch.full_like(distances, -float("inf"))
        )
        max_positive_distance, _ = torch.max(positive_distance, dim=1)

        negative_distance = torch.where(
            mask_anchor_negative, distances, torch.full_like(distances, float("inf"))
        )

        min_negative_distance, _ = torch.min(negative_distance, dim=1)

        return max_positive_distance, min_negative_distance

    def _batch_semi_hard_mining(
        self,
        distances: torch.Tensor,
        mask_anchor_positive: torch.Tensor,
        mask_anchor_negative: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Performs semi-hard mining by selecting hardest positive and semi-hard pairs

        Args
        ----

        distances: torch.Tensor
            Distance calculated with Euclidean or Cosine distance metric.
        mask_anchor_positive: torch.Tensor
            Mask indicating pairs with the same signer.
        mask_anchor_negative: torch.Tensor
            Mask indicating pairs with different signers.

        Returns
        ------
        Tuple[torch.Tensor, torch.Tensor]
            Largest positive distance and semi-hard negative distance

        This function filters positive and negative distances while enforcing:
        - The hardest positive sample as the one with the largest distance.
        - The semi-hard negative sample as the hardest negative that is:
            1. Harder than the positive distance (`d(a,n) > d(a,p)`)
            2. Easier than the positive distance plus a margin (`d(a,n) < d(a,p) + margin`)
        - If no valid semi-hard negatives exist, it falls back to selecting the hardest overall negative.

        """

        positive_distance_filtered = torch.where(
            mask_anchor_positive, distances, torch.full_like(distances, -float("inf"))
        )
        hardest_positive_dist, _ = torch.max(positive_distance_filtered, dim=1)

        valid_negative_distances = torch.where(
            mask_anchor_negative, distances, torch.full_like(distances, float("inf"))
        )

        # Condition 1: d(a,n) > d(a,p)  (Negative is harder than positive)
        is_harder_than_positive = (
            valid_negative_distances > hardest_positive_dist.unsqueeze(1)
        )

        # Condition 2: d(a,n) < d(a,p) + margin (Negative is easier than (positive + margin))
        is_easier_than_margin = valid_negative_distances < (
            hardest_positive_dist.unsqueeze(1) + self.margin
        )

        semi_hard_mask = (
            mask_anchor_negative & is_harder_than_positive & is_easier_than_margin
        )

        semi_hard_negative_distances_filtered = torch.where(
            semi_hard_mask, distances, torch.full_like(distances, float("inf"))
        )

        easiest_semi_hard_negative_dist, _ = torch.min(
            semi_hard_negative_distances_filtered, dim=1
        )

        hardest_overall_negative_dist, _ = torch.min(valid_negative_distances, dim=1)

        selected_negative_dist = torch.where(
            semi_hard_mask.sum(dim=1) > 0,
            easiest_semi_hard_negative_dist,
            hardest_overall_negative_dist,
        )

        return hardest_positive_dist, selected_negative_dist

    def _compute_diversity_regularisation(
        self, embeddings: torch.Tensor
    ) -> torch.Tensor:
        """

        Computes diveristy regularisation for embeddings.

        This function works by penalising excessive similarity betwen embeddings
        by calculating the squared Frobenius norm of the difference between
        the similarity matrix and the identity matrix. The goal is to encourage
        representations that are spread out in the embedding space.

        Args
        ----

        embeddings: torch.Tensor
            A tensor containing the embeddings to be regularised.

        Returns
        -------
        torch.Tensor
            The diversity regularisation loss scaled by `self.lambda_diveristy`.


        """

        normalised_embeddings = F.normalize(embeddings, p=2, dim=1)

        similarity_matrix = normalised_embeddings @ normalised_embeddings.T

        identity_matrix = torch.eye(
            embeddings.size(0), device=embeddings.device, dtype=embeddings.dtype
        )

        diversity_loss = (similarity_matrix - identity_matrix).pow(2).sum() / (
            embeddings.size(0) * (embeddings.size(0) - 1)
        )

        return self.lambda_diversity * diversity_loss

    def forward(self, embeddings: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        batch_size = embeddings.size(0) # type: ignore

        if self.normalise_embeddings:
            embeddings = F.normalize(embeddings, p=2, dim=1)

        if self.distance_metric == "euclidean":
            pairwise_distance: torch.Tensor = self._compute_euclidean_distance(
                embeddings.unsqueeze(1), embeddings.unsqueeze(0)
            )
        elif self.distance_metric == "cosine":
            pairwise_distance: torch.Tensor = self._compute_cosine_distance(
                embeddings.unsqueeze(1), embeddings.unsqueeze(0)
            )
        else:
            raise ValueError(f"Invalid distance metrics: {self.distance_metric}")

        mask_anchor_positive, mask_anchor_negative = self._get_triplet_mask(labels)

        if self.mining_strategy == "batch_hard":
            positive_distance, negative_distance = self._batch_hard_mining(
                pairwise_distance, mask_anchor_positive, mask_anchor_negative
            )
        elif self.mining_strategy == "batch_semi_hard":
            positive_distance, negative_distance = self._batch_semi_hard_mining(
                pairwise_distance, mask_anchor_positive, mask_anchor_negative
            )
        else:
            raise ValueError(f"Invalid mining strategy: {self.mining_strategy}")

        valid_triplets_mask = ~(
            torch.isinf(positive_distance) | torch.isinf(negative_distance)
        )
        positive_distance = positive_distance[valid_triplets_mask]
        negative_distance = negative_distance[valid_triplets_mask]

        if positive_distance.numel() == 0:
            return torch.tensor(0.0, device=embeddings.device, requires_grad=True)

        if self.soft_margin:
            triplet_loss = F.softplus(positive_distance - negative_distance)
        else:
            triplet_loss = F.relu(positive_distance - negative_distance + self.margin)

        if self.use_diversity and self.lambda_diversity > 0:
            diversity_loss = self._compute_diversity_regularisation(embeddings)
        else:
            diversity_loss = 0

        total_loss = triplet_loss.mean() + diversity_loss

        return total_loss

## Dataset Classes

These classes are responsible for loading and preprocessing signature image data along with its metadata (signer IDs, image types) for the subsequent fine-tuning. 

They provide a structured output of a `torch.Tensor` for the image, a string for the signer's identity, and another string to indicate the image type ('original' or 'forged')

In [None]:
class TrainingSignatureDataset(Dataset[Tuple[torch.Tensor, str, str]]):
    """

    Loads and handles signature images

    Attributes
    ----------

    data_map: Dict[str, Dict[str, List[str]]]
        A dictionary that maps the signature images to the signer for both original images and forgeries
    transform: Optional[transforms.Compose]
        Transformation to apply to the images

    Mehods
    ------

    ___len__(self) -> int:
        Returns the total number of signature images
    __getitem__(self, index: int) -> Any:
        Retrieves an image tensor and its associated signer id.

    """

    def __init__(
        self,
        data_map: Dict[str, Dict[str, List[str]]],
        transform: Optional[transforms.Compose] = None,
    ) -> None:
        self.data_map = data_map
        self.transform = transform
        self.signer_ids = sorted(list(data_map.keys()), key=int)

        self.all_image_references: List[Tuple[str, str, int]] = []
        for signer_id in self.signer_ids:
            for index, _ in enumerate(data_map[signer_id].get("original", [])):
                self.all_image_references.append((signer_id, "original", index))

            for index, _ in enumerate(data_map[signer_id].get("forged", [])):
                self.all_image_references.append((signer_id, "forged", index))

        print(f"--- Inside TrainingSignatureDataset.__init__ ---")
        print(f"  Length of self.data_map (signer_ids): {len(self.data_map)}")
        print(
            f"  Length of self.all_image_references after population: {len(self.all_image_references)}"
        )
        print(f"----------------------------------------")

    def __len__(self) -> int:
        return len(self.all_image_references)

    def __getitem__(self, index: int) -> Any:
        """
        
        Retrieve an image tensor and its associated signer ID.

        Args
        ----
            index: int
                Index of the image reference in the dataset.

        Returns
        -------
            Tuple[torch.Tensor, torch.Tensor]:
                - image_tensor: The grayscale image converted to a PyTorch tensor.
                - signer_label_tensor: A tensor representing the signer ID.

        Raises
        ------
            ValueError: If the image type is unknown.

        The function retrieves an image path based on the signer ID and image type ('original' or 'forged'),
        loads it as a grayscale image, applies transformations if available, and returns it as a tensor
        along with the corresponding signer label.
        
        """

        signer_id, image_type, image_index = self.all_image_references[index]
        if image_type == "original":
            path = self.data_map[signer_id]["original"][image_index]
        elif image_type == "forged":
            path = self.data_map[signer_id]["forged"][image_index]
        else:
            raise ValueError(f"Uknown image type: {image_type}")

        image_pil = Image.open(path)
        # 'L' is for grayscale
        if image_pil.mode != "L":
            image_pil = image_pil.convert("L")


        if self.transform:
            image_tensor = self.transform(image_pil)  # type: ignore
        else:
            image_tensor = torch.from_numpy(image_pil).unsqueeze(0).float() / 255.0  # type: ignore

        signer_label_tensor = torch.tensor(int(signer_id), dtype=torch.long)

        return image_tensor, signer_label_tensor

In [None]:
class TestingSignatureDataset(Dataset[Tuple[torch.Tensor, str, str]]):
    """

    Prepare offline triplets for testing

    Attributes
    ----------

    data_map: Dict[str, Dict[str, List[str]]]
        A dictionary that maps the signature images to the signer for both original images and forgeries
    transform: Optional[transforms.Compose]
        Transformation to apply to the images

    Mehods
    ------

    ___len__(self) -> int:
        Returns the total number of signature images
    __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
        Returns anchor, positive, negative, and signer id

    """

    def __init__(
        self,
        data_map: Dict[str, Dict[str, List[str]]],
        transform: Optional[transforms.Compose] = None,
    ):
        self.data_map = data_map
        self.transform = transform

        self.all_signers_ids: List[str] = sorted(list(data_map.keys()), key=int)

        self.anchor_candidates: List[Tuple[str, str]] = []

        for signer_id, signature_type in data_map.items():
            paths_for_signer = signature_type.get("original", []) + signature_type.get(
                "forged", []
            )

            for path in paths_for_signer:
                self.anchor_candidates.append((path, signer_id))

        self.all_signers_ids_filtered: List[str] = []

        for signer_id in self.all_signers_ids:
            total_images_for_signer = len(
                data_map[signer_id].get("original", [])
            ) + len(data_map[signer_id].get("forged", []))

            if total_images_for_signer >= 2:
                self.all_signers_ids_filtered.append(signer_id)

        if len(self.all_signers_ids_filtered) < 2:
            raise ValueError(
                "Not enough distinct signers with at least two images each to form valid triplets. "
                "Ensure your dataset has at least two signers, each with >= 2 images for testing."
            )

        print(f"--- Inside TrainingSignatureDataset.__init__ ---")
        print(f"  Length of self.data_map (signer_ids): {len(self.data_map)}")
        print(f"----------------------------------------")

    def __len__(self):
        return len(self.anchor_candidates)

    def __getitem__(  # type: ignore
        self, index: int
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
        """
        Forms triplets

        Args
        ----
            index (int): Index of the image reference in the dataset.

        Returns
        -------
            Tuple[torch.Tensor, torch.Tensor]:
                - image_tensor: The grayscale image converted to a PyTorch tensor.
                - signer_label_tensor: A tensor representing the signer ID.

        Raises
        ------
            ValueError: If the image type is unknown.

        The function retrieves an image path based on the signer ID and image type ('original' or 'forged'),
        loads it as a grayscale image, applies transformations if available, and returns the triplets as tensors
        along with the corresponding signer ID.
        """

        anchor_path: str
        anchor_signer_id: str

        anchor_path, anchor_signer_id = self.anchor_candidates[index]
        anchor_label_int: int = int(anchor_signer_id)

        all_paths_for_anchor = self.data_map[anchor_signer_id].get(
            "original", []
        ) + self.data_map[anchor_signer_id].get("forged", [])

        positive_candidates = [p for p in all_paths_for_anchor if p != anchor_path]

        if not positive_candidates:
            print(
                f"Warning: Signer {anchor_signer_id} has insufficient images to select a distinct positive. Resampling index..."
            )
            return self.__getitem__(random.randint(0, len(self) - 1))

        positive_path = random.choice(positive_candidates)

        negative_paths_candidates: List[str] = []

        forgeries_of_anchor = self.data_map[anchor_signer_id].get("forged", [])
        negative_paths_candidates.extend(forgeries_of_anchor)

        other_signer_ids = [
            aid for aid in self.all_signers_ids_filtered if aid != anchor_signer_id
        ]

        if not other_signer_ids and not forgeries_of_anchor:
            raise ValueError(
                f"No eligible negative signers found for anchor_signer_id {anchor_signer_id}. "
                "The dataset might contain too few unique signers."
            )

        if other_signer_ids:
            for other_aid in other_signer_ids:
                negative_paths_candidates.extend(
                    self.data_map[other_aid].get("original", [])
                    + self.data_map[other_aid].get("forged", [])
                )

        if not negative_paths_candidates:
            return self.__getitem__(random.randint(0, len(self) - 1))

        negative_path = random.choice(negative_paths_candidates)

        # Assuming images are grayscale
        anchor_img = Image.open(anchor_path).convert("L")
        positive_img = Image.open(positive_path).convert("L")
        negative_img = Image.open(negative_path).convert("L")

        if self.transform:
            anchor_tensor = self.transform(anchor_img)  # type: ignore
            positive_tensor = self.transform(positive_img)  # type: ignore
            negative_tensor = self.transform(negative_img)  # type: ignore
        else:
            anchor_tensor = transforms.ToTensor()(anchor_img)
            positive_tensor = transforms.ToTensor()(positive_img)
            negative_tensor = transforms.ToTensor()(negative_img)

        return anchor_tensor, positive_tensor, negative_tensor, anchor_label_int  # type: ignore

## Utility Functions

These functions are made to assist the process of loading signature images, creating datasets and dataloaders, and the evaluation of model

### Preparing Signature Images

For these functions to work, the name images' names should be as follows:

\<original/forgeries>\_\<signer's id>\_\<image's index>

In [None]:
def extract_signer_id(file_name: str) -> str:
    """

    Extracts the first sequence of digits from a given file name and returns it as an integer.

    If no number is found, it defaults to 'UNKNOWN_SIGNER'

    Parameters
    ----------
    file_name: str
        The name of the file in string

    Returns
    -------
    str
        The extracted number as a string. If no number is found, return 'UNKOWN_SIGNER'

    """

    match = re.search(r"(?:original|forgeries)_(\d+)_", file_name)
    if match:
        return match.group(1)
    return "UNKNOWN_SIGNER"


# One liner because why not

# extract_author_id = lambda file_name: int(match.group(0)) if (match := re.search(r'(\d+)', file_name)) else 0

In [None]:
def retrieve_signature_images(
    dataset_path: Path, image_format: List[str] = [".png", ".jpg", ".jpeg", ".bmp"]
) -> List[Tuple[str, str]]:
    """

    Retrieves and groups the signature images with their respective signer id.

    Args
    ----

    dataset_path: Path
        Path to the signature images datasets
    image_format: List[str]
        The format in which the signature images are saved

    Returns
    -------

    List[Tuple[str, str]]
        The list of signature images with their signer ids

    """

    images: List[Tuple[str, str]] = []
    if not dataset_path.is_dir():
        print(f"Warning: Directory not found! {dataset_path}")
        return []

    for image_path in dataset_path.iterdir():
        if image_path.is_file() and image_path.suffix.lower() in image_format:
            signer_id = extract_signer_id(str(image_path))
            if signer_id != "UKNOWN_SIGNER":
                images.append((signer_id, str(image_path)))
            else:
                print(
                    f"Warning: Could not extract signer ID from file: {image_path.name}"
                )
    return images

In [None]:
def prepare_signature_map(
    original_signatures: List[Tuple[str, str]], forged_signatures: List[Tuple[str, str]]
) -> Dict[str, Dict[str, List[str]]]:
    """

    Group originals and forgeries with their respective signers

    Args
    ----
    original_signatures: List[Tuple[str, str]]
        List of original signature images
    forged_signatures: List[Tuple[str, str]]
        List of forged signature images

    Returns
    -------
    Dict[str, Dict[str, List[str]]]
        Dictionary that groups originals and forgeries

    """

    signature_dictionary: defaultdict[str, Dict[str, List[Any]]] = defaultdict(
        lambda: {"original": [], "forged": []}
    )
    for signer_id, image_path in original_signatures:
        signature_dictionary[signer_id]["original"].append(image_path)

    for signer_id, image_path in forged_signatures:
        signature_dictionary[signer_id]["forged"].append(image_path)

    return {k: dict(v) for k, v in signature_dictionary.items()}

In [None]:
def training_and_testing_split(
    signature_dictionary: Dict[str, Dict[str, List[str]]],
    test_ratio: float = 0.2,
    random_state: int = 42,
) -> Tuple[Dict[str, Dict[str, List[str]]], Dict[str, Dict[str, List[str]]]]:
    """

    Split the signature map into training and testing datasets.

    Args
    ----
    signature_dictionary: Dict[str, Dict[str, List[str]]]
        Signature images mapping
    test_ratio: float
        Indicates the size of the test dataset
    random_state: int
        For reproducibility

    Returns
    -------
    Tuple[Dict[str, Dict[str, List[str]]], Dict[str, Dict[str, List[str]]]]
        Contains both the training and testing mapping

    """

    singer_id: List[str] = sorted(list(signature_dictionary.keys()), key=int)

    train_validate_signature_id: List[str]
    test_signature_id: List[str]

    train_validate_signature_id, test_signature_id = train_test_split(  # type: ignore
        singer_id, test_size=test_ratio, random_state=random_state
    )

    def create_subset_map(signer_id_list: List[str]) -> Dict[str, Dict[str, List[str]]]:
        subset_map: DefaultDict[str, Dict[str, List[str]]] = defaultdict(
            lambda: {"original": [], "forged": []}
        )
        for signer_id in signer_id_list:
            if signer_id in signature_dictionary:
                subset_map[signer_id]["original"].extend(
                    signature_dictionary[signer_id].get("original", [])
                )
                subset_map[signer_id]["forged"].extend(
                    signature_dictionary[signer_id].get("forged", [])
                )
        return {k: dict(v) for k, v in subset_map.items()}

    train_val_map = create_subset_map(train_validate_signature_id)  # type: ignore
    test_map = create_subset_map(test_signature_id)  # type: ignore

    return train_val_map, test_map

### Evaluation, Plots, and Graphs

Ensure that the model in `load_model_for_inference` has the same architecture as the model used for training or the model that is to be loaded. 

In [None]:
def build_feature_extraction_model():
    """

    Model's definition

    """
    return FeatureExtraction(
        embedding_dim=256,
        weights="IMAGENET1K_V1",
        use_extra_layers=True,
        backbone_type="resnet18",
        extra_channels=[512, 256],
        dropout_rate=0.3,
    )

In [None]:
def build_batch_triplet_loss():
    """

    Triplet loss function definition

    """
    return BatchTripletLoss(
        margin=0.5,
        mining_strategy="batch_hard",
        distance_metric="euclidean",
        normalise_embeddings=False,
    )

In [None]:
def load_model_for_inference(
    checkpoint_path: str,
    device: torch.device,
    model_builder: Callable[[], nn.Module]
) -> nn.Module:
    """

    Loads model for evaluation

    Args
    ----

    checkpoint_path: str
        The path to the model
    device: torch.device
        'cuda' or 'cpu'
    model_builder: Callable[[], nn.Module]
        A function to define the model

    Returns
    -------
    nn.Module
        The loaded module, ready for inference
        
    """

    if not Path(checkpoint_path).exists():
        raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}")


    model = model_builder()


    try:
        checkpoint = torch.load(checkpoint_path, map_location=device)

    except Exception as e:
        raise RuntimeError(f"Error loading checkpoint from {checkpoint_path}: {e}")

    if "model_state_dict" not in checkpoint:
        raise ValueError(
            f"'{checkpoint_path}' does not contain 'model_state_dict'. Please ensure you saved the model's state_dict correctly"
        )


    model.load_state_dict(checkpoint["model_state_dict"])
    model.to(device)
    model.eval()
    
    print(f"Model loaded successfully from {checkpoint_path} for inference.")
    return model

In [None]:
def calculate_eer_eer_threshold(
    thresholds: npt.NDArray[np.float64],
    fpr: npt.NDArray[np.float64],
    fnr: npt.NDArray[np.float64],
) -> Tuple[float, float]:
    """

    Calculate equal error rate and equal error rate threshold according to the false positive rate and false negative rate

    Args
    ----
    thresholds: npt.NDArray[np.float64]
        Calculated threshold Receiver Operating Characteristic
    fpr: npt.NDArray[np.float64]
        False positive rate
    fnr: npt.NDArray[np.float64]
        False negative rate

    Returns
    -------
    Tuple[float, float]
        Equal error rate and equal error rate threshold

    """

    eer_threshold: float = 0.0
    eer: float = 0.0
    minimum_absolute_difference: float = float("inf")

    for i in range(len(thresholds)):
        absolute_difference = abs(fpr[i] - fnr[i])
        if absolute_difference < minimum_absolute_difference:
            minimum_absolute_difference = absolute_difference
            eer = (fpr[i] + fnr[i]) / 2
            eer_threshold = -thresholds[i]

    return eer, eer_threshold

In [None]:
def calculate_auc(
    all_labels_np: npt.NDArray[np.int64], all_distances_np: npt.NDArray[np.float32]
) -> Tuple[
    npt.NDArray[np.float64],
    npt.NDArray[np.float64],
    npt.NDArray[np.float64],
    npt.NDArray[np.float64],
    float,
]:
    """

    Calculate false positive rate, true positive rate, false negative rate, threshold and roc_auc

    Args
    ----
    all_labels_np: npt.NDArray[np.int64]
        NumPy True or False
    all_distances_np: npt.NDArray[np.float64]
        Predicted distances

    Returns
    -------
    Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64], float]
        The false positive rate, true positive rate, false negative rate, threshold and roc_auc
    """

    inverted_distances: npt.NDArray[np.float32] = -all_distances_np
    fpr, tpr, thresholds = roc_curve(all_labels_np, inverted_distances)  # type: ignore
    roc_auc = float(auc(fpr, tpr))  # type: ignore

    fnr = 1 - tpr  # type: ignore

    return fpr, tpr, fnr, thresholds, roc_auc  # type: ignore

In [None]:
def evaluate_triplet_network(
    model: nn.Module,
    test_loader: DataLoader[Tuple[torch.Tensor, str, str]],
    device: torch.device,
    margin: float = 1.0,
) -> Tuple[
    npt.NDArray[np.float32],
    npt.NDArray[np.int64],
    Dict[str, npt.NDArray[np.float32]],
    Dict[str, float],
    Dict[str, npt.NDArray[np.float64]],
]:
    
    """
    
    Evaluate the model against a dataset.

    Args
    ----
    model: nn.Module
        Model to be evaluated
    test_loader: DataLoader[Tuple[torch.Tensor, str, str]
        The dataloader must be able to be unpacked to form anchor, positive, negative, and a label
    device: torch.device
        'cuda' or 'cpu'
    margin: float
        The minimum required separation between the distance of anchor-positive pair 
        and the distance of an anchor-negative pair. 
        
        The value must be above 0.0 

    Returns
    -------
    Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64], float]
        The false positive rate, true positive rate, false negative rate, threshold and roc_auc
    """
    
    if margin < 0.0:
        raise ValueError(f"{margin} is less than 0.0")
    
    model.eval()

    all_labels: List[int] = []
    all_distances: List[float] = []

    embeddings_list: Dict[str, List[npt.NDArray[np.float32]]] = {
        "anchor": [],
        "positive": [],
        "negative": [],
    }

    total_loss: float = 0.0
    num_batches: int = 0
    embedding_dim: int = 0

    with torch.no_grad():
        for anchor, positive, negative, _ in test_loader:

            anchor = anchor.to(device)
            positive = positive.to(device)
            negative = negative.to(device)

            anchor_embed = model(anchor)
            positive_embed = model(positive)
            negative_embed = model(negative)

            if num_batches == 0:
                embedding_dim = anchor_embed.shape[-1]

            post_dist = F.pairwise_distance(anchor_embed, positive_embed)
            neg_dist = F.pairwise_distance(anchor_embed, negative_embed)

            loss = torch.mean(torch.relu(post_dist - neg_dist + margin))
            total_loss += loss.item()
            num_batches += 1

            embeddings_list["anchor"].append(anchor_embed.cpu().numpy())
            embeddings_list["positive"].append(positive_embed.cpu().numpy())
            embeddings_list["negative"].append(negative_embed.cpu().numpy())

            all_distances.extend(post_dist.cpu().tolist())  # type: ignore [attr-defined]
            all_labels.extend([1] * len(post_dist))

            all_distances.extend(neg_dist.cpu().tolist())  # type: ignore [attr-defined]
            all_labels.extend([0] * len(neg_dist))

    final_embeddings: Dict[str, npt.NDArray[np.float32]] = {}
    for key, embedding_list_for_key in embeddings_list.items():
        if embedding_list_for_key:
            final_embeddings[key] = np.concatenate(embedding_list_for_key, axis=0)
        else:
            if embedding_dim > 0:
                final_embeddings[key] = np.empty((0, embedding_dim), dtype=np.float32)
            else:
                final_embeddings[key] = np.array([], dtype=np.float32).reshape(0, -1)

    avg_loss: float = total_loss / num_batches if num_batches > 0 else 0.0

    all_distances_np: npt.NDArray[np.float32] = np.array(
        all_distances, dtype=np.float32
    )
    all_labels_np: npt.NDArray[np.int64] = np.array(all_labels, dtype=np.int64)

    eer_threshold: float = 0.0
    eer: float = 0.0

    fpr, tpr, fnr, thresholds, roc_auc = calculate_auc(all_labels_np, all_distances_np)
    eer, eer_threshold = calculate_eer_eer_threshold(thresholds, fpr, fnr)

    avg_pos_dist = (
        np.mean(all_distances_np[all_labels_np == 1])
        if np.any(all_labels_np == 1)
        else 0.0
    )
    avg_neg_dist = (
        np.mean(all_distances_np[all_labels_np == 0])
        if np.any(all_labels_np == 0)
        else 0.0
    )
    
    predicted_labels_at_eer_threshold = (all_distances_np <= eer_threshold).astype(int)
    
    accuracy = accuracy_score(all_labels_np, predicted_labels_at_eer_threshold)
    precision = precision_score(all_labels_np, predicted_labels_at_eer_threshold, pos_label=1) # type: ignore
    recall = recall_score(all_labels_np, predicted_labels_at_eer_threshold, pos_label=1) # type: ignore

    metrics: Dict[str, float | Any] = {
        "avg_triplet_loss": avg_loss,
        "avg_pos_dist": float(avg_pos_dist),
        "avg_neg_dist": float(avg_neg_dist),
        "auc_roc": roc_auc,
        "eer": float(eer),
        "eer_threshold": float(eer_threshold),
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall
    }
    
    auc_roc: Dict[str, npt.NDArray[np.float64]] = {
        "fpr": fpr,
        "tpr": tpr,
        "fnr": fnr
    }

    return all_distances_np, all_labels_np, final_embeddings, metrics, auc_roc

In [None]:
def plot_confusion_matrix(
    all_labels_np: npt.NDArray[np.int64],
    all_predictions_np: npt.NDArray[np.float32],
    best_threshold: float,
) -> None:

    predictions = [1 if d <= best_threshold else 0 for d in all_predictions_np]
    cm = confusion_matrix(all_labels_np, predictions) # type: ignore

    plt.figure(figsize=(8, 6)) # type: ignore
    sns.heatmap(cm, annot=True, fmt="g", cmap="Blues") # type: ignore
    plt.title("Confusion Matrix") # type: ignore
    plt.ylabel("True Label") # type: ignore
    plt.xlabel("Predicted Label") # type: ignore
    plt.savefig("confusion_matrix.png") # type: ignore
    plt.show() # type: ignore

In [None]:
def plot_auc_roc(fpr: npt.NDArray[np.float64], tpr: npt.NDArray[np.float64], auc_roc: float) -> None:
    plt.figure(figsize=(8, 6)) # type: ignore
    plt.plot(fpr, tpr, color="orange", label=f"AUC = {auc_roc}") # type: ignore
    plt.plot([0, 1], [0, 1], linestyle="--", color="gray") # type: ignore
    plt.xlabel("False Positive Rate") # type: ignore
    plt.ylabel("True Positive Rate") # type: ignore
    plt.title("ROC Curve") # type: ignore
    plt.legend(loc="lower right") # type: ignore
    plt.grid() # type: ignore
    plt.show() # type: ignore

In [None]:
def plot_embedding_distances(
    anchors: npt.NDArray[np.float32],
    positives: npt.NDArray[np.float32],
    negatives: npt.NDArray[np.float32],
) -> None:
    
    pca = PCA(n_components=2)
    all_embeddings = np.concatenate([anchors, positives, negatives], axis=0)
    embeddings_2d = pca.fit_transform(all_embeddings)  # type: ignore

    n_samples = len(anchors)
    anchors_2d = embeddings_2d[:n_samples] # type: ignore
    positives_2d = embeddings_2d[n_samples : 2 * n_samples] # type: ignore
    negatives_2d = embeddings_2d[2 * n_samples :] # type: ignore

    plt.figure(figsize=(10, 8)) # type: ignore
    plt.scatter(anchors_2d[:, 0], anchors_2d[:, 1], c="blue", label="Anchors", alpha=0.6)  # type: ignore
    plt.scatter(positives_2d[:, 0], positives_2d[:, 1], c="green", label="Positives", alpha=0.6)  # type: ignore
    plt.scatter(negatives_2d[:, 0], negatives_2d[:, 1], c="red", label="Negatives", alpha=0.6)  # type: ignore

    for i in range(len(anchors_2d)):  # type: ignore
        plt.plot(  # type: ignore
            [anchors_2d[i, 0], positives_2d[i, 0]],
            [anchors_2d[i, 1], positives_2d[i, 1]],
            "g-",
            alpha=0.1,
        )

    for i in range(len(anchors_2d)):  # type: ignore
        plt.plot(  # type: ignore
            [anchors_2d[i, 0], negatives_2d[i, 0]],
            [anchors_2d[i, 1], negatives_2d[i, 1]],
            "r-",
            alpha=0.1,
        )

    plt.title("2D Visualization of Embedding Distances")  # type: ignore
    plt.xlabel("First Principal Component")  # type: ignore
    plt.ylabel("Second Principal Component")  # type: ignore
    plt.legend()  # type: ignore
    plt.grid(True)  # type: ignore
    plt.savefig("embedding_distance.png")  # type: ignore
    plt.show()  # type: ignore

In [None]:
def plot_3d_embeddings_interactive(
    anchors: npt.NDArray[np.float32],
    positives: npt.NDArray[np.float32],
    negatives: npt.NDArray[np.float32],
) -> None:
    """
    Plots an interactive 3D visualization of embedding distances using PCA to reduce dimensionality.

    Parameters
    ----------
    anchors : npt.NDArray[np.float32]
        Array of anchor embeddings.
    positives : npt.NDArray[np.float32]
        Array of positive embeddings.
    negatives : npt.NDArray[np.float32]
        Array of negative embeddings.

    Returns
    -------
    None
    """

    pca = PCA(n_components=3)
    all_embeddings = np.concatenate([anchors, positives, negatives])
    embeddings_3d = pca.fit_transform(all_embeddings)  # type: ignore

    n_samples = len(anchors)
    anchors_3d = embeddings_3d[:n_samples]  # type: ignore
    positives_3d = embeddings_3d[n_samples : 2 * n_samples]  # type: ignore
    negatives_3d = embeddings_3d[2 * n_samples :]  # type: ignore

    fig = go.Figure()

    fig.add_trace(  # type: ignore
        go.Scatter3d(
            x=anchors_3d[:, 0],
            y=anchors_3d[:, 1],
            z=anchors_3d[:, 2],
            mode="markers",
            name="Anchors",
            marker=dict(size=5, color="blue", opacity=0.6),
        )
    )

    fig.add_trace(  # type: ignore
        go.Scatter3d(
            x=positives_3d[:, 0],
            y=positives_3d[:, 1],
            z=positives_3d[:, 2],
            mode="markers",
            name="Positives",
            marker=dict(size=5, color="green", opacity=0.6),
        )
    )

    fig.add_trace(  # type: ignore
        go.Scatter3d(
            x=negatives_3d[:, 0],
            y=negatives_3d[:, 1],
            z=negatives_3d[:, 2],
            mode="markers",
            name="Negatives",
            marker=dict(size=5, color="red", opacity=0.6),
        )
    )

    fig.update_layout(  # type: ignore
        title="Interactive 3D Visualization of Embedding Distances",
        scene=dict(
            xaxis_title="First Principal Component",
            yaxis_title="Second Principal Component",
            zaxis_title="Third Principal Component",
        ),
        width=800,
        height=800,
        margin=dict(l=0, r=0, b=0, t=30),
    )

    fig.update_layout(  # type: ignore
        title="Interactive 3D Visualization of Embedding Distances",
        scene=dict(
            xaxis_title="First Principal Component",
            yaxis_title="Second Principal Component",
            zaxis_title="Third Principal Component",
            camera=dict(eye=dict(x=1.25, y=1.25, z=1.25)),
        ),
        width=800,
        height=800,
        margin=dict(l=0, r=0, b=0, t=30),
    )

    fig.show()  # type: ignore

## Training & Evaluation Loop

`TrainingClass` can be used for a single training run or for k-fold cross validation. 

In [None]:
class TrainingClass(nn.Module):
    """

    Combines `FeatureExtraction` class with the `BatchTripletLoss` class. It extract features from input
    images using `ResNet`-based architecture. This class implements Mixed Precision

    Attributes
    ----------

    model: nn.Module
        An instance of `FeatureExtraction` class
    train_dataloader: Dataloader[Tuple[torch.Tensor, str, str]]
        Signature images for the model's fine-tuning
    learning_config: Dict[pstr, str | int | float]
        Fine-tuning parameters (from LEARNING_CONFIG) 
    optimiser_config: Dict[str, str | Tuple[float, float] | float]
        Optmiser's parameters (from OPTIMISER_PARAMS)
    scheduler_config: Dict[str, str | float | int]
        Scheduler's parameters (from SCHEDULER_PARAMS)
    loss_function: nn.Module
        An instance of `BatchTripletLoss` class
    checkpoint_path: Optional[str]
        The path in which checkpoints of the model are saved. Default is "checkpoint"
    save_checkpoints: bool
        Chooses whether checkpoints are to be saved
        
    Methods
    -------

    save_checkpoint(self, epoch: int, loss: float)
        Saves a version of the model.
        Parameters:
            - epoch
            - model_state_dict
            - optimiser_state_dict
            - loss
            - best_loss
            - patience_counter
    load_checkpoint(self, checkpoint_path)
        Loads a version of the model.
    train_epoch(self, dataloader: DataLoader[Tuple[torch.Tensor, torch.Tensor]]) -> float:
        Trains the model for one epoch and returns the average training loss
    validate_epoch(self, dataloader: DataLoader[Tuple[torch.Tensor, torch.Tensor]]) -> float
        Validates the model and returns the average validation loss
    fit(train_dataloder: DataLoader[Tuple[torch.Tensor, torch.Tensor]], val_dataloader: Optional[DataLoader[Tuple[torch.Tensor, torch.Tensor]]], start_epoch) -> Dict[str, Any]
        Trains the model for a specified number of epochs.

    """

    def __init__(
        self,
        model: nn.Module,
        train_dataloader: DataLoader[Tuple[torch.Tensor, str, str]],
        learning_config: Dict[str, str | int | float],
        optimiser_config: Dict[str, str | Tuple[float, float] | float],
        scheduler_config: Dict[str, str | float | int],
        loss_function: nn.Module,
        checkpoint_path: Optional[str] = "checkpoint",
        save_checkpoints: bool = True,
    ) -> None:
        super().__init__() # type: ignore

        if int(learning_config.get("MARGIN", 1.0)) < 0:
            raise ValueError(
                f"Margin must be non-negative, got {learning_config.get('MARGIN')}"
            )
        if int(learning_config.get("lambda_diversity", 0.0)) < 0:
            raise ValueError(
                f"lambda_diversity must be non-negative, got {learning_config.get('lambda_diversity')}"
            )

        self.model = model
        self.train_dataloader = train_dataloader
        self.learning_config = learning_config
        self.optimiser_config = optimiser_config
        self.scheduler_config = scheduler_config
        self.loss_function = loss_function
        self.checkpoint_path = checkpoint_path
        self.save_checkpoints = save_checkpoints
        
        self.scaler = GradScaler()

        self.device = torch.device(str(self.learning_config.get("DEVICE", "cpu")))
        self.model.to(self.device)
        self.loss_function.to(self.device)

        print(f"Model and Loss function moved to device: {self.device}")

        self.best_loss = float("inf")
        self.patience_counter = 0
        self.early_stopping_patience = self.learning_config.get(
            "EARLY_STOPPING_PATIENCE", 10
        )

        if self.save_checkpoints:
            self.checkpoint_path = Path(self.checkpoint_path) # type: ignore 
            self.checkpoint_path.mkdir(exist_ok=True)
        else:
            self.checkpoint_path = None

        optimiser_name = str(self.optimiser_config.get("optimiser"))
        if not optimiser_name:
            raise ValueError("Optimiser name 'optimiser' not found")

        optimiser_class = getattr(optim, optimiser_name)

        optimiser_params_for_init = {
            k: v for k, v in self.optimiser_config.items() if k != "optimiser"
        }

        learning_rate = self.learning_config.get("LEARNING_RATE")
        if learning_rate is None:
            raise ValueError("LEARNING_RATE key missing in learning_config.")

        optimiser_params_for_init["lr"] = learning_rate

        self.optimiser = optimiser_class(
            self.model.parameters(), **optimiser_params_for_init
        )

        scheduler_name = str(self.scheduler_config.get("scheduler"))
        if scheduler_name:
            scheduler_class = getattr(lr_scheduler, scheduler_name)
            scheduler_params_for_init = {
                k: v for k, v in self.scheduler_config.items() if k != "scheduler"
            }

            if scheduler_name == "OneCycleLR":
                num_epochs = self.learning_config.get("EPOCH")
                if num_epochs is None:
                    raise ValueError(
                        "EPOCH key missing in LEARNING_CONFIG for OneCycleLR scheduler setup."
                    )
                if not isinstance(num_epochs, int):
                    raise TypeError(
                        "EPOCH in LEARNING_CONFIG must be an integer for scheduler setup."
                    )

                total_steps = len(self.train_dataloader) * num_epochs
                scheduler_params_for_init["total_steps"] = total_steps

            self.scheduler = scheduler_class(
                self.optimiser, **scheduler_params_for_init
            )
        else:
            self.scheduler = None

    def save_checkpoint(self, epoch: int, loss: float) -> None:
        """

        Saves a version of the model.
        Parameters:
            - epoch
            - model_state_dict
            - optimiser_state_dict
            - loss
            - best_loss
            - patience_counter

        Args
        ----
        epoch: int
            The current epoch number
        loss: float
            The current loss value.

        """

        if not self.save_checkpoints:
            return

        checkpoint: Dict[str, int | float | Dict[str, Any] | None] = {
            "epoch": epoch,
            "model_state_dict": self.model.state_dict(),
            "optimiser_state_dict": self.optimiser.state_dict(),
            "loss": loss,
            "best_loss": self.best_loss,
            "patience_counter": self.patience_counter,
        }

        if self.scheduler:
            checkpoint["scheduler_state_dict"] = self.scheduler.state_dict()

        torch.save(
            checkpoint, self.checkpoint_path / f"model_{epoch}_loss_{loss:.4f}.pt" # type: ignore
        )

    def load_checkpoint(self, checkpoint_path: str) -> int:
        """

        Loads a version of the model.

        Args
        ----
        checkpoint_path: str
            Path to the checkpoint model

        Returns
        -------
        int
            The training epoch number to start from.

        """

        if not Path(checkpoint_path).exists():
            raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}")

        try:
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
        except Exception as e:
            raise RuntimeError(f"Error loading checkpoint from {checkpoint_path}: {e}")

        self.model.load_state_dict(checkpoint["model_state_dict"])

        self.optimiser.load_state_dict(checkpoint["optimiser_state_dict"])

        if self.scheduler:
            if (
                "scheduler_state_dict" in checkpoint
                and checkpoint["scheduler_state_dict"] is not None
            ):
                try:
                    self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
                except Exception as e:
                    print(
                        f"Warning: Could not load scheduler state_dict. It might be incompatible. Error: {e}"
                    )
            else:
                print(
                    "Warning: Scheduler initialized but no scheduler_state_dict found in checkpoint."
                )
        elif (
            "scheduler_state_dict" in checkpoint
            and checkpoint["scheduler_state_dict"] is not None
        ):
            print(
                "Warning: Checkpoint contains scheduler state, but no scheduler is initialized in this "
                " instance. State ignored."
            )

        loaded_epoch = checkpoint["epoch"]
        self.best_loss = checkpoint.get("best_loss", float("inf"))
        self.patience_counter = checkpoint.get("patience_counter", 0)

        loaded_loss = checkpoint["loss"]

        print(
            f"Checkpoint loaded successfully: Resuming from epoch {loaded_epoch + 1}."
        )
        print(
            f"  Loaded metrics: Last Epoch Loss = {loaded_loss:.4f}, Best Validation Loss = {self.best_loss:.4f}, Patience Counter = {self.patience_counter}"
        )

        self.model.train()

        return loaded_epoch + 1

    def train_epoch(
        self, dataloader: DataLoader[Tuple[torch.Tensor, str, str]]
    ) -> float:
        """

        Trains the model for one epoch and returns the average training loss.

        Parameters
        ----------
        dataloader: DataLoader[Tuple[torch.Tensor, torch.Tensor]]
            The training signature images dataset

        Returns
        -------
        float
            Average training loss calculated in the epoch.

        """

        self.model.train()
        total_loss = 0.0
        num_batches: int = len(dataloader)

        print("Entering train_epoch.")
        for batch_index, (embedding_batch, labels_batch) in enumerate(dataloader):
            embedding_batch = embedding_batch.to(self.device)
            labels_batch = labels_batch.to(self.device)
            
            self.optimiser.zero_grad()
            
            with autocast('cuda'):
                extracted_embedding = self.model(embedding_batch)
                loss = self.loss_function(extracted_embedding, labels_batch)
                
            self.scaler.scale(loss).backward() #type: ignore
            
            grad_clip_val = self.learning_config.get("GRAD_CLIP")
            if grad_clip_val is not None:
                self.scaler.unscale_(self.optimiser)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), grad_clip_val) # type: ignore
                
            self.scaler.step(self.optimiser)
            # self.optimiser.step()
            self.scaler.update()
            
            if (
                self.scheduler
                and self.scheduler_config.get("scheduler") != "OneCycleLR"
            ):
                self.scheduler.step()
                
            total_loss += loss.item()
            print(
                f"Training batch {batch_index}/{ num_batches }, "
                f"Loss: {loss.item():.4f}"
            )

        return total_loss / num_batches

    def validate_epoch(
        self, dataloader: DataLoader[Tuple[torch.Tensor, str, str]]
    ) -> float:
        """

        Validates the model and returns the average validation loss

        Parameters
        ----------
        dataloader: DataLoader[Tuple[torch.Tensor, str, str]]
            The validation signature images dataset

        Returns
        -------
        float
            Average training loss for the epoch.

        """

        self.model.eval()
        total_loss = 0.0

        with torch.no_grad(): 
            with autocast('cuda'):
                for _, (embeddings_batch, labels_batch) in enumerate(dataloader):
                    embeddings_batch = embeddings_batch.to(self.device)
                    labels_batch = labels_batch.to(self.device)

                    embeddings_batch = self.model(embeddings_batch)

                    loss = self.loss_function(embeddings_batch, labels_batch)
                    total_loss += loss.item()

        return total_loss / len(dataloader)

    def fit(
        self,
        train_dataloader: DataLoader[Tuple[torch.Tensor, str, str]],
        val_dataloader: Optional[DataLoader[Tuple[torch.Tensor, str, str]]] = None,
        start_epoch: int = 0,
    ) -> Dict[str, List[float] | float | None]:
        """

        Trains the model for a specified number of epochs

        Args
        ----
        train_dataloader: DataLoader[Tuple[torch.Tensor, torch.Tensor]]
            The training signature images dataset
        val_dataloader: Optional[DataLoader[Tuple[torch.Tensor, torch.Tensor]]]
            The validation signature images dataset
        start_epoch
            The starting epoch number

        Returns
        -------
        Dict[str, Any]
            A dictionary containing:
                - Training loss
                - Validation loss
                - Best validation loss
                - Final training loss
                - Final validation loss

        """

        epochs = self.learning_config.get("EPOCH", 100)

        scheduler_name = self.scheduler_config.get("scheduler")

        if scheduler_name == "OneCycleLR" and not self.scheduler:
            scheduler_class = getattr(lr_scheduler, scheduler_name) # type: ignore
            scheduler_params_for_init = {
                k: v for k, v in self.scheduler_config.items() if k != "scheduler"
            }
            total_steps = len(train_dataloader) * int(epochs)
            scheduler_params_for_init["total_steps"] = total_steps
            self.scheduler = scheduler_class(
                self.optimiser, **scheduler_params_for_init
            )

        print(
            f"Starting training for {epochs} epochs from epoch {start_epoch} on device {self.device}"
        )

        history: Dict[str, List[float] | float | None] = {
            "train_loss": [],
            "val_loss": [],
            "best_val_loss": float("inf"),
            "final_train_loss": None,
            "final_val_loss": None,
        }

        for epoch in range(start_epoch, int(epochs)):
            print(f"Epoch {epoch+1}/{epochs}")
            print(f"Starting Epoch {epoch+1} training phase.")

            train_loss = self.train_epoch(train_dataloader)
            history["train_loss"].append(train_loss) # type: ignore
            print(f"  Train Loss: {train_loss:.4f}")

            val_loss = None
            if val_dataloader:
                print(f"Starting Epoch {epoch+1} validation phase.")
                val_loss = self.validate_epoch(val_dataloader)
                history["val_loss"].append(val_loss) # type: ignore
                print(f"Validation Loss: {val_loss:.4f}")

            if self.scheduler and scheduler_name != "OneCycleLR":
                self.scheduler.step(val_loss if val_loss is not None else train_loss)

            # Use train_loss if no val_dataloader
            current_val_loss = (
                val_loss if val_dataloader else train_loss
            )  

            if current_val_loss < self.best_loss: # type: ignore
                self.best_loss = current_val_loss
                self.patience_counter = 0
                history["best_val_loss"] = self.best_loss
                self.save_checkpoint(epoch, self.best_loss) # type: ignore
            else:
                self.patience_counter += 1
                print(
                    f"  Patience: {self.patience_counter}/{self.early_stopping_patience}"
                )
                if self.patience_counter >= int(self.early_stopping_patience):
                    print(f"Early stopping triggered at epoch {epoch+1}.")
                    break

        # Store final losses
        if history["train_loss"]:
            history["final_train_loss"] = history["train_loss"][-1]  # type:ignore
        if history["val_loss"]:
            history["final_val_loss"] = history["val_loss"][-1]  # type:ignore

        print("Training completed.")
        return history

In [None]:
class KFoldCrossValidator:
    """

    This class is largely similar to `TrainingClass` in terms of attributes. The class uses `TrainingClass` to perform k-fold cross validation. 

    Attributes
    ----------
    
    dataset: Dataset[Tuple[torch.Tensor, str, str]],
        An instance of `TrainingSignatureDataset` 
    model_builder: Callable[[], nn.Module]
        A function that creates an instance of `FeatureExtraction`
    loss_function_builder: Callable[[], nn.Module],
        A function that creates an instance of `BatchTripletLoss`
    learning_config: Dict[pstr, str | int | float]
        Fine-tuning parameters (from LEARNING_CONFIG) 
    optimiser_config: Dict[str, str | Tuple[float, float] | float]
        Optmiser's parameters (from OPTIMISER_PARAMS)
    scheduler_config: Dict[str, str | float | int]
        Scheduler's parameters (from SCHEDULER_PARAMS)
    loss_function: nn.Module
        An instance of `BatchTripletLoss` class
    checkpoint_path: Optional[str]
        The path in which checkpoints of the model are saved. Default is "kfold_checkpoints"
    save_checkpoints: bool
        Chooses whether checkpoints are to be saved

    Methods
    -------

    _aggregate_results(self) -> Dict[str, Any]:
        Calculates the average loss value calculated on both training datasets and validation datasets
    run_cross_validation(self) -> Dict[str, Any]:
        Creates new dataloaders, model, and loss function in each fold.  

    """
    
    def __init__(
        self,
        dataset: Dataset[Tuple[torch.Tensor, str, str]],
        model_builder: Callable[[], nn.Module],
        loss_function_builder: Callable[[], nn.Module],
        learning_config: Dict[str, str | int | float],
        optimiser_config: Dict[str, str | Tuple[float, float] | float],
        scheduler_config: Dict[str, str | float | int],
        checkpoint_directory: str = "kfold_checkpoints",
        save_checkpoints: bool = False
    ) -> None:
        
        self.model_builder = model_builder
        self.dataset = dataset
        self.loss_function_builder = loss_function_builder
        self.learning_config = learning_config
        self.optimiser_config = optimiser_config
        self.scheduler_config = scheduler_config

        if not callable(self.model_builder):
            raise ValueError(
                "Model builder must be a callable function returning an nn.Module"
            )
        if (
            "LEARNING_RATE" in self.learning_config
            and float(self.learning_config["LEARNING_RATE"]) <= 0
        ):
            raise ValueError("Learning rate must be positive")

        self.save_checkpoints = save_checkpoints
        
        if self.save_checkpoints:
            self.checkpoint_directory = Path(checkpoint_directory)
            self.checkpoint_directory.mkdir(exist_ok=True)

        self.k_folds_number = int(self.learning_config.get("K_FOLDS", 5))
        self.batch_size = int(self.learning_config.get("BATCH_SIZE", 32))

        self.k_fold = KFold(n_splits=self.k_folds_number, shuffle=True)

        self.fold_results: List[Dict[str, Any]] = []

    def _aggregate_results(self) -> Dict[str, Any]:
        average_train_losses = [
            result["final_train_loss"] 
            for result in self.fold_results 
            if "final_train_loss" in result and result["final_train_loss"] is not None
        ]  
        average_validation_losses = [
            result["final_val_loss"]
            for result in self.fold_results
            if "final_val_loss" in result and result["final_val_loss"] is not None
        ]  # type:ignore

        return {
            "mean_final_train_loss": sum(average_train_losses)
            / len(average_train_losses),
            "mean_final_validation_loss": sum(average_validation_losses)
            / len(average_validation_losses),
        }

    def run_cross_validation(self) -> Dict[str, Any]:
        print(f"--- Running {self.k_folds_number}-Fold Cross-Validation")

        for fold, (train_indices, val_indices) in enumerate(self.k_fold.split(self.dataset)): # type: ignore
            print(f"\n-- Fold {fold+1//self.k_folds_number}")

            train_subset = Subset(self.dataset, train_indices) # type: ignore
            validation_subset = Subset(self.dataset, val_indices) # type: ignore

            train_loader = DataLoader(
                train_subset, batch_size=self.batch_size, shuffle=True, num_workers=0
            )
            validation_loader = DataLoader(
                validation_subset,
                batch_size=self.batch_size,
                shuffle=False,
                num_workers=0,
            )

            model = self.model_builder()
            loss_function = self.loss_function_builder()
            print(f"Model and Loss Function built for fold {fold+1}")

            fold_trainer = TrainingClass(
                model=model,
                train_dataloader=train_loader,
                loss_function=loss_function,
                learning_config=self.learning_config,
                optimiser_config=self.optimiser_config,
                scheduler_config=self.scheduler_config,
                save_checkpoints=self.save_checkpoints,
            )

            try:
                fold_history = fold_trainer.fit(
                    train_dataloader=train_loader, val_dataloader=validation_loader, start_epoch=0
                )
                self.fold_results.append(fold_history)
            except Exception as e:
                print(f"Error during fold {fold + 1} training: {e}")
                self.fold_results.append({"error": str(e)})

        aggregated_metrics = self._aggregate_results()
        print("\n--- K-Fold Cross-Validation Complete ---")
        return aggregated_metrics

# Data Loading & Preparation


## Load Raw Data Map


In [None]:
dataset_path: Path = Path(DATASET_CONFIG["DATASET_PATH"])
original_signatures = retrieve_signature_images(dataset_path / "original")
forged_signatures = retrieve_signature_images(dataset_path / "forged")
signature_map = prepare_signature_map(original_signatures, forged_signatures)
train_val_signatures_map, test_signatures_map = training_and_testing_split(
    signature_map, test_ratio=0.2, random_state=42
)

## Define Transforms

For more variety, consider addding `RandomRotation`

In [None]:
train_transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomAffine(
            degrees=0, translate=(0.05, 0.05), scale=(0.95, 1.05), shear=(-5, 5)
        ),

        transforms.RandomResizedCrop(
            (224, 224), scale=(0.9, 1.05), ratio=(0.95, 1.05), antialias=True
        ),

        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5]),
    ]
)

test_transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5]),
    ]
)

## Instantiate Datasets


In [None]:
train_dataset = TrainingSignatureDataset(
    data_map=train_val_signatures_map, transform=train_transform
)

test_dataset = TestingSignatureDataset(
    data_map=test_signatures_map, transform=test_transform
)

## Create DataLoaders


In [None]:
test_dataloader = DataLoader(test_dataset, LEARNING_CONFIG["BATCH_SIZE"], shuffle=False) # type: ignore
train_dataloader = DataLoader(train_dataset, LEARNING_CONFIG["BATCH_SIZE"], shuffle=True) # type: ignore

# Model Initialisation & Setup


In [None]:
model_builder = build_feature_extraction_model
loss_function_builder = build_batch_triplet_loss

In [None]:
k_fold_validator = KFoldCrossValidator(
    dataset=train_dataset,
    model_builder=model_builder,
    loss_function_builder=loss_function_builder,
    learning_config=LEARNING_CONFIG,
    optimiser_config=OPTIMISER_PARAMS,
    scheduler_config=SCHEDULER_PARAMS,
    checkpoint_directory="kfold_checkpoints",
    save_checkpoints=False
)

In [None]:
final_model_trainer = TrainingClass(
    model=model_builder(),
    train_dataloader=train_dataloader,
    loss_function=loss_function_builder(),
    learning_config=LEARNING_CONFIG,
    optimiser_config=OPTIMISER_PARAMS,
    scheduler_config=SCHEDULER_PARAMS,
    save_checkpoints=True,
)

# Model Training

## K-Fold Cross Validation

In [None]:
k_fold_result = k_fold_validator.run_cross_validation()

In [None]:
k_fold_result

## Training Final model

In [None]:
model_result = final_model_trainer.fit(train_dataloader=train_dataloader)

In [None]:
model_result

# Model Evaluation

Load the best model


In [None]:
model_for_evaluation = load_model_for_inference("model/model_29_loss_0.0510.pt", LEARNING_CONFIG["DEVICE"], model_builder) # type: ignore

## Training Dataset

In [None]:
train_dataset_for_evaluation = TestingSignatureDataset(
    data_map=train_val_signatures_map, transform=test_transform
)

In [None]:
train_dataloader_for_evaluation = DataLoader(
    train_dataset_for_evaluation, batch_size=64, shuffle=False
)

In [None]:
train_all_distances, train_all_labels, train_embeddings, train_metrics, auc_roc_metrics = (
    evaluate_triplet_network(
        model_for_evaluation,
        train_dataloader_for_evaluation,
        LEARNING_CONFIG["DEVICE"], # type: ignore
        0.5,
    )
)

In [None]:
train_metrics

## Testing Dataset

In [None]:
test_all_distances, test_all_labels, test_embeddings, test_metrics, auc_roc_metrics = (
    evaluate_triplet_network(
        model_for_evaluation, test_dataloader, LEARNING_CONFIG["DEVICE"], 0.5 # type: ignore
    )
)

In [None]:
test_metrics

# Plots, and Graphs

## Training Dataset

In [None]:
plot_auc_roc(auc_roc_metrics["fpr"], auc_roc_metrics["tpr"], train_metrics["auc_roc"])

In [None]:
plot_confusion_matrix(train_all_labels, train_all_distances, 0.2)

In [None]:
plot_embedding_distances(
    train_embeddings["anchor"],
    train_embeddings["positive"],
    train_embeddings["negative"],
)

In [None]:
plot_3d_embeddings_interactive(
    train_embeddings["anchor"],
    train_embeddings["positive"],
    train_embeddings["negative"],
)

## Testing Dataset

In [None]:
plot_auc_roc(auc_roc_metrics["fpr"], auc_roc_metrics["tpr"], test_metrics["auc_roc"])

In [None]:
plot_confusion_matrix(test_all_labels, test_all_distances, 1.07)

In [None]:
plot_embedding_distances(test_embeddings['anchor'], test_embeddings['positive'], test_embeddings['negative'])

In [None]:
plot_3d_embeddings_interactive(test_embeddings['anchor'], test_embeddings['positive'], test_embeddings['negative'])

# Utilising The Model

For the model to work, the signature images should be preprocessed the same way it was used during fine-tuning. 

The following demonstration uses the testing dataset:
```bash
Signer ids: ['6', '13', '14', '20', '27', '32', '33', '42', '44', '50', '53']
```

## Preparing signature images

In [None]:
reference_signature_image_path: Path = Path(
    "processed_signature_images/original/original_53_9.png"
)
original_input_signature_path: Path = Path(
    "processed_signature_images/original/original_53_22.png"
)
original_input_different_signer: Path = Path(
    "processed_signature_images/original/original_27_12.png"
)
forged_input_same_signer: Path = Path(
    "processed_signature_images/forged/forgeries_53_13.png"
)
forged_input_different_signer: Path = Path(
    "processed_signature_images/forged/forgeries_50_5.png"
)


In [None]:
reference_image = Image.open(str(reference_signature_image_path)).convert("L")
input_image = Image.open(str(original_input_signature_path)).convert("L")
forged_same_signer = Image.open(str(forged_input_same_signer)).convert("L")
forged_different_signer_image = Image.open(str(forged_input_different_signer)).convert(
    "L"
)
original_different_signer_image = Image.open(
    str(original_input_different_signer)
).convert("L")

In [None]:
reference_tensor = test_transform(reference_image)  # type: ignore
input_tensor = test_transform(input_image)  # type: ignore
forged_same_signer_tensor = test_transform(forged_same_signer)  # type: ignore
forged_different_signer_tensor = test_transform(forged_different_signer_image)  # type: ignore
original_different_signer_tensor = test_transform(original_different_signer_image)  # type: ignore

In [None]:
reference_tensor = reference_tensor.unsqueeze(0)  # type: ignore
input_tensor = input_tensor.unsqueeze(0)  # type: ignore
forged_same_signer_tensor = forged_same_signer_tensor.unsqueeze(0)  # type: ignore
forged_different_signer_tensor = forged_different_signer_tensor.unsqueeze(0)  # type: ignore
original_different_signer_tensor = original_different_signer_tensor.unsqueeze(0)  # type: ignore

In [None]:
reference_tensor = reference_tensor.to("cuda")  # type: ignore
input_tensor = input_tensor.to("cuda")  # type: ignore
forged_same_signer_tensor = forged_same_signer_tensor.to("cuda")  # type: ignore
forged_different_signer_tensor = forged_different_signer_tensor.to("cuda")  # type: ignore
original_different_signer_tensor = original_different_signer_tensor.to("cuda")  # type: ignore

## Applying The Model

Relying purely on the model itself is prone to misclassifications in edge cases, so I implemented cosine similarity to complement the model. 

Cosine similariy will measure how similar two vectors are, based on the angle between them rather than their magnitude. 
1. If the two vectors point in the same direction, similarity score is close to 1.0,
2. If the two vectors are orthogonal, simmilarity score is 0.0, indicating an absence of similarity
3. Opposite vectors have a similarity score near -1.0, signifying total dissimilarity.

As for the distance, I implemented a simple Euclidean distance calculation. 

Additionally, a confidence score derived from the similarity score and the distance is also calculated. 

*The more time you spend thinking about this, the more you realise the amount of additional work needed to make it a complete system. So, I will only keep this example simple.*

In [None]:
def apply_model(
    reference_tensor: torch.Tensor,
    input_tensor: torch.Tensor,
    forged_same_signer_tensor: torch.Tensor,
    forged_different_signer_tensor: torch.Tensor,
    original_different_signer_tensor: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    
    reference_feature_embeddings = torch.Tensor(model_for_evaluation(reference_tensor))
    input_feature_embeddings = torch.Tensor(model_for_evaluation(input_tensor))
    forged_same_signer_feature_embeddings = torch.Tensor(model_for_evaluation(
        forged_same_signer_tensor
    ))
    forged_different_signer_feature_embeddings = torch.Tensor(model_for_evaluation(
        forged_different_signer_tensor
    ))
    original_different_signer_feature_embeddings = torch.Tensor(model_for_evaluation(
        original_different_signer_tensor
    ))
    return (
        reference_feature_embeddings,
        input_feature_embeddings,
        forged_same_signer_feature_embeddings,
        forged_different_signer_feature_embeddings,
        original_different_signer_feature_embeddings,
    )

In [None]:
def calculate_metrics(
    reference_feature_embeddings: torch.Tensor, input_feature_embeddings: torch.Tensor
) -> Tuple[float, float, float, float, float]:
    similarity = F.cosine_similarity(
        reference_feature_embeddings, input_feature_embeddings
    ).item()

    distance = F.pairwise_distance(
        reference_feature_embeddings, input_feature_embeddings
    ).item()

    normalised_distance = max(0, min(1, distance / 1.3))
    distance_score = 1 - normalised_distance

    # Greater emphasis is placed on the distance calculated.
    confidence_score = 0.8 * similarity + 0.2 * distance_score

    return similarity, confidence_score, normalised_distance, distance_score, distance

In [None]:
def apply_threshold(
    similarity: float, 
    distance: float,
    confidence_score: float,
    result: Dict[str, bool | float | str]
) -> None:
    pass_thresholds = (
        similarity >= 0.5 and 
        distance <= 1.07
    )
    
    if pass_thresholds:
        if confidence_score >= 0.9:
            result['prediction_level'] = 'Very High Confidence'
            result['is_genuine'] = True
        elif confidence_score >= 0.8:
            result['prediction_level'] = 'High Confidence'
            result['is_genuine'] = True
        elif confidence_score >= 0.7:
            result['prediction_level'] = 'Medium Confidence'
            result['is_genuine'] = True
        elif confidence_score >= 0.6:
            result['prediction_level'] = 'Low Confidence'
        else:
            result['prediction_level'] = "Very Low Confidence"
    else:
        result['prediction_level'] = 'Failed Threshold Check'
            

In [None]:
(
    reference_feature_embeddings,
    input_feature_embeddings,
    forged_same_signer_feature_embeddings,
    forged_different_signer_feature_embeddings,
    original_different_signer_feature_embeddings,
) = apply_model(
    reference_tensor,  # type: ignore
    input_tensor,  # type: ignore
    forged_same_signer_tensor,  # type: ignore
    forged_different_signer_tensor,  # type: ignore
    original_different_signer_tensor,  # type: ignore
)

In [None]:
def denormalize_image_tensor(
    tensor: torch.Tensor, 
    mean: float, 
    std: float
    ) -> torch.Tensor: 
    
    mean_tensor = torch.tensor(mean).view(-1, 1, 1).to(tensor.device)
    std_tensor = torch.tensor(std).view(-1, 1, 1).to(tensor.device)
    denormalized_tensor = tensor * std_tensor + mean_tensor
    
    return torch.clamp(denormalized_tensor, 0.0, 1.0)

In [None]:
def visualise_image(reference_tensor: torch.Tensor, input_tensor: torch.Tensor):
    to_pil = transforms.ToPILImage()
    
    reference_signature = denormalize_image_tensor(reference_tensor.to('cuda'), 0.5, 0.5).squeeze(0)
    input_signature = denormalize_image_tensor(input_tensor.to('cuda'), 0.5, 0.5).squeeze(0)
    
    _, axes = plt.subplots(1, 2, figsize=(15, 5)) # type: ignore

    axes[0].imshow(to_pil(reference_signature), cmap='gray')
    axes[0].set_title(f"Reference")
    axes[0].axis('off')

    axes[1].imshow(to_pil(input_signature), cmap='gray')
    axes[1].set_title(f"Input")
    axes[1].axis('off')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # type: ignore
    plt.show() # type: ignore

### Same signer (Original)

In [None]:
visualise_image(torch.Tensor(reference_tensor), torch.Tensor(input_tensor))

In [None]:
(similarity, confidence_score, normalised_distance, distance_score, distance) = (
    calculate_metrics(reference_feature_embeddings, input_feature_embeddings)
)

In [None]:
first_result: Dict[str, bool |  float | str] = {
    'is_genuine': False,
    'confidence_score': confidence_score,
    'similarity_score': similarity,
    'euclidean_distance': normalised_distance,
    'distance_score': distance_score,
    'prediction_level': 'unknown',
    'passed_thresholds': False
}

In [None]:
apply_threshold(
    float(first_result["similarity_score"]),
    float(first_result["distance_score"]),
    float(first_result["confidence_score"]),
    first_result,
)

In [None]:
first_result

### Same Signer (Forged)

In [None]:
visualise_image(torch.Tensor(reference_tensor), torch.Tensor(forged_same_signer_tensor))

In [None]:
(similarity, confidence_score, normalised_distance, distance_score, distance) = (
    calculate_metrics(reference_feature_embeddings, forged_same_signer_feature_embeddings)
)

In [None]:
second_result: Dict[str, bool |  float | str] = {
    'is_genuine': False,
    'confidence_score': confidence_score,
    'similarity_score': similarity,
    'euclidean_distance': normalised_distance,
    'distance_score': distance_score,
    'prediction_level': 'unknown',
    'passed_thresholds': False
}

In [None]:
apply_threshold(
    float(second_result["similarity_score"]),
    float(second_result["distance_score"]),
    float(second_result["confidence_score"]),
    second_result,
)

In [None]:
second_result

### Different Signer (Forged)

In [None]:
visualise_image(torch.Tensor(reference_tensor), torch.Tensor(forged_different_signer_tensor))

In [None]:
(similarity, confidence_score, normalised_distance, distance_score, distance) = (
    calculate_metrics(reference_feature_embeddings, forged_different_signer_feature_embeddings)
)

In [None]:
third_result: Dict[str, bool |  float | str] = {
    'is_genuine': False,
    'confidence_score': confidence_score,
    'similarity_score': similarity,
    'euclidean_distance': normalised_distance,
    'distance_score': distance_score,
    'prediction_level': 'unknown',
    'passed_thresholds': False
}

In [None]:
apply_threshold(
    float(third_result["similarity_score"]),
    float(third_result["distance_score"]),
    float(third_result["confidence_score"]),
    third_result,
)

In [None]:
third_result

### Different Signer (Original)

In [None]:
visualise_image(torch.Tensor(reference_tensor), torch.Tensor(original_different_signer_tensor))

In [None]:
(similarity, confidence_score, normalised_distance, distance_score, distance) = (
    calculate_metrics(reference_feature_embeddings, original_different_signer_feature_embeddings)
)

In [None]:
fourth_result: Dict[str, bool |  float | str] = {
    'is_genuine': False,
    'confidence_score': confidence_score,
    'similarity_score': similarity,
    'euclidean_distance': normalised_distance,
    'distance_score': distance_score,
    'prediction_level': 'unknown',
    'passed_thresholds': False
}

In [None]:
apply_threshold(
    float(fourth_result["similarity_score"]),
    float(fourth_result["distance_score"]),
    float(fourth_result["confidence_score"]),
    fourth_result,
)

In [None]:
fourth_result