# === 1. SETUP AND IMPORTS ===

In [1]:
# Re-clone repository để đảm bảo code mới nhất
!rm -rf Deep_Learning-Based_Signature_Forgery_Detection_for_Personal_Identity_Authentication
!git clone https://github.com/trongjhuongwr/Deep_Learning-Based_Signature_Forgery_Detection_for_Personal_Identity_Authentication.git
%cd Deep_Learning-Based_Signature_Forgery_Detection_for_Personal_Identity_Authentication

import pandas as pd
import shutil
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import numpy as np
import os
import sys
import matplotlib.pyplot as plt
import seaborn as sns
import re
import json
import glob
import random
import torch.nn as nn
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.metrics import roc_curve, auc, accuracy_score, confusion_matrix, precision_recall_fscore_support
from torchvision import transforms
import torch.optim as optim
import itertools

# Thêm đường dẫn hiện tại vào sys.path để import modules
sys.path.append(os.path.abspath(os.getcwd()))

# Import các modules từ repo đã clone
from models.feature_extractor import ResNetFeatureExtractor
from models.meta_learner import MetricGenerator
from utils.model_evaluation import compute_metrics, _plot_det_curve, _plot_far_frr, _plot_confusion_matrix, _plot_score_distribution, _plot_roc_curve

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Random seed set to: {seed}")

SEED = 42
seed_everything(SEED)

print(f"Setup complete. Device: {DEVICE}")

Cloning into 'Deep_Learning-Based_Signature_Forgery_Detection_for_Personal_Identity_Authentication'...
remote: Enumerating objects: 3462, done.[K
remote: Counting objects: 100% (165/165), done.[K
remote: Compressing objects: 100% (120/120), done.[K
remote: Total 3462 (delta 81), reused 104 (delta 45), pack-reused 3297 (from 3)[K
Receiving objects: 100% (3462/3462), 248.67 MiB | 21.43 MiB/s, done.
Resolving deltas: 100% (386/386), done.
/kaggle/working/Deep_Learning-Based_Signature_Forgery_Detection_for_Personal_Identity_Authentication
Random seed set to: 42
Setup complete. Device: cuda


# === 2. DATASET LOADER WITH PROTOCOL SPLIT ===

In [2]:
class CedarAdaptationDataset(Dataset):
    """
    Custom Loader cho CEDAR hỗ trợ Domain Adaptation Protocol.
    
    Chiến lược chia (Split):
    - Adaptation Set (Train): Users có ID <= split_user_id (ví dụ: 1-10).
      Dùng để fine-tune mô hình học phong cách chữ ký của domain mới.
    - Evaluation Set (Test): Users có ID > split_user_id (ví dụ: 11-55).
      Dùng để kiểm thử hiệu năng trên những người chưa từng gặp (User-Independent).
    """
    def __init__(self, root_dir, mode='adaptation', split_user_id=10, n_pairs=500, transform=None):
        self.root_dir = root_dir
        self.mode = mode
        self.split_id = split_user_id
        self.n_pairs = n_pairs
        self.transform = transform
        self.users = {} # {uid: {'gen': [], 'forg': []}}
        self.pairs = []
        
        self._parse_cedar_structure()
        if self.mode == 'test':
            self._generate_exhaustive_pairs()
        else:
            self._generate_balanced_pairs()
        
    def _parse_cedar_structure(self):
        """Quét cấu trúc thư mục và nhóm file theo User ID."""
        print(f" > [{self.mode.upper()}] Scanning CEDAR data...")
        
        # Đường dẫn dựa trên cấu trúc dataset CEDAR chuẩn
        gen_path = os.path.join(self.root_dir, 'full_org')
        forg_path = os.path.join(self.root_dir, 'full_forg')
        
        # 1. Parse Genuine (Chữ ký thật)
        # Pattern: original_10_1.png -> ID=10
        for fpath in glob.glob(os.path.join(gen_path, "original_*.png")):
            try:
                fname = os.path.basename(fpath)
                uid = int(fname.split('_')[1])
                self._add_file(uid, fpath, 'gen')
            except: pass

        # 2. Parse Forged (Chữ ký giả)
        # Pattern: forgeries_10_1.png -> ID=10
        for fpath in glob.glob(os.path.join(forg_path, "forgeries_*.png")):
            try:
                fname = os.path.basename(fpath)
                uid = int(fname.split('_')[1])
                self._add_file(uid, fpath, 'forg')
            except: pass
            
        print(f" > Total Users Found in Split: {len(self.users)}")

    def _add_file(self, uid, fpath, ftype):
        # Logic lọc dựa trên Mode và Split ID
        if self.mode == 'adaptation' and uid <= self.split_id:
            if uid not in self.users: self.users[uid] = {'gen': [], 'forg': []}
            self.users[uid][ftype].append(fpath)
        elif self.mode == 'test' and uid > self.split_id:
            if uid not in self.users: self.users[uid] = {'gen': [], 'forg': []}
            self.users[uid][ftype].append(fpath)

    def _generate_balanced_pairs(self):
        """Sinh cặp ngẫu nhiên cân bằng cho tập Train/Adaptation."""
        if len(self.users) == 0: return

        # Phân phối n_pairs đều cho các users
        pairs_per_user = max(20, self.n_pairs // len(self.users))
        
        for uid, data in self.users.items():
            gens = data['gen']
            forgs = data['forg']
            
            if len(gens) < 2: continue
            
            # 1. Genuine Pairs (Positive - Label 1.0)
            for _ in range(pairs_per_user // 2):
                self.pairs.append((random.choice(gens), random.choice(gens), 1.0))
                
            # 2. Forged Pairs (Negative - Label 0.0)
            if len(forgs) > 0:
                for _ in range(pairs_per_user // 2):
                    self.pairs.append((random.choice(gens), random.choice(forgs), 0.0))
        
        random.shuffle(self.pairs)
        print(f" > Generated {len(self.pairs)} pairs for {self.mode}.")

    def _generate_exhaustive_pairs(self):
        """Sinh TẤT CẢ các cặp có thể cho tập Test (Evaluation)."""
        print(f" > [Exhaustive] Generating ALL possible pairs for {len(self.users)} users...")
        
        for uid, data in self.users.items():
            gens = data['gen']
            forgs = data['forg']
            
            # 1. Genuine Pairs: Tổ hợp chập 2 của các chữ ký thật
            gen_pairs = list(itertools.combinations(gens, 2))
            for p1, p2 in gen_pairs:
                self.pairs.append((p1, p2, 1.0))
            
            # 2. Forged Pairs: Tích Descartes (Mỗi chữ ký thật vs Mỗi chữ ký giả)
            for g in gens:
                for f in forgs:
                    self.pairs.append((g, f, 0.0))
                    
        print(f" > [Exhaustive] Total Pairs Generated: {len(self.pairs)} (Full Evaluation)")
    
    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        p1, p2, lbl = self.pairs[idx]
        img1 = Image.open(p1).convert('RGB')
        img2 = Image.open(p2).convert('RGB')
        
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            
        return {
            'support_images': img1, 
            'query_images': img2, 
            'query_labels': torch.tensor(lbl, dtype=torch.float32),
            'paths': (p1, p2)
        }

# === 3. EVALUATION ENGINE (UPDATED METRICS) ===

In [3]:
def evaluate_model(fe, mg, loader, device, num_users=None, output_dir=None, silent=False):
    """
    Hàm đánh giá tiêu chuẩn với đầy đủ các metrics:
    Users, Accuracy, Precision, Recall, F1, AUC, EER, Threshold.
    """
    fe.eval()
    mg.eval()
    all_scores, all_labels = [], []
    
    iter_bar = tqdm(loader, desc="Inference", leave=False) if not silent else loader
    
    with torch.no_grad():
        for batch in iter_bar:
            s = batch['support_images'].to(device)
            q = batch['query_images'].to(device)
            lbl = batch['query_labels'].to(device)
            
            combined = torch.cat((fe(s), fe(q)), dim=1)
            # Output của MetricGenerator là logits, dùng sigmoid để đưa về [0, 1]
            probs = torch.sigmoid(mg(combined)).squeeze(1)
            
            all_scores.extend(probs.cpu().numpy())
            all_labels.extend(lbl.cpu().numpy())
            
    # --- TÍNH TOÁN METRICS CHI TIẾT ---
    all_labels = np.array(all_labels)
    all_scores = np.array(all_scores)

    # 1. Tính ROC Curve để tìm EER và Threshold tối ưu
    fpr, tpr, thresholds = roc_curve(all_labels, all_scores)
    fnr = 1 - tpr
    
    # Tìm điểm EER (Equal Error Rate): nơi FAR (FPR) ~ FRR (FNR)
    # $EER = FPR$ tại điểm $FPR \approx FNR$
    eer_index = np.nanargmin(np.absolute((fnr - fpr)))
    eer = fpr[eer_index]
    eer_threshold = thresholds[eer_index]
    roc_auc = auc(fpr, tpr)

    # 2. Dự đoán nhãn (Binary Prediction) dựa trên EER Threshold
    # Score >= Threshold -> Genuine (1), ngược lại là Forged (0)
    preds = (all_scores >= eer_threshold).astype(int)

    # 3. Tính các metrics classification
    accuracy = accuracy_score(all_labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, preds, average='binary', zero_division=0)
    
    # Gom lại kết quả
    results = {
        'users': num_users if num_users else 'N/A',
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auc': roc_auc,
        'eer': eer,
        'threshold': eer_threshold,
        'labels': all_labels,
        'scores': all_scores
    }

    # In bảng kết quả
    if output_dir or not silent:
        print(f"\n{'='*10} FINAL TEST RESULTS (CEDAR) {'='*10}")
        print(f"{'Metric':<20} | {'Value':<10}")
        print("-" * 35)
        print(f"{'Users (Test Set)':<20} | {results['users']}")
        print(f"{'Accuracy':<20} | {results['accuracy']:.2%}")
        print(f"{'Precision':<20} | {results['precision']:.2%}")
        print(f"{'Recall':<20} | {results['recall']:.2%}")
        print(f"{'F1-Score':<20} | {results['f1']:.2%}")
        print(f"{'ROC-AUC':<20} | {results['auc']:.4f}")
        print(f"{'EER':<20} | {results['eer']:.2%}")
        print(f"{'EER Threshold':<20} | {results['threshold']:.4f}")
        print("="*35)
        
    # Vẽ biểu đồ nếu có output_dir
    if output_dir:
        try:
            _plot_roc_curve(results, output_dir)
            _plot_score_distribution(results, output_dir)
            # Chuyển format cho hàm confusion matrix nếu cần (tuỳ implement gốc)
            # Ở đây ta giả định hàm _plot_confusion_matrix nhận dict results như trên
            _plot_confusion_matrix(results, output_dir) 
            _plot_det_curve(results, output_dir)
            _plot_far_frr(results, output_dir)
        except Exception as e:
            print(f"Warning: Could not plot charts. Reason: {e}")
            
    return results

# === 4. ADAPTATION ENGINE ===

In [4]:
def run_domain_adaptation(pretrained_path, train_loader, val_loader, device, epochs=10):
    """
    Fine-tunes mô hình BHSig trên tập con của CEDAR (Few-shot Domain Adaptation).
    """
    print(f"\n{'='*10} PHASE 1: FEW-SHOT DOMAIN ADAPTATION {'='*10}")
    
    # 1. Load Pre-trained Model
    feature_extractor = ResNetFeatureExtractor(backbone_name='resnet34').to(device)
    metric_generator = MetricGenerator(embedding_dim=1024).to(device)
    
    print(f" > Loading Source Weights: {os.path.basename(pretrained_path)}")
    ckpt = torch.load(pretrained_path, map_location=device, weights_only=False)
    feature_extractor.load_state_dict(ckpt['feature_extractor'])
    metric_generator.load_state_dict(ckpt['metric_generator'])
    
    # 2. Optimizer (Learning Rate thấp để tránh quên kiến thức cũ - Catastrophic Forgetting)
    optimizer = optim.AdamW([
        {'params': feature_extractor.parameters(), 'lr': 1e-5}, 
        {'params': metric_generator.parameters(), 'lr': 5e-5}   
    ], weight_decay=1e-3)
    
    criterion = nn.BCEWithLogitsLoss()
    
    best_eer = 1.0
    best_state = None
    
    # 3. Adaptation Loop
    for epoch in range(epochs):
        feature_extractor.train()
        metric_generator.train()
        epoch_loss = 0
        
        for batch in tqdm(train_loader, desc=f"Adaptation Epoch {epoch+1}", leave=False):
            s = batch['support_images'].to(device)
            q = batch['query_images'].to(device)
            lbl = batch['query_labels'].to(device).unsqueeze(1)
            
            optimizer.zero_grad()
            
            # Forward
            combined = torch.cat((feature_extractor(s), feature_extractor(q)), dim=1)
            scores = metric_generator(combined)
            loss = criterion(scores, lbl)
            
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            
        # Validation on Target Domain (Unseen Users)
        # Lưu ý: Lúc validate trong vòng lặp, ta để silent=True để đỡ rối màn hình
        val_results = evaluate_model(feature_extractor, metric_generator, val_loader, device, silent=True)
        curr_eer = val_results['eer']
        
        print(f"   Epoch {epoch+1:02d} | Loss: {epoch_loss/len(train_loader):.4f} | Test EER: {curr_eer:.2%}")
        
        if curr_eer < best_eer:
            best_eer = curr_eer
            # Save in-memory best state
            best_state = {
                'fe': feature_extractor.state_dict(),
                'mg': metric_generator.state_dict()
            }
            
    print(f" > Adaptation Complete. Best EER Achieved: {best_eer:.2%}")
    
    # Load best weights for final testing
    feature_extractor.load_state_dict(best_state['fe'])
    metric_generator.load_state_dict(best_state['mg'])
    
    return feature_extractor, metric_generator

# === 5. EXECUTION BLOCK ===

In [5]:
# --- PATHS CONFIGURATION ---
CEDAR_ROOT = '/kaggle/input/cedardataset/signatures'
BHSIG_MODEL_PATH = '/kaggle/input/my-best-models-meta-learning/Deep_Learning-Based_Signature_Forgery_Detection_for_Personal_Identity_Authentication/checkpoints_meta/best_model_fold_2.pth' 
OUTPUT_DIR = '/kaggle/working/cedar_adaptation_results'

os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- TRANSFORMS ---
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 1. PREPARE DATASETS
# Adaptation: Lấy 10 người đầu tiên (ID <= 10) để học
print(">>> Initializing Adaptation Set...")
adapt_set = CedarAdaptationDataset(CEDAR_ROOT, mode='adaptation', split_user_id=10, n_pairs=2000, transform=test_transform)
adapt_loader = DataLoader(adapt_set, batch_size=16, shuffle=True, num_workers=2)

# Test: Lấy các người còn lại (ID > 10) để kiểm tra
print(">>> Initializing Test Set...")
test_set = CedarAdaptationDataset(CEDAR_ROOT, mode='test', split_user_id=10, transform=test_transform)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=2)

# Lấy số lượng user trong tập test để hiển thị trong metrics
num_test_users = len(test_set.users)

# 2. RUN PROCESS
if os.path.exists(BHSIG_MODEL_PATH) and len(adapt_set) > 0:
    # Phase 1: Adapt (Train)
    final_fe, final_mg = run_domain_adaptation(BHSIG_MODEL_PATH, adapt_loader, test_loader, DEVICE, epochs=15)
    
    # Phase 2: Final Test & Visualize (Eval)
    print("\n>>> Starting Final Evaluation...")
    evaluate_model(final_fe, final_mg, test_loader, DEVICE, num_users=num_test_users, output_dir=OUTPUT_DIR)
    
    # Save Adapted Model
    torch.save({
        'feature_extractor': final_fe.state_dict(),
        'metric_generator': final_mg.state_dict()
    }, os.path.join(OUTPUT_DIR, 'cedar_adapted_model.pth'))
    print(f" > Saved adapted model to {OUTPUT_DIR}")
    
else:
    print(f"Error: Model path not found ({BHSIG_MODEL_PATH}) or Dataset is empty.")

>>> Initializing Adaptation Set...
 > [ADAPTATION] Scanning CEDAR data...
 > Total Users Found in Split: 10
 > Generated 2000 pairs for adaptation.
>>> Initializing Test Set...
 > [TEST] Scanning CEDAR data...
 > Total Users Found in Split: 45
 > [Exhaustive] Generating ALL possible pairs for 45 users...
 > [Exhaustive] Total Pairs Generated: 38340 (Full Evaluation)



Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 210MB/s]


 > Loading Source Weights: best_model_fold_2.pth


Adaptation Epoch 1:   0%|          | 0/125 [00:00<?, ?it/s]

   Epoch 01 | Loss: 0.2821 | Test EER: 21.70%


Adaptation Epoch 2:   0%|          | 0/125 [00:00<?, ?it/s]

   Epoch 02 | Loss: 0.0413 | Test EER: 13.00%


Adaptation Epoch 3:   0%|          | 0/125 [00:00<?, ?it/s]

   Epoch 03 | Loss: 0.0263 | Test EER: 10.42%


Adaptation Epoch 4:   0%|          | 0/125 [00:00<?, ?it/s]

   Epoch 04 | Loss: 0.0180 | Test EER: 7.02%


Adaptation Epoch 5:   0%|          | 0/125 [00:00<?, ?it/s]

   Epoch 05 | Loss: 0.0084 | Test EER: 7.01%


Adaptation Epoch 6:   0%|          | 0/125 [00:00<?, ?it/s]

   Epoch 06 | Loss: 0.0115 | Test EER: 7.02%


Adaptation Epoch 7:   0%|          | 0/125 [00:00<?, ?it/s]

   Epoch 07 | Loss: 0.0058 | Test EER: 5.95%


Adaptation Epoch 8:   0%|          | 0/125 [00:00<?, ?it/s]

   Epoch 08 | Loss: 0.0151 | Test EER: 5.86%


Adaptation Epoch 9:   0%|          | 0/125 [00:00<?, ?it/s]

   Epoch 09 | Loss: 0.0040 | Test EER: 6.30%


Adaptation Epoch 10:   0%|          | 0/125 [00:00<?, ?it/s]

   Epoch 10 | Loss: 0.0083 | Test EER: 2.12%


Adaptation Epoch 11:   0%|          | 0/125 [00:00<?, ?it/s]

   Epoch 11 | Loss: 0.0034 | Test EER: 2.63%


Adaptation Epoch 12:   0%|          | 0/125 [00:00<?, ?it/s]

   Epoch 12 | Loss: 0.0056 | Test EER: 3.42%


Adaptation Epoch 13:   0%|          | 0/125 [00:00<?, ?it/s]

   Epoch 13 | Loss: 0.0035 | Test EER: 3.13%


Adaptation Epoch 14:   0%|          | 0/125 [00:00<?, ?it/s]

   Epoch 14 | Loss: 0.0022 | Test EER: 3.35%


Adaptation Epoch 15:   0%|          | 0/125 [00:00<?, ?it/s]

   Epoch 15 | Loss: 0.0020 | Test EER: 2.85%
 > Adaptation Complete. Best EER Achieved: 2.12%

>>> Starting Final Evaluation...


Inference:   0%|          | 0/600 [00:00<?, ?it/s]


Metric               | Value     
-----------------------------------
Users (Test Set)     | 45
Accuracy             | 97.15%
Precision            | 94.24%
Recall               | 97.15%
F1-Score             | 95.67%
ROC-AUC              | 0.9924
EER                  | 2.85%
EER Threshold        | 0.0021
 > Saved adapted model to /kaggle/working/cedar_adaptation_results


<Figure size 800x600 with 0 Axes>