In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os, random
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

from sklearn.metrics import (
    accuracy_score, precision_score, recall_score,
    f1_score, confusion_matrix, roc_curve, auc
)

import matplotlib.pyplot as plt
import seaborn as sns

IRIS_DIR = "/kaggle/input/multimodal-biometric-dataset-mulb/MULB dataset/iris dataset"

BATCH_SIZE = 16
EPOCHS = 30
LR = 1e-4

EMBED_DIM = 256          
MARGIN = 1.0             # Triplet margin
FREEZE_EPOCHS = 5        # Freeze backbone initially

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)


def collect_images(subject_dir):
    return [
        os.path.join(root, f)
        for root, _, files in os.walk(subject_dir)
        for f in files if f.lower().endswith((".jpg", ".png", ".jpeg"))
    ]

def build_subject_dict(base_dir):
    data = {}
    for s in sorted(os.listdir(base_dir)):
        p = os.path.join(base_dir, s)
        if os.path.isdir(p):
            imgs = collect_images(p)
            if len(imgs) >= 3:
                data[s] = imgs
    return data

# Dataset Preparation

iris_dict = build_subject_dict(IRIS_DIR)
subjects = list(iris_dict.keys())
random.shuffle(subjects)

split = int(0.8 * len(subjects))
train_subjects = subjects[:split]
val_subjects = subjects[split:]

print("Total subjects:", len(subjects))

# Data Augmentation

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.75, 1.0)),
    transforms.RandomRotation(10),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.05),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# Dataset

class IrisTripletDataset(Dataset):
    def __init__(self, subjects, iris_dict, transform):
        self.subjects = subjects
        self.iris_dict = iris_dict
        self.transform = transform

    def __len__(self):
        return len(self.subjects) * 5

    def __getitem__(self, idx):
        anchor_subject = random.choice(self.subjects)
        negative_subject = random.choice(
            [s for s in self.subjects if s != anchor_subject]
        )

        anchor, positive = random.sample(self.iris_dict[anchor_subject], 2)
        negative = random.choice(self.iris_dict[negative_subject])

        return (
            self.transform(Image.open(anchor).convert("RGB")),
            self.transform(Image.open(positive).convert("RGB")),
            self.transform(Image.open(negative).convert("RGB"))
        )

train_loader = DataLoader(
    IrisTripletDataset(train_subjects, iris_dict, train_transform),
    batch_size=BATCH_SIZE, shuffle=True
)

val_loader = DataLoader(
    IrisTripletDataset(val_subjects, iris_dict, val_transform),
    batch_size=BATCH_SIZE, shuffle=False
)

# Model Code

def get_backbone():
    model = models.efficientnet_b0(
        weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1
    )
    model.classifier = nn.Identity()
    return model

class ProjectionHead(nn.Module):
    def __init__(self, in_dim=1280, out_dim=EMBED_DIM):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, out_dim),
            nn.BatchNorm1d(out_dim)
        )

    def forward(self, x):
        return F.normalize(self.net(x), dim=1)

class IrisEmbeddingNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = get_backbone()
        self.proj = ProjectionHead()

    def forward(self, x):
        f = self.encoder(x)
        if f.ndim == 4:
            f = F.adaptive_avg_pool2d(f, 1).flatten(1)
        return self.proj(f)

# Triplet Loss

class TripletLoss(nn.Module):
    def __init__(self, margin=MARGIN):
        super().__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        d_pos = F.pairwise_distance(anchor, positive)
        d_neg = F.pairwise_distance(anchor, negative)
        return torch.mean(F.relu(d_pos - d_neg + self.margin))

# Training Function

model = IrisEmbeddingNet().to(DEVICE)
criterion = TripletLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

# Freeze backbone initially
for p in model.encoder.parameters():
    p.requires_grad = False

for epoch in range(EPOCHS):

    if epoch == FREEZE_EPOCHS:
        print("Unfreezing backbone...")
        for p in model.encoder.parameters():
            p.requires_grad = True

    model.train()
    epoch_loss = 0

    for a, p, n in train_loader:
        a, p, n = a.to(DEVICE), p.to(DEVICE), n.to(DEVICE)

        optimizer.zero_grad()
        za = model(a)
        zp = model(p)
        zn = model(n)

        loss = criterion(za, zp, zn)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    epoch_loss /= len(train_loader)
    print(f"Epoch [{epoch+1}/{EPOCHS}] Train Loss: {epoch_loss:.4f}")


# Evaluation (FAR, FRR, EER)

model.eval()
dists, labels = [], []

with torch.no_grad():
    for a, p, n in val_loader:
        a, p, n = a.to(DEVICE), p.to(DEVICE), n.to(DEVICE)

        za = model(a)
        zp = model(p)
        zn = model(n)

        # Genuine pairs
        dists.extend(F.pairwise_distance(za, zp).cpu().numpy())
        labels.extend([1] * za.size(0))

        # Impostor pairs
        dists.extend(F.pairwise_distance(za, zn).cpu().numpy())
        labels.extend([0] * za.size(0))

dists = np.array(dists)
labels = np.array(labels)

# ROC & EER
fpr, tpr, thresholds = roc_curve(labels, -dists)
far = fpr
frr = 1 - tpr

eer_idx = np.nanargmin(np.abs(far - frr))
eer = far[eer_idx]
eer_threshold = -thresholds[eer_idx]

y_pred = (dists < eer_threshold).astype(int)

# Metrics

acc = accuracy_score(labels, y_pred)
prec = precision_score(labels, y_pred)
rec = recall_score(labels, y_pred)
f1 = f1_score(labels, y_pred)
roc_auc = auc(fpr, tpr)

print("\n===== MULTIMODAL HAND + IRIS RESULTS =====")
print(f"Accuracy  : {acc:.4f}")
print(f"Precision : {prec:.4f}")
print(f"Recall    : {rec:.4f}")
print(f"F1-score  : {f1:.4f}")
print(f"ROC-AUC   : {roc_auc:.4f}")

print("\n===== BIOMETRIC ERROR RATES =====")
print(f"FAR : {far[eer_idx]:.4f}")
print(f"FRR : {frr[eer_idx]:.4f}")
print(f"EER : {eer:.4f}")

torch.save(model.state_dict(), "iris_triplet_efficientnet.pth")
print("Model saved successfully.")
