# Install the necessary libraries and define the constants

In [None]:
from IPython.display import clear_output
!pip install ultralytics huggingface_hub[hf_xet]
clear_output()

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

from PIL import Image
from ultralytics import YOLO
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download
from tqdm.auto import tqdm

SIGN_IMG_SIZE = (105, 105) # The image size to be fed into the Siamese Network

___
# Define the structure for the Siamese Model before loading the weights

In [None]:
class ContrastiveLoss(nn.Module):
    """
    Contrastive loss function for similarity learning.
    This loss function is particularly useful when training models to learn embeddings
    where similar pairs of inputs are mapped close together in the embedding space,
    and dissimilar pairs are mapped far apart.
    """
    def __init__(self, margin=1.5):
        """
        Initializes the ContrastiveLoss module.

        Args:
            margin (float, optional): The margin value for the contrastive loss.
                                     It defines the boundary beyond which dissimilar pairs
                                     should not contribute to the loss. Defaults to 1.5.
        """
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        """
        Computes the contrastive loss.

        Args:
            output1 (torch.Tensor): Output embeddings from the first input in the pair.
                                    Shape: (batch_size, embedding_dimension)
            output2 (torch.Tensor): Output embeddings from the second input in the pair.
                                    Shape: (batch_size, embedding_dimension)
            label (torch.Tensor): Labels indicating whether the pair is similar or dissimilar.
                                  - 0 indicates a similar pair (should be close in embedding space).
                                  - 1 indicates a dissimilar pair (should be far apart in embedding space).
                                  Shape: (batch_size,)

        Returns:
            torch.Tensor: The computed contrastive loss. A scalar value representing the average loss
                          over the batch.
        """
        # Calculate the Euclidean distance between the two output embeddings for each pair in the batch.
        # Computes the pairwise Euclidean distance between rows of two tensors.
        euclidean_distance = F.pairwise_distance(output1, output2)

        # Compute the contrastive loss.
        # The loss has two components based on the label:
        # 1. For similar pairs (label == 0):
        #    - We want to minimize the squared Euclidean distance between the embeddings.
        # 2. For dissimilar pairs (label == 1):
        #    - We want to maximize the distance between the embeddings, but only up to a certain margin.
        #    - If the Euclidean distance is greater than the margin, the loss contribution for this pair should be zero.
        #    - label will be 1 for dissimilar pairs and 0 for similar pairs, effectively selecting this term
        loss_contrastive = torch.mean(
            (1 - label) * torch.pow(euclidean_distance, 2) +
            label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
        )

        return loss_contrastive

In [None]:
class SiameseModel(nn.Module):
    """Siamese network for signature verification."""
    def __init__(self, lr=1e-4, from_file=None, embedding_size=256, margin=1.5):
        """
        Initializes the SignatureRCNN model.

        Args:
            device (torch.device, optional): Device to use for computation (CPU or CUDA). Defaults to CUDA if available, otherwise CPU.
            lr (float, optional): Learning rate for the optimizer. Defaults to 1e-4.
            from_file (str, optional): Path to a pre-trained model file to load. Defaults to None.
            embedding_size (int): Size of the final feature vector.
        """
        super(SiameseModel, self).__init__()

        ## Define this if you want to train the model again with the train() function
        # self.criterion = ContrastiveLoss(margin=margin)
        # self.optimizer = torch.optim.AdamW(self.parameters(), lr=lr, weight_decay=1e-4)
        
        # Define convolutional layers for feature extraction
        self.cnn = nn.Sequential(
            # Block 1: 1 input channel (grayscale), 64 output channels
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(), # Replace ReLU
            nn.BatchNorm2d(64), # Add BatchNorm to stabalize
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), # Stride=2 for downsampling, (replaces MaxPool)

            # Block 2: 64 input channels, 128 output channels
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),

            # Block 3: 128 input channels, 256 output channels
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),

            # Block 4: 256 input channels, 512 output channels
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.BatchNorm2d(512),
            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            
            nn.AdaptiveAvgPool2d((1, 1)) # Replace flattening, less overfit, more intepretable features
        )
        self.fc = nn.Sequential(
            nn.Linear(512, 512),
            nn.LeakyReLU(),
            nn.Linear(512, embedding_size)
        )
        
        # Load the model weight if exist
        if from_file and os.path.exists(from_file):
            self.load_from_file(from_file)
        
        
    def forward_once(self, x):
        """Pass input through the feature extractor.""" 
        output = self.cnn(x)
        output = output.view(output.size()[0], -1)
        output = self.fc(output)
        return output

    def forward(self, input1, input2):
        """Compute embeddings for both images in the pair."""
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        return output1, output2

    def load_from_file(self, file_path):
        """Load back the model weights from the .pt file"""
        checkpoint = torch.load(file_path, map_location=torch.device('cpu'), weights_only=True)
        self.cnn.load_state_dict(checkpoint['cnn'])
        self.fc.load_state_dict(checkpoint['fc'])

    def save_to_file(self, file_path):
        """Save the model weights to the .pt file"""
        torch.save({
            'cnn': self.cnn.state_dict(),
            'fc': self.fc.state_dict()
        }, file_path)
        print(f"✅ Saved Siamese weights to {file_path}")

    def train_model(self, data_loader, epochs): # The training methods are included in the other file of Siamese Model
        ...

In [None]:
def crop_from_xywhn_pil(image_pil, xywhn_box, padding=0):
    """Crop the image at the chosen box"""
    # Extract the coordinate from the box
    W, H = image_pil.size
    x_center, y_center, w, h = map(float, xywhn_box)

    x_c, y_c = x_center * W, y_center * H
    bw, bh = w * W, h * H
    
    x1 = max(int(x_c - bw / 2) - padding, 0)
    y1 = max(int(y_c - bh / 2) - padding, 0)
    x2 = min(int(x_c + bw / 2) + padding, W)
    y2 = min(int(y_c + bh / 2) + padding, H)

    # Crop out the image
    cropped = image_pil.crop((x1, y1, x2, y2))
    return cropped

___
# Define the full pipeline of Detection + Verification

In [None]:
class DetectToVerify:
    """
    A unified pipeline that performs both signature detection and signature verification.
    
    Components:
    - YOLO11 object detector to find signature bounding boxes in a document image.
    - Siamese neural network to verify whether the detected signature matches a reference sample.
    """
    def __init__(self, detector_filename, verifier_filename, repo_id=None, threshold=0.5):
        """
        Initializes the detection and verification pipeline.

        Args:
            detector_filename (str): Path to the YOLO detector model (.pt file).
            verifier_filename (str): Path to the Siamese verification model (.pt file).
            repo_id (str, optional): Hugging Face Hub repo ID to download models if not present locally.
            threshold (float): Distance threshold below which a match is considered genuine.
        """
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.threshold = threshold

        # Load YOLO signature detector model
        if not os.path.isfile(detector_filename):
            if repo_id is None:
                raise Exception("Please pass in repo_id to pull the model")
            print(f"Download {detector_filename} from hub")
            hf_hub_download(repo_id=repo_id, filename=detector_filename, local_dir=LOCAL_DIR)
        self.detector = YOLO(detector_filename)
        self.detector.to(self.device)
        print(f"✅ Loaded detector from {detector_filename}\n")

        # Load Siamese signature verifier model
        if not os.path.isfile(verifier_filename):
            if repo_id is None:
                raise Exception("Please pass in repo_id to pull the model")
            print(f"Download {verifier_filename} from hub")
            hf_hub_download(repo_id=repo_id, filename=verifier_filename, local_dir=LOCAL_DIR)
        self.verifier = SiameseModel(from_file=verifier_filename)
        self.verifier.to(self.device)
        self.verifier.eval() # Set the model to evaluate state
        print(f"✅ Loaded verifier from {verifier_filename}\n")

        # Preprocessing pipeline for both the sample and cropped signature images (verify process)
        self.transform = T.Compose([
            T.Grayscale(),
            T.Resize(SIGN_IMG_SIZE),
            T.ToTensor(),
            T.Normalize(mean=[0.5], std=[0.5]),
        ])

    def infer(self, sample_image, document_image, min_conf=0.5, show=False):
        """
        Performs inference by detecting signatures in the document and verifying each against the sample.

        Args:
            sample_image (str or PIL.Image): The reference signature image (ground truth).
            document_image (str or PIL.Image): The document containing potential signatures.
            min_conf (float): The mininum confidence for the detected bboxes.
            show (bool): Whether to visualize each detected signature and distance score.

        Returns:
            bool: True if any detected signature is verified as genuine, False otherwise.
        """
        # Load and prepare sample and document images
        if not isinstance(sample_image, Image.Image):
            sample_image = Image.open(sample_image).convert("RGB")
        if not isinstance(document_image, Image.Image):
            document_image = Image.open(document_image).convert("RGB")
        sample_image_tensor = self.transform(sample_image).unsqueeze(0).to(self.device)  # Prepare the sample image tensor
        

        # Run signature detection with YOLO
        results = self.detector.predict(document_image, verbose=False, conf=min_conf)[0]
        labels = [False] # Start with False to handle case with no detections

        # Iterate through all detections
        for box, cls, conf in zip(results.boxes.xywhn, results.boxes.cls, results.boxes.conf):
            if int(cls) != 0:
                continue   # Skip non-signature detections
    
            # Crop out the detected signature and transform it
            cropped = crop_from_xywhn_pil(document_image, box)
            cropped_tensor = self.transform(cropped).unsqueeze(0).to(self.device)
    
            # Pass sample and cropped signature to the Siamese model
            output_sample, output_cropped = self.verifier(sample_image_tensor, cropped_tensor)
    
            # Compute Euclidean distance between embeddings
            distance = F.pairwise_distance(output_sample, output_cropped).item()
            labels.append(distance < self.threshold)   # Mark as genuine if distance below threshold
            
            # Optionally visualize cropped signature with distance
            if show:
                # Display the cropped signature and distance
                plt.imshow(cropped)
                plt.title(f"Distance={distance} | Label {distance < self.threshold}")
                plt.axis('off')
                plt.show()
    
        return any(labels)
        # return preds

    def evaluate(self, dataset, min_conf=0.5):
        """
        Evaluate the signature detection + verification pipeline on a dataset.

        Args:
            dataset (List[Dict]): A list of dictionaries with keys:
                - 'sample_image': reference signature (genuine)
                - 'document_image': document containing a signature
                - 'label': 0 if genuine, 1 if forged

        Returns:
            float: Accuracy of the pipeline.
        """
        correct, total = 0, 0
        pbar = tqdm(dataset, desc="Evaluating", dynamic_ncols=True) # Create the progress bar

        for item in pbar:
            sample = item['sample_signature']
            document = item['document']
            label = item['label'] # 0 = genuine, 1 = forged

            # Pass the model through inference stage
            prediction = self.infer(sample, document, min_conf=min_conf)

            if prediction == (label == 0):
                correct += 1
            total += 1

            # Live accuracy update
            accuracy = correct / total if total > 0 else 0.0
            pbar.set_postfix({'accuracy': f'{accuracy:.2%}'})

        accuracy = correct / total if total > 0 else 0.0
        print(f"\n✅ Accuracy: {accuracy:.2%}")
        return accuracy

    def evaluate_siamese(self, dataset):
        """
        Evaluate only the Siamese verifier on a dataset of signature pairs.
    
        Args:
            dataset (List[Dict]): A list of dictionaries with keys:
                - 'to_verify_signature': Second signature image (genuine or forged).
                - 'sample_signature': First signature image (genuine).
                - 'label': 0 if genuine, 1 if forged.
            show (bool): Whether to visualize example predictions with distances.
    
        Returns:
            float: Accuracy of the verifier model.
        """
        correct, total = 0, 0
        progress_bar = tqdm(dataset, desc="Evaluating Siamese Verifier", dynamic_ncols=True) # Create the progress bar
    
        for item in progress_bar:
            img1 = item['to_verify_signature']
            img2 = item['sample_signature']
            label = item['label']
    
            # Load and transform both images, send to the correct device of cuda or cpu
            if not isinstance(img1, Image.Image):
                img1 = Image.open(img1).convert("RGB")
            if not isinstance(img2, Image.Image):
                img2 = Image.open(img2).convert("RGB")
                
            tensor1 = self.transform(img1).unsqueeze(0).to(self.device)
            tensor2 = self.transform(img2).unsqueeze(0).to(self.device)

            # Run prediction without updating the weights
            with torch.no_grad():
                out1, out2 = self.verifier(tensor1, tensor2)
                distance = F.pairwise_distance(out1, out2).item()
                prediction = 0 if distance < self.threshold else 1  # 0 = genuine, 1 = forged
    
            if prediction == label:
                correct += 1
            total += 1
    
            # Live accuracy update
            accuracy = correct / total if total > 0 else 0.0
            progress_bar.set_postfix({'accuracy': f'{accuracy:.2%}'})  # Create the progress bar

        final_acc = correct / total if total > 0 else 0.0
        print(f"\n✅ Siamese Verifier Final Accuracy: {final_acc:.4%}")
        return final_acc


___
# Import the dataset and prepare the model (both pulling from HUB)

***Dataset***  
We are using our improvised dataset, storing at [Mels22/SigDetectVerifyFlow](https://huggingface.co/datasets/Mels22/SigDetectVerifyFlow). Describing the dataset within the few words, it allows the full flow from Detection to Verification of signatures.

Each sample in the dataset contains the following fields:

- `document` *(Image)*: The full document image that contains one or more handwritten signatures.
- `bbox` *(List of Bounding Boxes)*: The coordinates of the signature(s) detected in the `document`. Format: `[x_min, y_min, x_max, y_max]`.
- `to_verify_signature` *(Image)*: A cropped signature from the document image that needs to be verified.
- `sample_signature` *(Image)*: A standard reference signature used for comparison.
- `label` *(int)*: Indicates if the `to_verify_signature` is **genuine (0)** or **forged (1)** when compared to the `sample_signature`.

___
***Model***
We store our trained model on the [Mels22/Signature-Detection-Verification](https://huggingface.co/Mels22/Signature-Detection-Verification), which in short, contains the following files:
- `detector_yolo_1cls.pt`: The detection model to be trained on recognizing `signature` only.
- `detector_yolo_4cls.pt`: The detection model to be trained on recognizing hand-written elements: `signature`, `initial`, `redaction`, and `date`.
- `verifier_siamese.pt`: The verification model with the above defined architecture to classify the 2 images as genuine or forged.

In [None]:
from datasets import load_dataset

LOCAL_DIR = "." # Save the downloaded model to 

# Model and dataset repo on Hugging Face Hub
MODEL_REPO = "Mels22/Signature-Detection-Verification"
DATASET_REPO = "Mels22/SigDetectVerifyFlow"

In [None]:
data_loader = load_dataset(DATASET_REPO)

In [None]:
flow = DetectToVerify(
    detector_filename="detector_yolo_1cls.pt", # Path to detection model
    verifier_filename="verifier_siamese.pt", # Path to verification model
    repo_id=MODEL_REPO,
    threshold=0.5, # Threshold for siamese verification
)

___
# Inference / Evaluate the model on the dataset

In [None]:
flow.evaluate(data_loader['test'])