In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os, random
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

from sklearn.metrics import (
    accuracy_score, precision_score, recall_score,
    f1_score, roc_curve, auc, confusion_matrix
)


HAND_DIR = "/kaggle/input/multimodal-biometric-dataset-mulb/MULB dataset/hand dataset"
IRIS_DIR = "/kaggle/input/multimodal-biometric-dataset-mulb/MULB dataset/iris dataset"

BATCH_SIZE = 16
EPOCHS = 40
LR = 1e-4

EMBED_DIM = 256
MARGIN = 1.0

FREEZE_EPOCHS = 5
PATIENCE = 6
MIN_DELTA = 1e-4

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)


def collect_images(subject_dir):
    return [
        os.path.join(r, f)
        for r, _, files in os.walk(subject_dir)
        for f in files if f.lower().endswith((".jpg", ".png", ".jpeg"))
    ]

def build_subject_dict(base_dir):
    data = {}
    for s in sorted(os.listdir(base_dir)):
        p = os.path.join(base_dir, s)
        if os.path.isdir(p):
            imgs = collect_images(p)
            if len(imgs) >= 3:
                data[s] = imgs
    return data


hand_dict = build_subject_dict(HAND_DIR)
iris_dict = build_subject_dict(IRIS_DIR)

subjects = sorted(list(set(hand_dict.keys()) & set(iris_dict.keys())))
random.shuffle(subjects)

split = int(0.8 * len(subjects))
train_subjects = subjects[:split]
val_subjects = subjects[split:]

print("Total subjects:", len(subjects))

# Transformation

train_tf = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.75, 1.0)),
    transforms.RandomRotation(15),
    transforms.ColorJitter(0.3, 0.3, 0.3),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

val_tf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

# Dataset

class MultimodalTripletDataset(Dataset):
    def __init__(self, subjects, hand_dict, iris_dict, transform):
        self.subjects = subjects
        self.hand = hand_dict
        self.iris = iris_dict
        self.transform = transform

    def __len__(self):
        return len(self.subjects) * 5

    def __getitem__(self, idx):
        a_id = random.choice(self.subjects)
        n_id = random.choice([s for s in self.subjects if s != a_id])

        ah, ph = random.sample(self.hand[a_id], 2)
        ai, pi = random.sample(self.iris[a_id], 2)

        nh = random.choice(self.hand[n_id])
        ni = random.choice(self.iris[n_id])

        return (
            self.transform(Image.open(ah).convert("RGB")),
            self.transform(Image.open(ai).convert("RGB")),
            self.transform(Image.open(ph).convert("RGB")),
            self.transform(Image.open(pi).convert("RGB")),
            self.transform(Image.open(nh).convert("RGB")),
            self.transform(Image.open(ni).convert("RGB")),
        )

# Dataloders

train_loader = DataLoader(
    MultimodalTripletDataset(train_subjects, hand_dict, iris_dict, train_tf),
    batch_size=BATCH_SIZE, shuffle=True
)

val_loader = DataLoader(
    MultimodalTripletDataset(val_subjects, hand_dict, iris_dict, val_tf),
    batch_size=BATCH_SIZE, shuffle=False
)

# Model Code

def get_backbone():
    m = models.efficientnet_b0(
        weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1
    )
    m.classifier = nn.Identity()
    return m

class ProjectionHead(nn.Module):
    def __init__(self, in_dim=1280, out_dim=EMBED_DIM):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, out_dim),
            nn.BatchNorm1d(out_dim)
        )

    def forward(self, x):
        return F.normalize(self.net(x), dim=1)

class GatedFusion(nn.Module):
    def __init__(self, dim=EMBED_DIM):
        super().__init__()
        self.gate = nn.Sequential(
            nn.Linear(dim * 2, dim),
            nn.Sigmoid()
        )

    def forward(self, fh, fi):
        g = self.gate(torch.cat([fh, fi], dim=1))
        return g * fh + (1 - g) * fi

class MultimodalEmbeddingNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.hand_enc = get_backbone()
        self.iris_enc = get_backbone()
        self.hand_proj = ProjectionHead()
        self.iris_proj = ProjectionHead()
        self.gated_fusion = GatedFusion()

    def forward(self, hand, iris):
        fh = self.hand_enc(hand)
        fi = self.iris_enc(iris)
        if fh.ndim == 4:
            fh = F.adaptive_avg_pool2d(fh,1).flatten(1)
            fi = F.adaptive_avg_pool2d(fi,1).flatten(1)
        zh = self.hand_proj(fh)
        zi = self.iris_proj(fi)
        z = self.gated_fusion(zh, zi)
        return F.normalize(z, dim=1)

# Loss Function

class TripletLoss(nn.Module):
    def __init__(self, margin=MARGIN):
        super().__init__()
        self.margin = margin

    def forward(self, a, p, n):
        return torch.mean(
            F.relu(
                F.pairwise_distance(a, p)
                - F.pairwise_distance(a, n)
                + self.margin
            )
        )

# Training Function

model = MultimodalEmbeddingNet().to(DEVICE)
criterion = TripletLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

for p in list(model.hand_enc.parameters()) + list(model.iris_enc.parameters()):
    p.requires_grad = False

train_losses, val_accs = [], []
best_loss = float("inf")
patience_counter = 0
best_state = None

for epoch in range(EPOCHS):

    if epoch == FREEZE_EPOCHS:
        for p in list(model.hand_enc.parameters()) + list(model.iris_enc.parameters()):
            p.requires_grad = True

    model.train()
    running_loss = 0

    for ah, ai, ph, pi, nh, ni in train_loader:
        ah, ai, ph, pi, nh, ni = (
            ah.to(DEVICE), ai.to(DEVICE),
            ph.to(DEVICE), pi.to(DEVICE),
            nh.to(DEVICE), ni.to(DEVICE)
        )

        optimizer.zero_grad()
        loss = criterion(
            model(ah, ai),
            model(ph, pi),
            model(nh, ni)
        )
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    train_losses.append(epoch_loss)

    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for ah, ai, ph, pi, _, _ in val_loader:
            ah, ai, ph, pi = ah.to(DEVICE), ai.to(DEVICE), ph.to(DEVICE), pi.to(DEVICE)
            d = F.pairwise_distance(model(ah, ai), model(ph, pi))
            correct += (d < MARGIN).sum().item()
            total += len(d)
    val_acc = correct / total
    val_accs.append(val_acc)

    print(f"Epoch {epoch+1:02d} | Loss: {epoch_loss:.4f} | Val Acc: {val_acc:.4f}")

    if epoch_loss < best_loss - MIN_DELTA:
        best_loss = epoch_loss
        best_state = model.state_dict()
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print("Early stopping.")
            break

model.load_state_dict(best_state)

# Plots

cm = confusion_matrix(labels, y_pred)

plt.figure(figsize=(5,4))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=["Impostor","Genuine"],
            yticklabels=["Impostor","Genuine"])
plt.title("Confusion Matrix @ EER Threshold")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.show()

plt.figure()
plt.plot(train_losses, label="Train Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss Curve")
plt.legend()
plt.show()

plt.figure()
plt.plot(val_accs, label="Validation Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Validation Accuracy Curve")
plt.legend()
plt.show()

torch.save(model.state_dict(), "multimodal_hand_iris_gated_triplet.pth")
print("Model saved.")
