In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_curve, auc
import seaborn as sns
from tqdm import tqdm


In [None]:
# --- Configuration ---
CONFIG = {
    "dataset_path": "/kaggle/input/cedardataset/signatures",  # Adjust if needed
    "model_path": "/kaggle/input/siamese-transformer/pytorch/default/1/best_siamese_transformer.pth",  # Update this path
    "img_size": (224, 224),
    "batch_size": 32,
    "embedding_dim": 128,
    "transformer_heads": 4,
    "transformer_layers": 2,
    "dropout": 0.1,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "seed": 42,
    "threshold": 0.5  # Distance threshold for classification
}

In [None]:
# --- Seeding for Reproducibility ---
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(CONFIG['seed'])
print(f"Device: {CONFIG['device']}")
print(f"Dataset Path: {CONFIG['dataset_path']}")
print(f"Model Path: {CONFIG['model_path']}")

In [None]:
# --- Dataset Class (Same as training) ---
class SignatureDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.genuine_path = os.path.join(root_dir, 'full_org')
        self.forged_path = os.path.join(root_dir, 'full_forg')
        
        # Group signatures by writer ID
        self.writers = {}
        all_genuine = sorted(os.listdir(self.genuine_path))
        all_forged = sorted(os.listdir(self.forged_path))
        
        for f in all_genuine:
            if not f.startswith("original_") or not f.endswith(".png"):
                continue
            try:
                writer_id = int(f.split('_')[1])
            except (IndexError, ValueError):
                continue
            
            if writer_id not in self.writers:
                self.writers[writer_id] = {'genuine': [], 'forged': []}
            self.writers[writer_id]['genuine'].append(os.path.join(self.genuine_path, f))
            
        for f in all_forged:
            if not f.startswith("forgeries_") or not f.endswith(".png"):
                continue
            try:
                writer_id = int(f.split('_')[1])
            except (IndexError, ValueError):
                continue
            if writer_id in self.writers:
                self.writers[writer_id]['forged'].append(os.path.join(self.forged_path, f))
        
        # Split writers (Writer-Independent Split)
        writer_ids = list(self.writers.keys())
        train_ids, test_ids = train_test_split(writer_ids, test_size=0.2, random_state=CONFIG['seed'])
        
        if split == 'train':
            self.active_writers = train_ids
        else:
            self.active_writers = test_ids
            
        self.pairs = self._generate_pairs()

    def _generate_pairs(self):
        pairs = []
        for wid in self.active_writers:
            gens = self.writers[wid]['genuine']
            forgs = self.writers[wid]['forged']
            
            # Positive Pairs (Genuine-Genuine)
            for i in range(len(gens)):
                for j in range(i + 1, len(gens)):
                    pairs.append([gens[i], gens[j], 0])  # 0 = Similar
            
            # Negative Pairs (Genuine-Forged)
            for g in gens:
                for f in forgs:
                    pairs.append([g, f, 1])  # 1 = Dissimilar
                    
        return pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img1_path, img2_path, label = self.pairs[idx]
        
        img1 = Image.open(img1_path).convert("L").convert("RGB")
        img2 = Image.open(img2_path).convert("L").convert("RGB")
        
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            
        return img1, img2, torch.tensor(label, dtype=torch.float32)

In [None]:
# --- Model Architecture (Same as training) ---
class SiameseTransformer(nn.Module):
    def __init__(self):
        super(SiameseTransformer, self).__init__()
        
        efficientnet = models.efficientnet_b0(pretrained=False)  # No need to download weights
        self.backbone = efficientnet.features
        
        self.feature_dim = 1280 
        self.seq_len = 7 * 7
        
        self.pos_embedding = nn.Parameter(torch.randn(1, self.seq_len, self.feature_dim))
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.feature_dim, 
            nhead=CONFIG['transformer_heads'], 
            dim_feedforward=self.feature_dim * 2,
            dropout=CONFIG['dropout'],
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=CONFIG['transformer_layers'])
        
        self.fc = nn.Sequential(
            nn.Linear(self.feature_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, CONFIG['embedding_dim'])
        )

    def forward_one(self, x):
        features = self.backbone(x)
        features = features.view(features.size(0), self.feature_dim, -1)
        features = features.permute(0, 2, 1)
        features = features + self.pos_embedding
        features = self.transformer(features)
        embedding = torch.mean(features, dim=1)
        embedding = self.fc(embedding)
        return embedding

    def forward(self, img1, img2):
        out1 = self.forward_one(img1)
        out2 = self.forward_one(img2)
        return out1, out2

In [None]:
# --- Transformations ---
test_transforms = transforms.Compose([
    transforms.Resize(CONFIG['img_size']),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# --- Load Test Dataset ---
print("\nLoading test dataset...")
test_ds = SignatureDataset(CONFIG['dataset_path'], split='test', transform=test_transforms)
test_loader = DataLoader(test_ds, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=2)
print(f"Test Pairs: {len(test_ds)}")

# --- Load Model ---
print("\nLoading trained model...")
model = SiameseTransformer().to(CONFIG['device'])
model.load_state_dict(torch.load(CONFIG['model_path'], map_location=CONFIG['device']))
model.eval()
print("Model loaded successfully!")

# --- Evaluation Function ---
def evaluate_model(model, dataloader, threshold=0.5):
    """
    Evaluate the model and return predictions, labels, and distances
    """
    all_distances = []
    all_labels = []
    all_predictions = []
    
    print("\nEvaluating model...")
    with torch.no_grad():
        for img1, img2, labels in tqdm(dataloader, desc="Testing"):
            img1, img2, labels = img1.to(CONFIG['device']), img2.to(CONFIG['device']), labels.to(CONFIG['device'])
            
            # Get embeddings
            emb1, emb2 = model(img1, img2)
            
            # Calculate distances
            distances = F.pairwise_distance(emb1, emb2)
            
            # Predictions based on threshold
            # Distance < threshold â†’ Similar (0), Distance >= threshold â†’ Dissimilar (1)
            predictions = (distances >= threshold).float()
            
            all_distances.extend(distances.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions.cpu().numpy())
    
    return np.array(all_predictions), np.array(all_labels), np.array(all_distances)

# --- Run Evaluation ---
predictions, true_labels, distances = evaluate_model(model, test_loader, threshold=CONFIG['threshold'])

# --- Calculate Metrics ---
accuracy = accuracy_score(true_labels, predictions)
precision = precision_score(true_labels, predictions)
recall = recall_score(true_labels, predictions)
f1 = f1_score(true_labels, predictions)

print("\n" + "="*50)
print("MODEL EVALUATION RESULTS")
print("="*50)
print(f"Threshold: {CONFIG['threshold']}")
print(f"Total Test Pairs: {len(test_ds)}")
print(f"\nAccuracy:  {accuracy*100:.2f}%")
print(f"Precision: {precision*100:.2f}%")
print(f"Recall:    {recall*100:.2f}%")
print(f"F1-Score:  {f1*100:.2f}%")

# --- Confusion Matrix ---
cm = confusion_matrix(true_labels, predictions)
print(f"\nConfusion Matrix:")
print(f"                 Predicted")
print(f"                 Similar  Dissimilar")
print(f"Actual Similar   {cm[0,0]:>6}   {cm[0,1]:>6}")
print(f"Actual Dissimilar {cm[1,0]:>6}   {cm[1,1]:>6}")

# --- Visualizations ---
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Confusion Matrix Heatmap
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[0,0],
            xticklabels=['Similar', 'Dissimilar'],
            yticklabels=['Similar', 'Dissimilar'])
axes[0,0].set_title('Confusion Matrix')
axes[0,0].set_ylabel('True Label')
axes[0,0].set_xlabel('Predicted Label')

# 2. Distance Distribution
axes[0,1].hist(distances[true_labels==0], bins=50, alpha=0.6, label='Genuine Pairs', color='green')
axes[0,1].hist(distances[true_labels==1], bins=50, alpha=0.6, label='Forged Pairs', color='red')
axes[0,1].axvline(CONFIG['threshold'], color='black', linestyle='--', linewidth=2, label='Threshold')
axes[0,1].set_xlabel('Distance')
axes[0,1].set_ylabel('Frequency')
axes[0,1].set_title('Distance Distribution')
axes[0,1].legend()

# 3. ROC Curve
fpr, tpr, thresholds = roc_curve(true_labels, distances)
roc_auc = auc(fpr, tpr)
axes[1,0].plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
axes[1,0].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
axes[1,0].set_xlim([0.0, 1.0])
axes[1,0].set_ylim([0.0, 1.05])
axes[1,0].set_xlabel('False Positive Rate')
axes[1,0].set_ylabel('True Positive Rate')
axes[1,0].set_title('ROC Curve')
axes[1,0].legend(loc="lower right")

# 4. Metrics Bar Chart
metrics = ['Accuracy', 'Precision', 'Recall', 'F1-Score']
values = [accuracy*100, precision*100, recall*100, f1*100]
colors = ['#3498db', '#2ecc71', '#e74c3c', '#f39c12']
axes[1,1].bar(metrics, values, color=colors)
axes[1,1].set_ylim([0, 100])
axes[1,1].set_ylabel('Score (%)')
axes[1,1].set_title('Performance Metrics')
for i, v in enumerate(values):
    axes[1,1].text(i, v + 2, f'{v:.2f}%', ha='center', fontweight='bold')

plt.tight_layout()
plt.savefig('evaluation_results.png', dpi=300, bbox_inches='tight')
plt.show()

# --- Find Optimal Threshold ---
print("\n" + "="*50)
print("THRESHOLD OPTIMIZATION")
print("="*50)

thresholds_to_test = np.linspace(0.1, 2.0, 20)
accuracies = []

for thresh in thresholds_to_test:
    preds = (distances >= thresh).astype(float)
    acc = accuracy_score(true_labels, preds)
    accuracies.append(acc)

optimal_idx = np.argmax(accuracies)
optimal_threshold = thresholds_to_test[optimal_idx]
optimal_accuracy = accuracies[optimal_idx]

print(f"Optimal Threshold: {optimal_threshold:.4f}")
print(f"Optimal Accuracy: {optimal_accuracy*100:.2f}%")

# Plot threshold vs accuracy
plt.figure(figsize=(10, 6))
plt.plot(thresholds_to_test, np.array(accuracies)*100, marker='o', linewidth=2)
plt.axvline(optimal_threshold, color='red', linestyle='--', linewidth=2, label=f'Optimal: {optimal_threshold:.4f}')
plt.axvline(CONFIG['threshold'], color='green', linestyle='--', linewidth=2, label=f'Current: {CONFIG["threshold"]:.4f}')
plt.xlabel('Threshold', fontsize=12)
plt.ylabel('Accuracy (%)', fontsize=12)
plt.title('Threshold vs Accuracy', fontsize=14, fontweight='bold')
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)
plt.savefig('threshold_optimization.png', dpi=300, bbox_inches='tight')
plt.show()

# --- Sample Predictions Visualization ---
print("\n" + "="*50)
print("SAMPLE PREDICTIONS")
print("="*50)

def visualize_predictions(dataset, model, num_samples=6, threshold=0.5):
    """Visualize random predictions from the test set"""
    inv_normalize = transforms.Normalize(
        mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
        std=[1/0.229, 1/0.224, 1/0.225]
    )
    
    fig, axes = plt.subplots(num_samples//2, 4, figsize=(16, num_samples*2))
    
    for i in range(num_samples//2):
        # Get random sample
        idx = random.randint(0, len(dataset)-1)
        img1, img2, label = dataset[idx]
        
        # Get prediction
        model.eval()
        with torch.no_grad():
            emb1, emb2 = model(img1.unsqueeze(0).to(CONFIG['device']), 
                              img2.unsqueeze(0).to(CONFIG['device']))
            dist = F.pairwise_distance(emb1, emb2).item()
        
        pred_label = "GENUINE" if dist < threshold else "FORGED"
        actual_label = "GENUINE" if label.item() == 0 else "FORGED"
        is_correct = pred_label == actual_label
        
        # Denormalize images
        img1_display = inv_normalize(img1).permute(1, 2, 0).cpu().numpy()
        img2_display = inv_normalize(img2).permute(1, 2, 0).cpu().numpy()
        
        # Plot
        axes[i, 0].imshow(np.clip(img1_display, 0, 1))
        axes[i, 0].set_title("Reference", fontsize=10)
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(np.clip(img2_display, 0, 1))
        axes[i, 1].set_title("Test", fontsize=10)
        axes[i, 1].axis('off')
        
        # Info
        color = 'green' if is_correct else 'red'
        info_text = f"True: {actual_label}\nPred: {pred_label}\nDist: {dist:.4f}\n{'âœ“ CORRECT' if is_correct else 'âœ— WRONG'}"
        axes[i, 2].text(0.5, 0.5, info_text, ha='center', va='center', 
                       fontsize=11, fontweight='bold', color=color,
                       bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        axes[i, 2].axis('off')
        
        # Distance bar
        axes[i, 3].barh(['Distance'], [dist], color='blue' if dist < threshold else 'red')
        axes[i, 3].axvline(threshold, color='black', linestyle='--', linewidth=2)
        axes[i, 3].set_xlim([0, max(2, dist+0.5)])
        axes[i, 3].set_xlabel('Distance', fontsize=9)
    
    plt.tight_layout()
    plt.savefig('sample_predictions.png', dpi=300, bbox_inches='tight')
    plt.show()

visualize_predictions(test_ds, model, num_samples=6, threshold=CONFIG['threshold'])

print("\n" + "="*50)
print("Evaluation complete! Results saved as PNG files.")
print("="*50)

In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import cv2
from tqdm import tqdm

# --- Configuration ---
CONFIG = {
    "dataset_path": "/kaggle/input/cedardataset/signatures",
    "model_path": "/kaggle/input/siamese-transformer/pytorch/default/1/best_siamese_transformer.pth",
    "img_size": (224, 224),
    "batch_size": 32,
    "embedding_dim": 128,
    "transformer_heads": 4,
    "transformer_layers": 2,
    "dropout": 0.1,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "seed": 42,
    "threshold": 0.5
}

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(CONFIG['seed'])

# --- Dataset Class ---
class SignatureDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.genuine_path = os.path.join(root_dir, 'full_org')
        self.forged_path = os.path.join(root_dir, 'full_forg')
        
        self.writers = {}
        all_genuine = sorted(os.listdir(self.genuine_path))
        all_forged = sorted(os.listdir(self.forged_path))
        
        for f in all_genuine:
            if not f.startswith("original_") or not f.endswith(".png"):
                continue
            try:
                writer_id = int(f.split('_')[1])
            except (IndexError, ValueError):
                continue
            
            if writer_id not in self.writers:
                self.writers[writer_id] = {'genuine': [], 'forged': []}
            self.writers[writer_id]['genuine'].append(os.path.join(self.genuine_path, f))
            
        for f in all_forged:
            if not f.startswith("forgeries_") or not f.endswith(".png"):
                continue
            try:
                writer_id = int(f.split('_')[1])
            except (IndexError, ValueError):
                continue
            if writer_id in self.writers:
                self.writers[writer_id]['forged'].append(os.path.join(self.forged_path, f))
        
        writer_ids = list(self.writers.keys())
        train_ids, test_ids = train_test_split(writer_ids, test_size=0.2, random_state=CONFIG['seed'])
        
        if split == 'train':
            self.active_writers = train_ids
        else:
            self.active_writers = test_ids
            
        self.pairs = self._generate_pairs()

    def _generate_pairs(self):
        pairs = []
        for wid in self.active_writers:
            gens = self.writers[wid]['genuine']
            forgs = self.writers[wid]['forged']
            
            for i in range(len(gens)):
                for j in range(i + 1, len(gens)):
                    pairs.append([gens[i], gens[j], 0])
            
            for g in gens:
                for f in forgs:
                    pairs.append([g, f, 1])
                    
        return pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img1_path, img2_path, label = self.pairs[idx]
        
        img1 = Image.open(img1_path).convert("L").convert("RGB")
        img2 = Image.open(img2_path).convert("L").convert("RGB")
        
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            
        return img1, img2, torch.tensor(label, dtype=torch.float32), img1_path, img2_path

# --- Modified Model with Patch Embedding Extraction ---
class SiameseTransformerExplainable(nn.Module):
    def __init__(self):
        super(SiameseTransformerExplainable, self).__init__()
        
        efficientnet = models.efficientnet_b0(pretrained=False)
        self.backbone = efficientnet.features
        
        self.feature_dim = 1280 
        self.seq_len = 7 * 7  # 49 patches
        self.patch_grid = (7, 7)  # Spatial dimensions
        
        self.pos_embedding = nn.Parameter(torch.randn(1, self.seq_len, self.feature_dim))
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.feature_dim, 
            nhead=CONFIG['transformer_heads'], 
            dim_feedforward=self.feature_dim * 2,
            dropout=CONFIG['dropout'],
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=CONFIG['transformer_layers'])
        
        self.fc = nn.Sequential(
            nn.Linear(self.feature_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, CONFIG['embedding_dim'])
        )

    def forward_one(self, x, return_patches=False):
        # Extract features -> [Batch, 1280, 7, 7]
        features = self.backbone(x)
        
        # Flatten spatial dims to sequence -> [Batch, 1280, 49]
        features = features.view(features.size(0), self.feature_dim, -1)
        
        # Transpose for Transformer -> [Batch, 49, 1280]
        features = features.permute(0, 2, 1)
        
        # Add Positional Encoding
        features = features + self.pos_embedding
        
        # Pass through Transformer
        patch_embeddings = self.transformer(features)  # [Batch, 49, 1280]
        
        if return_patches:
            return patch_embeddings
        
        # Global Average Pooling
        embedding = torch.mean(patch_embeddings, dim=1)
        
        # Final Projection
        embedding = self.fc(embedding)
        return embedding

    def forward(self, img1, img2, return_patches=False):
        if return_patches:
            patches1 = self.forward_one(img1, return_patches=True)
            patches2 = self.forward_one(img2, return_patches=True)
            return patches1, patches2
        else:
            out1 = self.forward_one(img1)
            out2 = self.forward_one(img2)
            return out1, out2

# --- Heatmap Generation Functions ---
def compute_patch_difference_heatmap(model, img1, img2, method='cosine'):
    """
    Compute patch-level differences between two signatures
    
    Args:
        model: The Siamese model
        img1, img2: Input tensors [1, 3, 224, 224]
        method: 'cosine' or 'euclidean'
    
    Returns:
        heatmap: 2D array showing differences
    """
    model.eval()
    with torch.no_grad():
        # Get patch embeddings [1, 49, 1280]
        patches1, patches2 = model(img1, img2, return_patches=True)
        
        # Remove batch dimension [49, 1280]
        patches1 = patches1.squeeze(0)
        patches2 = patches2.squeeze(0)
        
        if method == 'cosine':
            # Compute cosine similarity for each patch
            # Normalize the embeddings
            patches1_norm = F.normalize(patches1, p=2, dim=1)
            patches2_norm = F.normalize(patches2, p=2, dim=1)
            
            # Cosine similarity (element-wise)
            similarity = (patches1_norm * patches2_norm).sum(dim=1)  # [49]
            
            # Convert to difference (1 - similarity)
            difference = 1 - similarity  # [49]
            
        elif method == 'euclidean':
            # Euclidean distance
            difference = torch.norm(patches1 - patches2, dim=1)  # [49]
            # Normalize to [0, 1]
            difference = difference / difference.max()
        
        # Reshape to 2D grid [7, 7]
        heatmap = difference.view(model.patch_grid).cpu().numpy()
        
    return heatmap

def overlay_heatmap(original_img, heatmap, alpha=0.6, colormap=cv2.COLORMAP_JET):
    """
    Overlay heatmap on original image
    
    Args:
        original_img: PIL Image or numpy array
        heatmap: 2D numpy array [7, 7]
        alpha: Transparency of heatmap
        colormap: OpenCV colormap
    
    Returns:
        overlayed image as numpy array
    """
    # Convert PIL to numpy if needed
    if isinstance(original_img, Image.Image):
        original_img = np.array(original_img)
    
    # Ensure original is RGB
    if original_img.shape[-1] != 3:
        original_img = cv2.cvtColor(original_img, cv2.COLOR_GRAY2RGB)
    
    # Get original size
    h, w = original_img.shape[:2]
    
    # Normalize heatmap to [0, 255]
    heatmap_normalized = ((heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8) * 255).astype(np.uint8)
    
    # Resize heatmap to match original image size using bilinear interpolation
    heatmap_resized = cv2.resize(heatmap_normalized, (w, h), interpolation=cv2.INTER_LINEAR)
    
    # Apply colormap
    heatmap_colored = cv2.applyColorMap(heatmap_resized, colormap)
    heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
    
    # Overlay
    overlayed = cv2.addWeighted(original_img, 1-alpha, heatmap_colored, alpha, 0)
    
    return overlayed, heatmap_resized

def create_comprehensive_visualization(model, img1_tensor, img2_tensor, img1_pil, img2_pil, 
                                       label, threshold=0.5, method='cosine'):
    """
    Create a comprehensive visualization with heatmaps
    """
    model.eval()
    
    # Get final embeddings and distance
    with torch.no_grad():
        emb1, emb2 = model(img1_tensor, img2_tensor)
        distance = F.pairwise_distance(emb1, emb2).item()
    
    # Get heatmap
    heatmap = compute_patch_difference_heatmap(model, img1_tensor, img2_tensor, method=method)
    
    # Create overlays
    overlay1, heatmap_resized = overlay_heatmap(img1_pil, heatmap, alpha=0.5)
    overlay2, _ = overlay_heatmap(img2_pil, heatmap, alpha=0.5)
    
    # Determine prediction
    is_genuine = distance < threshold
    pred_label = "GENUINE" if is_genuine else "FORGED"
    true_label = "GENUINE" if label == 0 else "FORGED"
    is_correct = pred_label == true_label
    
    # Create visualization
    fig = plt.figure(figsize=(18, 10))
    gs = fig.add_gridspec(3, 4, hspace=0.3, wspace=0.3)
    
    # Row 1: Original images
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.imshow(img1_pil)
    ax1.set_title("Reference Signature", fontsize=12, fontweight='bold')
    ax1.axis('off')
    
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.imshow(img2_pil)
    ax2.set_title("Test Signature", fontsize=12, fontweight='bold')
    ax2.axis('off')
    
    # Row 1: Difference heatmap (raw)
    ax3 = fig.add_subplot(gs[0, 2])
    im = ax3.imshow(heatmap, cmap='jet', interpolation='nearest')
    ax3.set_title("Patch Difference Map\n(7Ã—7 patches)", fontsize=12, fontweight='bold')
    ax3.axis('off')
    plt.colorbar(im, ax=ax3, fraction=0.046, pad=0.04)
    
    # Row 1: Prediction info
    ax4 = fig.add_subplot(gs[0, 3])
    ax4.axis('off')
    
    color = 'green' if is_correct else 'red'
    status = 'âœ“ CORRECT' if is_correct else 'âœ— INCORRECT'
    
    info_text = f"""
    PREDICTION RESULTS
    {'='*30}
    
    Ground Truth: {true_label}
    Prediction: {pred_label}
    
    Distance: {distance:.4f}
    Threshold: {threshold:.4f}
    
    Status: {status}
    
    Method: {method.upper()}
    """
    
    ax4.text(0.1, 0.5, info_text, fontsize=11, family='monospace',
             verticalalignment='center',
             bbox=dict(boxstyle='round', facecolor=color, alpha=0.2))
    
    # Row 2: Overlayed heatmaps
    ax5 = fig.add_subplot(gs[1, 0])
    ax5.imshow(overlay1)
    ax5.set_title("Reference + Heatmap", fontsize=12, fontweight='bold')
    ax5.axis('off')
    
    ax6 = fig.add_subplot(gs[1, 1])
    ax6.imshow(overlay2)
    ax6.set_title("Test + Heatmap", fontsize=12, fontweight='bold')
    ax6.axis('off')
    
    # Row 2: High-resolution heatmap
    ax7 = fig.add_subplot(gs[1, 2])
    im2 = ax7.imshow(heatmap_resized, cmap='jet', interpolation='bilinear')
    ax7.set_title("Upsampled Heatmap\n(224Ã—224)", fontsize=12, fontweight='bold')
    ax7.axis('off')
    plt.colorbar(im2, ax=ax7, fraction=0.046, pad=0.04)
    
    # Row 2: Interpretation guide
    ax8 = fig.add_subplot(gs[1, 3])
    ax8.axis('off')
    
    guide_text = """
    HEATMAP INTERPRETATION
    {'='*30}
    
    ðŸ”µ BLUE regions:
       Low difference
       Patches match well
       Similar strokes
    
    ðŸŸ¢ GREEN regions:
       Moderate difference
       Some variation
    
    ðŸŸ¡ YELLOW regions:
       High difference
       Significant variation
    
    ðŸ”´ RED regions:
       Very high difference
       Strong mismatch
       Different strokes/angles
    """
    
    ax8.text(0.1, 0.5, guide_text, fontsize=10, family='monospace',
             verticalalignment='center',
             bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.3))
    
    # Row 3: Patch-level analysis
    ax9 = fig.add_subplot(gs[2, :2])
    
    # Flatten heatmap and find top differences
    flat_heatmap = heatmap.flatten()
    top_k = 5
    top_indices = np.argsort(flat_heatmap)[-top_k:][::-1]
    
    patch_positions = []
    patch_values = []
    for idx in top_indices:
        row = idx // 7
        col = idx % 7
        patch_positions.append(f"Patch ({row},{col})")
        patch_values.append(flat_heatmap[idx])
    
    colors_bar = ['red' if v > 0.7 else 'orange' if v > 0.5 else 'yellow' for v in patch_values]
    ax9.barh(patch_positions, patch_values, color=colors_bar)
    ax9.set_xlabel('Difference Score', fontsize=11)
    ax9.set_title(f'Top {top_k} Most Different Patches', fontsize=12, fontweight='bold')
    ax9.set_xlim([0, 1])
    
    # Row 3: Statistics
    ax10 = fig.add_subplot(gs[2, 2:])
    ax10.axis('off')
    
    stats_text = f"""
    STATISTICAL ANALYSIS
    {'='*35}
    
    Heatmap Statistics:
    â€¢ Mean Difference:    {heatmap.mean():.4f}
    â€¢ Max Difference:     {heatmap.max():.4f}
    â€¢ Min Difference:     {heatmap.min():.4f}
    â€¢ Std Deviation:      {heatmap.std():.4f}
    
    Patch Analysis:
    â€¢ Total Patches:      {heatmap.size}
    â€¢ High Diff (>0.7):   {(heatmap > 0.7).sum()} patches
    â€¢ Medium Diff (0.5-0.7): {((heatmap > 0.5) & (heatmap <= 0.7)).sum()} patches
    â€¢ Low Diff (<0.5):    {(heatmap <= 0.5).sum()} patches
    
    Classification:
    â€¢ Distance:           {distance:.4f}
    â€¢ Margin from threshold: {abs(distance - threshold):.4f}
    """
    
    ax10.text(0.1, 0.5, stats_text, fontsize=10, family='monospace',
             verticalalignment='center',
             bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.2))
    
    # Main title
    title_color = 'green' if is_correct else 'red'
    fig.suptitle(f'Explainable Signature Verification - {pred_label} ({status})', 
                 fontsize=16, fontweight='bold', color=title_color, y=0.98)
    
    return fig, heatmap

# --- Load Model and Dataset ---
print("Loading model and dataset...")

test_transforms = transforms.Compose([
    transforms.Resize(CONFIG['img_size']),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_ds = SignatureDataset(CONFIG['dataset_path'], split='test', transform=test_transforms)
print(f"Test Pairs: {len(test_ds)}")

model = SiameseTransformerExplainable().to(CONFIG['device'])
model.load_state_dict(torch.load(CONFIG['model_path'], map_location=CONFIG['device']))
model.eval()
print("Model loaded successfully!")

# --- Generate Visualizations ---
print("\n" + "="*60)
print("GENERATING EXPLAINABLE HEATMAP VISUALIZATIONS")
print("="*60)

# Inverse normalization for display
inv_normalize = transforms.Normalize(
    mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
    std=[1/0.229, 1/0.224, 1/0.225]
)

num_samples = 4  # Number of sample pairs to visualize

for i in range(num_samples):
    # Get a random sample
    idx = random.randint(0, len(test_ds)-1)
    img1_tensor, img2_tensor, label, path1, path2 = test_ds[idx]
    
    # Load original PIL images for overlay
    img1_pil = Image.open(path1).convert("RGB")
    img2_pil = Image.open(path2).convert("RGB")
    
    # Resize to match model input
    img1_pil = img1_pil.resize(CONFIG['img_size'])
    img2_pil = img2_pil.resize(CONFIG['img_size'])
    
    # Add batch dimension
    img1_batch = img1_tensor.unsqueeze(0).to(CONFIG['device'])
    img2_batch = img2_tensor.unsqueeze(0).to(CONFIG['device'])
    
    # Create visualization
    print(f"\nGenerating visualization {i+1}/{num_samples}...")
    fig, heatmap = create_comprehensive_visualization(
        model, img1_batch, img2_batch, img1_pil, img2_pil, 
        label.item(), threshold=CONFIG['threshold'], method='cosine'
    )
    
    plt.savefig(f'explainable_signature_{i+1}.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"Saved: explainable_signature_{i+1}.png")

print("\n" + "="*60)
print("VISUALIZATION COMPLETE!")
print("="*60)
print("\nKey Insights from Heatmap Analysis:")
print("â€¢ RED regions indicate strong mismatches (different strokes/angles)")
print("â€¢ BLUE regions show good matches (similar strokes)")
print("â€¢ The heatmap helps identify WHICH parts of signatures differ")
print("â€¢ This provides explainability beyond just a distance number")
print("â€¢ Useful for forensic analysis and understanding model decisions")