In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import os
from PIL import Image
import torch.nn.functional as F
from dataset import BiometricDataset
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np
from tqdm import tqdm
import time
from pathlib import Path


class CNNEmbedding(nn.Module):
    def __init__(self, embedding_dim=256, dropout_p=0.2):
        super(CNNEmbedding, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=2, dilation=2)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=2, dilation=2)

        self.res1 = nn.Conv2d(1, 64, kernel_size=1) if 1 != 64 else nn.Identity()
        self.res2 = nn.Conv2d(64, 128, kernel_size=1) if 64 != 128 else nn.Identity()
        self.res3 = nn.Conv2d(128, 256, kernel_size=1) if 128 != 256 else nn.Identity()
        self.res4 = nn.Conv2d(256, 512, kernel_size=1) if 256 != 512 else nn.Identity()

        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(256)
        self.bn4 = nn.BatchNorm2d(512)

        self.pool = nn.MaxPool2d(2, 2)
        self.adaptive_pool = nn.AdaptiveAvgPool2d(2)

        # Depthwise separable convolution
        self.depthwise_conv = nn.Conv2d(512, 512, kernel_size=3, padding=1, groups=512)
        self.pointwise_conv = nn.Conv2d(512, 512, kernel_size=1)
        self.depthwise_bn = nn.BatchNorm2d(512)

        # Squeeze-and-Excitation block
        self.se_block = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 512 // 16, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(512 // 16, 512, kernel_size=1),
            nn.Sigmoid()
        )

        self.fc = nn.Linear(512 * 2 * 2, embedding_dim)
        self.dropout = nn.Dropout(dropout_p)
        self.ln = nn.LayerNorm(embedding_dim)

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        identity = self.res1(x)
        x = self.bn1(self.conv1(x))
        x = F.relu(x + identity)
        x = self.pool(x)

        identity = self.res2(x)
        x = self.bn2(self.conv2(x))
        x = F.relu(x + identity)
        x = self.pool(x)

        identity = self.res3(x)
        x = self.bn3(self.conv3(x))
        x = F.relu(x + identity)
        x = self.pool(x)

        identity = self.res4(x)
        x = self.bn4(self.conv4(x))
        x = F.relu(x + identity)

        # Depthwise separable convolution
        x = self.depthwise_conv(x)
        x = self.pointwise_conv(x)
        x = self.depthwise_bn(x)
        x = F.relu(x)

        # Squeeze-and-Excitation
        se = self.se_block(x)
        x = x * se

        x = self.adaptive_pool(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.fc(x)
        x = self.ln(x)
        return x



class FusionTransformer(nn.Module):
    def __init__(self, embedding_dim=256, nhead=8, num_layers=4, dropout_p=0.1, num_modalities=3, num_prompts=3):
        super(FusionTransformer, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_modalities = num_modalities
        
        # Transformer encoder for global context
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embedding_dim, 
                                       nhead=nhead, 
                                       dim_feedforward=512, 
                                       dropout=dropout_p, 
                                       batch_first=True),
            num_layers=num_layers
        )

        # Positional encodings
        self.pos_encoder = nn.Parameter(torch.zeros(1, num_modalities + num_prompts, embedding_dim), requires_grad=True)
        
        # Layer norms for input and output
        self.ln_in = nn.LayerNorm(embedding_dim)
        self.ln_out = nn.LayerNorm(embedding_dim)
        
        # MPT: Learnable prompt embeddings
        self.num_prompts = num_prompts
        self.prompt_emb = nn.Parameter(torch.randn(num_prompts, embedding_dim))
        self.prompt_conv = nn.Conv1d(embedding_dim, embedding_dim, kernel_size=1)
        self.prompt_relu = nn.ReLU()
        
        # Final projection
        self.fc = nn.Linear(embedding_dim * (num_modalities + num_prompts), embedding_dim)
        self.dropout = nn.Dropout(dropout_p)
        
        # Initialize parameters
        self._init_parameters()
        
    def _init_parameters(self):
        # Initialize the prompt embeddings and positional encoder
        nn.init.xavier_uniform_(self.prompt_emb)
        nn.init.xavier_uniform_(self.pos_encoder)
        
    def forward(self, *embeddings):
        batch_size = embeddings[0].size(0)
        
        # Stack embeddings for transformer input
        modality_embeddings = torch.stack(embeddings, dim=1)  # [batch_size, num_modalities, embedding_dim]
        
        # Prepare prompt embeddings
        prompt = self.prompt_emb.unsqueeze(0).repeat(batch_size, 1, 1)  # [batch_size, num_prompts, embedding_dim]
        prompt = prompt.permute(0, 2, 1)  # [batch_size, embedding_dim, num_prompts]
        prompt = self.prompt_conv(prompt)  # [batch_size, embedding_dim, num_prompts]
        prompt = self.prompt_relu(prompt)
        prompt = prompt.permute(0, 2, 1)  # [batch_size, num_prompts, embedding_dim]
        
        # Concatenate modality embeddings and prompts
        x = torch.cat([modality_embeddings, prompt], dim=1)  # [batch_size, num_modalities + num_prompts, embedding_dim]
        
        # Apply layer norm and add positional encodings
        x = self.ln_in(x) + self.pos_encoder
        
        # Apply transformer for global context
        x = self.transformer(x)
        
        # Reshape and project to final embedding
        x = x.reshape(batch_size, -1)  # [batch_size, (num_modalities + num_prompts) * embedding_dim]
        x = self.dropout(x)
        x = self.fc(x)
        x = self.ln_out(x)
        return x



# Complete Model
class BiometricModel(nn.Module):
    def __init__(self, embedding_dim=256, num_modalities=3):
        super(BiometricModel, self).__init__()

        self.num_modalities = num_modalities

        self.periocular_cnn = CNNEmbedding(embedding_dim)
        self.forehead_cnn = CNNEmbedding(embedding_dim)
        self.iris_cnn = CNNEmbedding(embedding_dim)

        self.fusion_transformer = FusionTransformer(embedding_dim=256, num_modalities=num_modalities, num_prompts=3)
        
    def forward(self, periocular, forehead, iris):
        periocular_emb = self.periocular_cnn(periocular)
        forehead_emb = self.forehead_cnn(forehead)
        iris_emb = self.iris_cnn(iris)
        emb = [periocular_emb, forehead_emb, iris_emb]
        fused_emb = self.fusion_transformer(*emb)
        fused_emb = F.normalize(fused_emb, dim=1)
        return fused_emb
        

def load_image(image_path, transform):
    img = Image.open(image_path).convert('L')
    return transform(img)

def create_embedding_dicts(data_root, model, device='mps'):
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    modalities = ['periocular', 'forehead', 'iris']
    splits = ['test']
    train_dict = {}
    test_dict = {}

    model = model.to(device)
    model.eval()

    for split in splits:
        print("Doing: ", split)
        target_dict = train_dict if split == 'train' else test_dict
        # Get person IDs from directory
        person_ids = [f"{i:03d}" for i in range(1, 248)]
        for person_id in person_ids:
            print("Person ", person_id)
            target_dict[person_id] = []  # List to store 10 embeddings
            # Iterate over poses (1 to 10)
            for pose_idx in range(1, 11):
                # Load images for all three modalities for this person and pose
                images = []
                for modality in modalities:
                    img_path = Path(data_root) / modality / split / person_id
                    pose_images = [f for f in sorted(os.listdir(img_path)) if f!='.DS_Store']
                    if len(pose_images) < pose_idx:
                        print(f"Warning: {img_path} has only {len(pose_images)} poses, expected at least {pose_idx}")
                        break
                    img_name = pose_images[pose_idx - 1]  # Select the pose_idx-th image
                    img_path = img_path / img_name
                    if not img_path.exists():
                        print(f"Warning: {img_path} does not exist")
                        break
                    img = load_image(img_path, transform).to(device)
                    images.append(img)
                else:  # Only execute if all images are found
                    # Pass three images to the model
                    with torch.no_grad():
                        embedding = model(
                            images[0].unsqueeze(0),  # periocular (add batch dim)
                            images[1].unsqueeze(0),  # forehead
                            images[2].unsqueeze(0)   # iris
                        ).squeeze(0).cpu()  # Remove batch dim
                    target_dict[person_id].append(embedding)
                if len(images) != 3:  # Skip if any image was missing
                    continue

    return train_dict, test_dict


data_root = './dataset2'

model = BiometricModel(embedding_dim=256)

state_dict = torch.load('newmpt.pt')
model.load_state_dict(state_dict)
model.eval()

device = torch.device('cuda' if torch.cuda.is_available() else 'mps')

train_dict, test_dict = create_embedding_dicts(data_root, model, device)

2025-05-16 03:37:02.979833: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Doing:  test
Person  001
Person  002
Person  003
Person  004
Person  005
Person  006
Person  007
Person  008
Person  009
Person  010
Person  011
Person  012
Person  013
Person  014
Person  015
Person  016
Person  017
Person  018
Person  019
Person  020
Person  021
Person  022
Person  023
Person  024
Person  025
Person  026
Person  027
Person  028
Person  029
Person  030
Person  031
Person  032
Person  033
Person  034
Person  035
Person  036
Person  037
Person  038
Person  039
Person  040
Person  041
Person  042
Person  043
Person  044
Person  045
Person  046
Person  047
Person  048
Person  049
Person  050
Person  051
Person  052
Person  053
Person  054
Person  055
Person  056
Person  057
Person  058
Person  059
Person  060
Person  061
Person  062
Person  063
Person  064
Person  065
Person  066
Person  067
Person  068
Person  069
Person  070
Person  071
Person  072
Person  073
Person  074
Person  075
Person  076
Person  077
Person  078
Person  079
Person  080
Person  081
Person  082
Per

In [3]:
def evaluate_rank1(test_dict, gallery_size=5, device='cuda'):
    """
    Evaluate rank-1 recognition accuracy using cosine similarity.
    
    Args:
        test_dict: Dictionary with person IDs as keys and lists of embeddings as values
        gallery_size: Number of poses to use as gallery (remaining used as probe)
        device: Device to perform calculations on
    
    Returns:
        Rank-1 recognition rate
    """
    correct = 0
    total = 0
    
    # Convert embeddings to tensors and prepare gallery and probe sets
    gallery_embeddings = []
    gallery_labels = []
    probe_embeddings = []
    probe_labels = []
    
    for person_id, embeddings in test_dict.items():
        if len(embeddings) < gallery_size + 1:
            print(f"Skipping person {person_id} with only {len(embeddings)} embeddings")
            continue
            
        # Use first gallery_size embeddings as gallery
        for i in range(gallery_size):
            gallery_embeddings.append(embeddings[i])
            gallery_labels.append(person_id)
            
        # Use remaining embeddings as probe
        for i in range(gallery_size, len(embeddings)):
            probe_embeddings.append(embeddings[i])
            probe_labels.append(person_id)
    
    gallery_tensor = torch.stack(gallery_embeddings).to(device)
    probe_tensor = torch.stack(probe_embeddings).to(device)
    
    # Normalize embeddings for cosine similarity
    gallery_tensor = F.normalize(gallery_tensor, p=2, dim=1)
    probe_tensor = F.normalize(probe_tensor, p=2, dim=1)
    
    print(f"Gallery size: {len(gallery_tensor)} embeddings")
    print(f"Probe size: {len(probe_tensor)} embeddings")
    
    # Calculate similarities in batches to avoid OOM
    batch_size = 100
    correct = 0
    total = len(probe_tensor)
    
    with torch.no_grad():
        for i in tqdm(range(0, len(probe_tensor), batch_size)):
            batch_end = min(i + batch_size, len(probe_tensor))
            batch_probe = probe_tensor[i:batch_end]
            
            # Calculate cosine similarity between probe and gallery
            similarities = torch.mm(batch_probe, gallery_tensor.t())
            
            # Get the indices of the highest similarities
            _, indices = torch.max(similarities, dim=1)
            
            # Check if the prediction is correct
            for j in range(len(batch_probe)):
                probe_person = probe_labels[i + j]
                predicted_person = gallery_labels[indices[j].item()]
                
                if probe_person == predicted_person:
                    correct += 1
    
    # Calculate rank-1 recognition rate
    rank1_rate = correct / total if total > 0 else 0
    return rank1_rate, correct, total

# Add this code after loading the model and creating embeddings
print("Evaluating rank-1 recognition performance...")

# Test with different gallery sizes
for gallery_size in [5]:
    rank1_rate, c, t = evaluate_rank1(test_dict, gallery_size=gallery_size, device=device)
    print(f"Rank-1 recognition rate with gallery size {gallery_size}: {rank1_rate:.4f} ({rank1_rate*100:.2f}%) {c}/{t}")


Evaluating rank-1 recognition performance...
Skipping person 212 with only 4 embeddings
Gallery size: 1230 embeddings
Probe size: 1226 embeddings


100%|██████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 651.39it/s]

Rank-1 recognition rate with gallery size 5: 0.7488 (74.88%) 918/1226



