=== SETUP AND IMPORTS ===

In [None]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, models
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import numpy as np
import random
import os
import json
import sys
import shutil

# Set repo root
REPO_ROOT = os.path.abspath(os.getcwd())
sys.path.append(REPO_ROOT)

# Project-specific modules
from dataloader.tDCBAM_trainloader import SignaturePretrainDataset, get_pretraining_transforms
from models.Triplet_Siamese_Similarity_Network import tDCBAM
from losses.triplet_loss import TripletLoss

print(f" > Repo root: {REPO_ROOT}")

In [None]:
# HYPERPARAMETER CONFIGURATION
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Optimized for DenseNet121 & Robustness
INPUT_SHAPE = (224, 224) 
BATCH_SIZE = 30
EPOCHS = 30
LEARNING_RATE = 1e-4
MARGIN = 2.0

# Initialize Data Augmentation Pipeline
# This calls the updated function from your GitHub repo
train_transform = get_pretraining_transforms(input_shape=INPUT_SHAPE)

print(f" > Computation Device: {DEVICE}")
print(f" > Input Shape: {INPUT_SHAPE}")
print(f" > Margin: {MARGIN}")

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Random seed set to: {seed}")

SEED = 42
seed_everything(SEED)

In [None]:
# DATA PREPARATION STRATEGY (Ubuntu Local)
DATA_ROOT = os.path.expanduser('~/data/bhsig260-hindi-bengali')  # Update to your local data path

working_dir = os.path.join(REPO_ROOT, 'data')
genuine_dir = os.path.join(working_dir, 'all_genuine')
forged_dir = os.path.join(working_dir, 'all_forged')
splits_dir = os.path.join(working_dir, 'splits')

# Create clean directories
for d in [genuine_dir, forged_dir, splits_dir]:
    if os.path.exists(d):
        shutil.rmtree(d)
    os.makedirs(d, exist_ok=True)

print("Status: Consolidating BHSig dataset (Hindi + Bengali) into a unified structure...")
print(f" > Source: {DATA_ROOT}")
print(f" > Destination: {working_dir}")

# Verify source data exists
if os.path.isdir(DATA_ROOT):
    # Copy genuine signatures
    hindi_gen = os.path.join(DATA_ROOT, 'BHSig160_Hindi', 'Genuine')
    bengali_gen = os.path.join(DATA_ROOT, 'BHSig100_Bengali', 'Genuine')
    
    if os.path.isdir(hindi_gen):
        os.system(f'cp -r {hindi_gen}/* {genuine_dir}/ 2>/dev/null || true')
    if os.path.isdir(bengali_gen):
        os.system(f'cp -r {bengali_gen}/* {genuine_dir}/ 2>/dev/null || true')
    
    # Copy forged signatures
    hindi_forg = os.path.join(DATA_ROOT, 'BHSig160_Hindi', 'Forged')
    bengali_forg = os.path.join(DATA_ROOT, 'BHSig100_Bengali', 'Forged')
    
    if os.path.isdir(hindi_forg):
        os.system(f'cp -r {hindi_forg}/* {forged_dir}/ 2>/dev/null || true')
    if os.path.isdir(bengali_forg):
        os.system(f'cp -r {bengali_forg}/* {forged_dir}/ 2>/dev/null || true')
    
    print(f" > Genuine files: {len(os.listdir(genuine_dir))}")
    print(f" > Forged files: {len(os.listdir(forged_dir))}")
    print("Status: Data consolidation complete.")
else:
    print(f"ERROR: Data source not found at {DATA_ROOT}")
    print("Please download BHSig data and update DATA_ROOT path")

# Execute the data restructuring script.
# Objective: Segregate the dataset into disjoint 'Background' (Pre-training) and 'Evaluation' (Meta-learning) sets.
# Configuration: 150 users for Pre-training (Background), 110 users for Meta-learning (Evaluation).
print(" > Generating dataset splits...")
script_cmd = f"python scripts/restructure_bhsig.py --base_dir {DATA_ROOT} --output_dir {splits_dir} --pretrain_users 150"
os.system(script_cmd)

# Load the identified background users to restrict the pre-training scope.
background_users_path = os.path.join(splits_dir, 'bhsig_background_users.json')
try:
    with open(background_users_path, 'r') as f:
        background_users = json.load(f)
    print(f"Success: Loaded {len(background_users)} users for the Pre-training phase (Background Set).")
except FileNotFoundError:
    print("Error: Background users file not found. Ensure the restructuring script executed correctly.")
    print(f" > Expected path: {background_users_path}")
    background_users = []

=== DATASET AND DATALOADER INITIALIZATION ===

In [None]:
# Initialize the Dataset using the consolidated directories
train_dataset = SignaturePretrainDataset(
    org_dir=genuine_dir,
    forg_dir=forged_dir,
    transform=train_transform,
    user_list=background_users
)

# Initialize DataLoader
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,           # Shuffle is CRITICAL for SGD
    num_workers=4,          # Parallel data loading
    pin_memory=True,        # Faster transfer to GPU
    drop_last=True          # Avoid incomplete batches causing issues
)

print(f" > Dataset Prepared.")
print(f" > Total Training Triplets available per epoch: {len(train_dataset)}")
print(f" > Batch Size: {BATCH_SIZE}")

=== MODEL ARCHITECTURE AND LOSS FUNCTION ===

In [None]:
# 1. Initialize Model
model = tDCBAM(backbone_name='densenet121', output_dim=1024, pretrained=True).to(DEVICE)

# 2. Loss and Optimizer
criterion = TripletLoss(margin=MARGIN, mode='cosine')
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

# 3. Training Loop
history = {'loss': []}
CHECKPOINT_DIR = os.path.join(REPO_ROOT, 'checkpoints', 'pretraining')
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print(f"Starting Training for {EPOCHS} epochs...")
print(f"Checkpoints will be saved to: {CHECKPOINT_DIR}")

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    
    # REGENERATE TRIPLETS: Critical for Online Hard Mining
    # This reshuffles pairs to find new hard negatives every epoch
    train_dataset.on_epoch_end()
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False)
    
    for batch_idx, (anchor, positive, negative, _) in enumerate(pbar):
        anchor, positive, negative = anchor.to(DEVICE), positive.to(DEVICE), negative.to(DEVICE)
        
        optimizer.zero_grad()
        
        # Forward pass
        anchor_emb, pos_emb, neg_emb = model(anchor, positive, negative)
        
        # Compute Loss
        loss = criterion(anchor_emb, pos_emb, neg_emb)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        pbar.set_postfix({'loss': loss.item()})
        
    avg_loss = running_loss / len(train_loader)
    history['loss'].append(avg_loss)
    
    # Update Scheduler
    scheduler.step(avg_loss)
    
    print(f"Epoch [{epoch+1}/{EPOCHS}] | Triplet Loss: {avg_loss:.4f} | LR: {optimizer.param_groups[0]['lr']:.2e}")
    
    # Save Checkpoint every 5 epochs
    if (epoch + 1) % 5 == 0:
        ckpt_path = os.path.join(CHECKPOINT_DIR, f"tDCBAM_pretrain_epoch_{epoch+1}.pth")
        torch.save(model.state_dict(), ckpt_path)
        print(f"   >>> Checkpoint saved: {os.path.basename(ckpt_path)}")

# Save Final Feature Extractor Weights
final_weights_path = os.path.join(REPO_ROOT, "background_pretrain.pth")
torch.save(model.feature_extractor.state_dict(), final_weights_path)
print(f"\nTraining Complete. Final weights saved to: {final_weights_path}")

=== TRAINING VISUALIZATION ===

In [None]:
# Plot Training Loss
plt.figure(figsize=(10, 5))
plt.plot(history['loss'], label='Triplet Loss')
plt.title('Pre-training Convergence (Triplet Loss)')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()