# Biometric Fusion with MPT Training

This notebook implements the training process for the Multi-Modal Biometric Fusion model with Modified Prompt Tuning (MPT).

## Overview

The model combines:
1. CNN feature extractors for periocular, forehead, and iris images
2. Transformer-based fusion with modified prompt tuning
3. Multiple loss functions for optimal embedding learning

In [None]:
# Import required libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import os
import time
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from sklearn.manifold import TSNE
from PIL import Image

from model.model_mpt import BiometricModel
from model.dataset import BiometricDataset
from model.loss import InfoNCELoss, ContrastiveLoss

## Configuration Parameters

Set up the hyperparameters for training.

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Training parameters
params = {
    'data_path': './dataset2',            # Path to dataset
    'embedding_dim': 512,                 # Dimension of embeddings
    'batch_size': 32,                     # Batch size for training
    'epochs': 100,                         # Number of training epochs
    'learning_rate': 3e-4,                # Learning rate
    'weight_decay': 1e-5,                 # Weight decay for optimizer
    'save_dir': './checkpoints',          # Directory to save checkpoints
    'save_interval': 5,                   # Save model every N epochs
    'k1': 0.5,                            # Weight for ContrastiveLoss
    'k2': 0.5,                            # Weight for InfoNCELoss
    'temperature': 0.05,                  # Temperature for InfoNCE loss
    'margin': 2.0                         # Margin for Contrastive loss
}

# Create save directory if it doesn't exist
os.makedirs(params['save_dir'], exist_ok=True)

## Data Preparation

Load and prepare datasets for training and validation.

In [None]:
# Define data transformations
transform_train = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomRotation(degrees=3),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0)),
    transforms.RandomPosterize(bits=6, p=0.1),
    transforms.RandomAdjustSharpness(sharpness_factor=1.5, p=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])

transform_val = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load datasets
print("Loading datasets...")
train_dataset = BiometricDataset(root_dir=params['data_path'], transform=transform_train, 
                                instances_per_person=8, split='train')
test_dataset = BiometricDataset(root_dir=params['data_path'], transform=transform_val, 
                               instances_per_person=8, split='test')

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=params['batch_size'], shuffle=True, num_workers=4)
val_loader = DataLoader(test_dataset, batch_size=params['batch_size'], shuffle=False, num_workers=4)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

## Model Definition and Loss Functions

Create the model and define loss functions for training.

In [None]:
# Initialize model
model = BiometricModel(embedding_dim=params['embedding_dim']).to(device)

# Initialize loss functions
infonce_criterion = InfoNCELoss(temperature=params['temperature'])
contrastive_criterion = ContrastiveLoss(margin=params['margin'])

# Initialize optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=params['learning_rate'], 
                              weight_decay=params['weight_decay'])

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, verbose=True
)

## Training and Validation Functions

Define functions for training and validation steps.

In [None]:
def train_step(model, optimizer, infonce_criterion, contrastive_criterion, 
               periocular, forehead, iris, labels, device, k1=0.5, k2=0.5):
    """Perform one training step"""
    model.train()
    optimizer.zero_grad()
    
    periocular = periocular.to(device)
    forehead = forehead.to(device)
    iris = iris.to(device)
    labels = labels.to(device)
    
    embeddings = model(periocular, forehead, iris)
    
    # Calculate both losses and combine them with weights
    infonce_loss = infonce_criterion(embeddings, labels)
    contrastive_loss = contrastive_criterion(embeddings, labels)
    loss = k1 * contrastive_loss + k2 * infonce_loss
    
    loss.backward()
    optimizer.step()
    return loss.item(), infonce_loss.item(), contrastive_loss.item()

def validate(model, val_loader, infonce_criterion, contrastive_criterion, device, k1=0.5, k2=0.5):
    """Perform validation on the validation set"""
    model.eval()
    val_loss = 0.0
    val_infonce_loss = 0.0
    val_contrastive_loss = 0.0
    
    with torch.no_grad():
        for batch in val_loader:
            perioculars = batch['perioculars'].view(-1, 1, 128, 128).to(device)
            foreheads = batch['foreheads'].view(-1, 1, 128, 128).to(device)
            irises = batch['irises'].view(-1, 1, 128, 128).to(device)
            labels = batch['labels'].view(-1).to(device)
            
            embeddings = model(perioculars, foreheads, irises)
            
            # Calculate both losses
            infonce_loss = infonce_criterion(embeddings, labels)
            contrastive_loss = contrastive_criterion(embeddings, labels)
            loss = k1 * contrastive_loss + k2 * infonce_loss
            
            val_loss += loss.item()
            val_infonce_loss += infonce_loss.item()
            val_contrastive_loss += contrastive_loss.item()
            
    avg_val_loss = val_loss / len(val_loader)
    avg_val_infonce_loss = val_infonce_loss / len(val_loader)
    avg_val_contrastive_loss = val_contrastive_loss / len(val_loader)
    
    return avg_val_loss, avg_val_infonce_loss, avg_val_contrastive_loss

## Embedding Visualization Function

Function to visualize the learned embeddings using t-SNE.

In [None]:
def visualize_embeddings(model, dataset, device, num_persons=10, num_samples=10, epoch=0):
    """Visualize embeddings using t-SNE"""
    model.eval()
    embeddings = []
    labels = []
    
    # Select first n persons
    person_ids = dataset.person_ids[:num_persons]
    
    with torch.no_grad():
        for person_id in person_ids:
            for i in range(num_samples):
                label = dataset.label_map[person_id]
                # Randomly select one iris, periocular, and forehead image
                iris_img_path = np.random.choice(dataset.iris_images[person_id])
                periocular_img_path = np.random.choice(dataset.periocular_images[person_id])
                forehead_img_path = np.random.choice(dataset.forehead_images[person_id])
                
                # Load and transform images
                transform = dataset.transform
                iris_img = transform(Image.open(iris_img_path).convert('L')).unsqueeze(0).to(device)
                periocular_img = transform(Image.open(periocular_img_path).convert('L')).unsqueeze(0).to(device)
                forehead_img = transform(Image.open(forehead_img_path).convert('L')).unsqueeze(0).to(device)
                
                # Get embedding
                emb = model(periocular_img, forehead_img, iris_img).cpu().numpy()
                embeddings.append(emb[0])
                labels.append(label)
    
    # Apply t-SNE
    tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings)-1))
    embeddings_2d = tsne.fit_transform(np.array(embeddings))
    
    # Plot
    plt.figure(figsize=(10, 8))
    for i in range(num_persons):
        idx = [j for j, l in enumerate(labels) if l == i]
        plt.scatter(embeddings_2d[idx, 0], embeddings_2d[idx, 1], label=f'Person {person_ids[i]}', s=10)
    
    plt.title(f'Embedding Visualization (Epoch {epoch+1})')
    plt.xlabel('t-SNE Component 1')
    plt.ylabel('t-SNE Component 2')
    plt.grid(True)
    plt.savefig(f'{params["save_dir"]}/embeddings_epoch_{epoch+1}.png')
    plt.show()

## Training Loop

Execute the training process.

In [None]:
# Initialize tracking variables
epochs = []
losses = []
infonce_losses = []
contrastive_losses = []
val_losses = []
val_infonce_losses = []
val_contrastive_losses = []

# Track best validation loss for model saving
best_val_loss = float('inf')

# Start training
print("Starting training...")
start_time = time.time()

num_epochs = params['epochs']
for epoch in range(num_epochs):
    epoch_start_time = time.time()

    # Training phase
    total_loss = 0
    total_infonce_loss = 0
    total_contrastive_loss = 0
    model.train()
    
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        # Reshape batch
        perioculars = batch['perioculars'].view(-1, 1, 128, 128)
        foreheads = batch['foreheads'].view(-1, 1, 128, 128)
        irises = batch['irises'].view(-1, 1, 128, 128)
        labels = batch['labels'].view(-1)
        
        # Train step
        loss, infonce_loss, contrastive_loss = train_step(
            model, optimizer, infonce_criterion, contrastive_criterion, 
            perioculars, foreheads, irises, labels, device, 
            k1=params['k1'], k2=params['k2']
        )
        
        total_loss += loss
        total_infonce_loss += infonce_loss
        total_contrastive_loss += contrastive_loss
    
    # Calculate average losses
    avg_loss = total_loss / len(train_loader)
    avg_infonce_loss = total_infonce_loss / len(train_loader)
    avg_contrastive_loss = total_contrastive_loss / len(train_loader)
    
    # Validation phase
    val_loss, val_infonce_loss, val_contrastive_loss = validate(
        model, val_loader, infonce_criterion, contrastive_criterion, 
        device, k1=params['k1'], k2=params['k2']
    )
    
    # Update learning rate based on validation loss
    scheduler.step(val_loss)
    
    # Track metrics
    epochs.append(epoch + 1)
    losses.append(avg_loss)
    infonce_losses.append(avg_infonce_loss)
    contrastive_losses.append(avg_contrastive_loss)
    val_losses.append(val_loss)
    val_infonce_losses.append(val_infonce_loss)
    val_contrastive_losses.append(val_contrastive_loss)
    
    # Print metrics
    print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}, Val Loss: {val_loss:.4f}")
    print(f"  InfoNCE Loss: {avg_infonce_loss:.4f}, Val InfoNCE Loss: {val_infonce_loss:.4f}")
    print(f"  Contrastive Loss: {avg_contrastive_loss:.4f}, Val Contrastive Loss: {val_contrastive_loss:.4f}")
    
    # Save model
    torch.save(model.state_dict(), f"{params['save_dir']}/model_latest_mpt.pt")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), f"{params['save_dir']}/model_latest_mpt_best.pt")
        print(f"  Model saved with validation loss: {best_val_loss:.4f}")
    
    # Save checkpoint at intervals
    if (epoch + 1) % params['save_interval'] == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': avg_loss,
            'val_loss': val_loss
        }, f"{params['save_dir']}/checkpoint_epoch_{epoch+1}.pt")
    
    # Visualize embeddings periodically
    if (epoch + 1) % 10 == 0 or epoch == 0:
        visualize_embeddings(model, test_dataset, device, num_persons=20, epoch=epoch)
    
    # Plot learning curves
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, losses, label='Total Training Loss')
    plt.plot(epochs, val_losses, label='Total Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Total Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(epochs, infonce_losses, label='InfoNCE Training')
    plt.plot(epochs, val_infonce_losses, label='InfoNCE Validation')
    plt.plot(epochs, contrastive_losses, label='Contrastive Training')
    plt.plot(epochs, val_contrastive_losses, label='Contrastive Validation')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Component Losses')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(f"{params['save_dir']}/learning_curves.png")
    plt.show()
    
    # Print time estimates
    epoch_time = time.time() - epoch_start_time
    avg_epoch_time = (time.time() - start_time) / (epoch + 1)
    remaining_time = avg_epoch_time * (num_epochs - (epoch + 1))
    print(f"Epoch {epoch+1} completed in {epoch_time:.2f}s. Estimated time left: {remaining_time/60:.2f} minutes.")

print("Training completed!")

## Final Evaluation

Evaluate the trained model and visualize final embeddings.

In [None]:
# Load best model for final evaluation
model.load_state_dict(torch.load(f"{params['save_dir']}/model_latest_mpt_best.pt"))
model.eval()

# Final validation
final_val_loss, final_val_infonce, final_val_contrastive = validate(
    model, val_loader, infonce_criterion, contrastive_criterion, 
    device, k1=params['k1'], k2=params['k2']
)

print(f"Final validation loss: {final_val_loss:.4f}")
print(f"Final validation InfoNCE loss: {final_val_infonce:.4f}")
print(f"Final validation Contrastive loss: {final_val_contrastive:.4f}")

# Generate final t-SNE visualization with more persons
visualize_embeddings(model, test_dataset, device, num_persons=30, num_samples=10, epoch=num_epochs)

print("Training and evaluation completed!")