# Triplet Distil Net Embedding Generator

**Note:** In this notebook we can experiment with 

- Embedding Sizes
- Custom Loss Functions
- Epochs and Learning Rate
- contact finetuning


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import random
import os, glob
import numpy as np
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd
import re
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE

# Semi-Hard Mining

In [4]:
def generate_triplets_from_folders(s1_dir, s2_dir, split=(0.7, 0.15, 0.15), seed=42, triplets_per_id=2):
    """
    Generate multiple triplets per identity (anchor, positive, negative) from s1, s2 directories.
    Returns dict of train/val/test triplet lists.
    """
    random.seed(seed)
    s1_images = sorted(glob.glob(os.path.join(s1_dir, "*")))
    s2_images = sorted(glob.glob(os.path.join(s2_dir, "*")))

    # could change depending on how you've named your dataset
    def basename_noext(p): return os.path.splitext(os.path.basename(p))[0]
    id_to_paths = {}
    for p in s1_images:
        id_to_paths[basename_noext(p)] = [p, None]
    for p in s2_images:
        id_ = basename_noext(p)
        if id_ in id_to_paths:
            id_to_paths[id_][1] = p

    valid_ids = [k for k, v in id_to_paths.items() if v[0] and v[1]]
    valid_ids.sort()
    print(f"Found {len(valid_ids)} identities with images in both s1 and s2")

    triplets = []
    for anchor_id in valid_ids:
        anchor = id_to_paths[anchor_id][0]
        positive = id_to_paths[anchor_id][1]
        neg_ids = [i for i in valid_ids if i != anchor_id]

        for _ in range(triplets_per_id):
            negative_id = random.choice(neg_ids)
            negative = id_to_paths[negative_id][1]
            triplets.append((anchor, positive, negative))

    random.shuffle(triplets)

    n = len(triplets)
    n_train = int(split[0] * n)
    n_val = int(split[1] * n)
    train_triplets = triplets[:n_train]
    val_triplets = triplets[n_train:n_train+n_val]
    test_triplets = triplets[n_train+n_val:]

    print(f"Total triplets: {len(triplets)} -> train {len(train_triplets)}, val {len(val_triplets)}, test {len(test_triplets)}")
    return {"train": train_triplets, "val": val_triplets, "test": test_triplets}


In [None]:
# could change depending on how you've named your dataset
def get_last_two_parts(path):
    """Return the last two components of a path as a string."""
    parts = os.path.normpath(path).split(os.sep)
    return os.path.join(*parts[-2:])

class TripletDatasetWithTeacher(Dataset):
    """
    triplets: list of tuples (anchor_path, positive_path, negative_path)
    teacher_embeddings: np.array of shape (num_images, emb_dim)
        Order must match s1/s2 dirs: even=s1, odd=s2
    """
    def __init__(self, triplets, s1_dir, s2_dir, teacher_embeddings, transform=None):
        self.triplets = triplets
        self.transform = transform
        self.teacher_embeddings = teacher_embeddings
        s1_paths = sorted(glob.glob(os.path.join(s1_dir, "*")))
        s2_paths = sorted(glob.glob(os.path.join(s2_dir, "*")))

        self.path_to_idx = {}
        for i, path in enumerate(s1_paths):
            self.path_to_idx[get_last_two_parts(path)] = 2*i  # even idx for s1
        for i, path in enumerate(s2_paths):
            self.path_to_idx[get_last_two_parts(path)] = 2*i + 1  # odd idx for s2

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        a_path, p_path, n_path = self.triplets[idx]

        # Load images
        anchor = Image.open(a_path)
        positive = Image.open(p_path)
        negative = Image.open(n_path)

        anchor = Image.open(a_path).convert('L')  
        positive = Image.open(p_path).convert('L')
        negative = Image.open(n_path).convert('L')

        if self.transform:
            anchor = self.transform(anchor)
            positive = self.transform(positive)
            negative = self.transform(negative)

        # Fetch teacher embeddings
        key_a = get_last_two_parts(a_path)
        key_p = get_last_two_parts(p_path)
        key_n = get_last_two_parts(n_path)

        t_a = torch.tensor(self.teacher_embeddings[self.path_to_idx[key_a]], dtype=torch.float32)
        t_p = torch.tensor(self.teacher_embeddings[self.path_to_idx[key_p]], dtype=torch.float32)
        t_n = torch.tensor(self.teacher_embeddings[self.path_to_idx[key_n]], dtype=torch.float32)

        # L2-normalize
        t_a = F.normalize(t_a, dim=0)
        t_p = F.normalize(t_p, dim=0)
        t_n = F.normalize(t_n, dim=0)

        return anchor, positive, negative, t_a, t_p, t_n
    
class TripletDataset(Dataset):
    """
    triplets: list of tuples (anchor_path, positive_path, negative_path)
    """
    def __init__(self, triplets, transform=None):
        self.triplets = triplets
        self.transform = transform

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        a_path, p_path, n_path = self.triplets[idx]

        anchor = Image.open(a_path) 
        positive = Image.open(p_path)
        negative = Image.open(n_path)

        anchor = Image.open(a_path).convert('L') 
        positive = Image.open(p_path).convert('L')
        negative = Image.open(n_path).convert('L')

        if self.transform:
            anchor = self.transform(anchor)
            positive = self.transform(positive)
            negative = self.transform(negative)

        return anchor, positive, negative


# Transforms (strong but safe for ridge details)
def get_train_transforms(img_size=512):
    return transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomApply([
            transforms.RandomRotation(degrees=15, fill=0)
        ], p=0.9),
        transforms.RandomApply([transforms.RandomAffine(degrees=0, translate=(0.05,0.05))], p=0.6),
        transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1,1.0))], p=0.2),
        transforms.RandomApply([transforms.ColorJitter(brightness=0.05, contrast=0.05)], p=0.5),
        transforms.ToTensor(),  
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

def get_eval_transforms(img_size=512):
    return transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])


In [None]:
# Build triplets list: (anchor, positive, negative)
s1_dir = "<DIRNAME>"
s2_dir = "<DIRNAME>"

## TripletDistilNet Architecture

In [8]:
class TripletNet(nn.Module):
    def __init__(self, embedding_dim=256, pretrained=False):
        super().__init__()
        resnet = models.resnet18(pretrained=pretrained)
        resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.backbone = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4,
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2),
            nn.Linear(256, embedding_dim)
        )

    def forward_once(self, x):
        x = self.backbone(x)
        x = x.view(x.size(0), -1)
        x = self.head(x)
        x = F.normalize(x, p=2, dim=1)
        return x
    
    def forward(self, anchor, positive, negative):
        e_a = self.forward_once(anchor)
        e_p = self.forward_once(positive)
        e_n = self.forward_once(negative)
        return e_a, e_p, e_n

class TripletDistillLoss(nn.Module):
    def __init__(self, margin=0.3, alpha=0.7):
        super().__init__()
        self.margin = margin
        self.alpha = alpha

    def forward(self, s_a, s_p, s_n, t_a, t_p, t_n):
        cos_ap = F.cosine_similarity(s_a, s_p)
        cos_an = F.cosine_similarity(s_a, s_n)
        triplet_loss = torch.clamp(cos_an - cos_ap + self.margin, min=0.0).mean()
        distill_loss = ((1 - F.cosine_similarity(s_a, t_a)).mean() +
                        (1 - F.cosine_similarity(s_p, t_p)).mean() +
                        (1 - F.cosine_similarity(s_n, t_n)).mean()) / 3.0
        return triplet_loss + self.alpha * distill_loss

In [None]:
prev_model = "previously_trained_model.pt"

model = TripletNet(embedding_dim=256, pretrained=True)
device = "mps"  
model.load_state_dict(torch.load(prev_model, map_location=device))
model.to(device)
model.eval()

## Create semi-hard triplets

In [10]:
splits = generate_triplets_from_folders(s1_dir, s2_dir, triplets_per_id=1)
transform_eval  = get_eval_transforms(256)
all_data = splits["train"] + splits["val"] + splits["test"]
dataset_test  = TripletDataset(all_data,  transform=transform_eval)
loader_test  = DataLoader(dataset_test, batch_size=32, shuffle=False, num_workers=0)

all_embeddings = []

with torch.no_grad():
    for a, p, n in tqdm(loader_test, desc="Testing"):
        a, p, n = a.to(device), p.to(device), n.to(device)
        e_a, e_p, e_n = model(a, p, n)
        all_embeddings.append(e_a.cpu().numpy())
        all_embeddings.append(e_p.cpu().numpy())

all_embeddings = np.concatenate(all_embeddings, axis=0)

Found 785 identities with images in both s1 and s2
Total triplets: 785 -> train 549, val 117, test 119


Testing: 100%|██████████| 25/25 [00:58<00:00,  2.33s/it]


In [11]:
def pairwise_distances(embeddings):
    """Compute pairwise squared Euclidean distances between embeddings."""
    dot_product = torch.matmul(embeddings, embeddings.t())
    square_norm = torch.diag(dot_product)
    distances = square_norm.unsqueeze(1) - 2 * dot_product + square_norm.unsqueeze(0)
    distances = torch.clamp(distances, min=0.0)
    return torch.sqrt(distances + 1e-8)

def semi_hard_triplet_mining_multiple(embeddings, labels, margin=0.2, triplets_per_anchor=6):
    """
    Offline semi-hard mining with multiple triplets per anchor.
    - Anchor always from S1 (even indices)
    - Positive from S2 (same ID)
    - Negatives from S2 of different IDs
    """
    device = embeddings.device
    distances = pairwise_distances(embeddings)
    triplets = []

    num_embeddings = embeddings.shape[0]
    for anchor_idx in range(0, num_embeddings, 2):  # only S1 anchors
        anchor_label = labels[anchor_idx].item()
        positive_idx = anchor_idx + 1

        # all S2 indices for negatives
        neg_indices = torch.arange(1, num_embeddings, 2, device=device)
        neg_indices = neg_indices[labels[neg_indices] != anchor_label]

        d_ap = distances[anchor_idx, positive_idx].item()

        # semi-hard negatives: d_ap < d_an < d_ap + margin
        semi_hard_negatives = neg_indices[
            (distances[anchor_idx, neg_indices] > d_ap) &
            (distances[anchor_idx, neg_indices] < d_ap + margin)
        ]

        # select up to `triplets_per_anchor` negatives
        chosen_negatives = []

        if len(semi_hard_negatives) >= triplets_per_anchor:
            chosen_negatives = np.random.choice(semi_hard_negatives.cpu().numpy(),
                                               triplets_per_anchor, replace=False)
        else:
            # use all semi-hard negatives
            chosen_negatives = semi_hard_negatives.cpu().numpy().tolist()
            if len(chosen_negatives) < triplets_per_anchor:
                remaining = triplets_per_anchor - len(chosen_negatives)
                hard_negatives = neg_indices[torch.argsort(distances[anchor_idx, neg_indices])]
                for hn in hard_negatives.cpu().numpy():
                    if hn not in chosen_negatives:
                        chosen_negatives.append(hn)
                    if len(chosen_negatives) == triplets_per_anchor:
                        break

        for n_idx in chosen_negatives:
            triplets.append((anchor_idx, positive_idx, int(n_idx)))

    return triplets

In [None]:
num_ids = n  # number of identities
all_labels = []

for i in range(num_ids):
    all_labels.append(i)  # S1
    all_labels.append(i)  # S2

all_labels = torch.tensor(all_labels)  
all_embeddings = torch.tensor(all_embeddings, dtype=torch.float32)

In [13]:
semi_hard_triplets = semi_hard_triplet_mining_multiple(all_embeddings, all_labels, margin=0.2, triplets_per_anchor=6)
print(f"Generated {len(semi_hard_triplets)} semi-hard triplets")

Generated 4710 semi-hard triplets


In [14]:
s1_images = sorted(glob.glob(os.path.join(s1_dir, "*")))
s2_images = sorted(glob.glob(os.path.join(s2_dir, "*")))
interleaved = [val for pair in zip(s1_images, s2_images) for val in pair]

In [15]:
triplet_paths = []
for a_idx, p_idx, n_idx in semi_hard_triplets:
    anchor_path   = interleaved[a_idx]
    positive_path = interleaved[p_idx]
    negative_path = interleaved[n_idx]
    triplet_paths.append((anchor_path, positive_path, negative_path))

print(f"Generated {len(triplet_paths)} triplets with paths")  

Generated 4710 triplets with paths


## Dataset creation for training

In [16]:
# Shuffle all triplets
random.seed(42)
random.shuffle(triplet_paths)

split = (0.7, 0.15, 0.15)
n = len(triplet_paths)
n_train = int(split[0] * n)
n_val   = int(split[1] * n)

train_triplets = triplet_paths[:n_train]
val_triplets   = triplet_paths[n_train:n_train+n_val]
test_triplets  = triplet_paths[n_train+n_val:]

print(f"Train triplets: {len(train_triplets)}, Val triplets: {len(val_triplets)}, Test triplets: {len(test_triplets)}")

Train triplets: 3297, Val triplets: 706, Test triplets: 707


In [None]:
transform_train = get_train_transforms(512)
transform_eval  = get_eval_transforms(512)

teacher_embeddings1 = np.load("<path_to_session1_teacher_emb>")
teacher_embeddings2 = np.load("<path_to_session1_teacher_emb>")

teacher_embeddings = np.concatenate([teacher_embeddings1, teacher_embeddings2], axis=1)
pca = PCA(n_components=256)
teacher_embeddings_256 = pca.fit_transform(teacher_embeddings1)  # [num_samples, 256]

dataset_train = TripletDatasetWithTeacher(train_triplets, s1_dir, s2_dir, teacher_embeddings_256, transform=transform_train)
dataset_val   = TripletDatasetWithTeacher(val_triplets, s1_dir, s2_dir, teacher_embeddings_256,   transform=transform_eval)
dataset_test  = TripletDatasetWithTeacher(test_triplets,  s1_dir, s2_dir, teacher_embeddings_256, transform=transform_eval)

loader_train = DataLoader(dataset_train, batch_size=16, shuffle=True, num_workers=0, drop_last=True)
loader_val   = DataLoader(dataset_val, batch_size=32, shuffle=False, num_workers=0, drop_last=False)
loader_test  = DataLoader(dataset_test, batch_size=32, shuffle=False, num_workers=0, drop_last=False)

(1570, 256)


In [None]:
def print_triplet_distribution(dataset, name="Dataset"):
    n = len(dataset)
    print(f"{name} contains {n} triplets (anchor, positive, negative)")

# Print distributions
print_triplet_distribution(dataset_train, "Train")
print_triplet_distribution(dataset_val, "Validation")
print_triplet_distribution(dataset_test, "Test")

Train contains 3297 triplets (anchor, positive, negative)
Validation contains 706 triplets (anchor, positive, negative)
Test contains 707 triplets (anchor, positive, negative)


# Train

In [None]:
train_losses = []
val_losses = []
def train_triplet_distill(model, train_loader, val_loader, device,
                           epochs=30, lr=1e-4, weight_decay=1e-5,
                           margin=0.3, alpha=0.7):
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)
    criterion = TripletDistillLoss(margin=margin, alpha=alpha)
    best_auc = 0.0

    for epoch in range(1, epochs+1):
        # Training
        model.train()
        running_loss = 0.0
        for a, p, n, t_a, t_p, t_n in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs} [Train]"):
            a, p, n = a.to(device), p.to(device), n.to(device)
            t_a, t_p, t_n = t_a.to(device), t_p.to(device), t_n.to(device)

            s_a, s_p, s_n = model(a, p, n)
            loss = criterion(s_a, s_p, s_n, t_a, t_p, t_n)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * a.size(0)

        avg_train_loss = running_loss / len(train_loader.dataset)
        train_losses.append(avg_train_loss)

        # Validation
        model.eval()
        val_loss_total = 0.0
        with torch.no_grad():
            for a, p, n, t_a, t_p, t_n in tqdm(val_loader, desc=f"Epoch {epoch}/{epochs} [Val]"):
                a, p, n = a.to(device), p.to(device), n.to(device)
                t_a, t_p, t_n = t_a.to(device), t_p.to(device), t_n.to(device)

                s_a, s_p, s_n = model(a, p, n)
                val_loss = criterion(s_a, s_p, s_n, t_a, t_p, t_n)
                val_loss_total += val_loss.item() * a.size(0)

        avg_val_loss = val_loss_total / len(val_loader.dataset)
        val_losses.append(avg_val_loss)
        val_auc = evaluate_triplet_auc_with_teacher(model, val_loader, device)

        scheduler.step(val_auc)
        print(f"Epoch {epoch:03d} | TrainLoss={avg_train_loss:.4f} | ValLoss={avg_val_loss:.4f} | ValAUC={val_auc:.4f}")

        if val_auc > best_auc:
            best_auc = val_auc
            torch.save(model.state_dict(), "<save_model_path>.pt")
            print("saved best model")

    print("Training complete. Best AUC:", best_auc)


def evaluate_triplet_auc_with_teacher(model, loader, device):
    model.eval()
    labels = []
    scores = []
    with torch.no_grad():
        for a, p, n, t_a, t_p, t_n in loader:
            a, p, n = a.to(device), p.to(device), n.to(device)
            e_a, e_p, e_n = model(a, p, n)

            # positive pairs
            sim_pos = F.cosine_similarity(e_a, e_p).cpu().numpy()
            labels.extend([1]*len(sim_pos))
            scores.extend(sim_pos.tolist())

            # negative pairs
            sim_neg = F.cosine_similarity(e_a, e_n).cpu().numpy()
            labels.extend([0]*len(sim_neg))
            scores.extend(sim_neg.tolist())

    try:
        auc = roc_auc_score(np.array(labels), np.array(scores))
    except:
        auc = 0.5
    return auc

In [None]:
device = "mps"  # mention your gpu or cpu
model = TripletNet(embedding_dim=256, pretrained=True)

train_triplet_distill(
    model,
    train_loader=loader_train,
    val_loader=loader_val,
    device=device,
    epochs=30,
    lr=1e-4,
    margin=0.6,
    alpha=0.01  # weight for distillation
)

In [None]:
epochs = range(1, len(train_losses) + 1)

plt.figure(figsize=(6,4))
plt.plot(epochs, train_losses, '-', label='Train Loss', color='blue')
plt.plot(epochs, val_losses, '-', label='Validation Loss', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Train vs Validation Loss')
plt.legend()
plt.show()

# Test

In [None]:
def test_triplet_cosine(model, loader_test, device):
    # Load best model
    model.load_state_dict(torch.load("<save_model_path>.pt", map_location=device))
    
    model.to(device)
    model.eval()

    labels = []
    scores = []

    with torch.no_grad():
        for a, p, n in tqdm(loader_test, desc="Testing"):
            a, p, n = a.to(device), p.to(device), n.to(device)
            e_a, e_p, e_n = model(a, p, n)

            # Cosine similarity for positive (anchor-positive) pairs
            sim_pos = F.cosine_similarity(e_a, e_p).cpu().numpy()
            labels.extend([1] * len(sim_pos))
            scores.extend(sim_pos.tolist())

            # Cosine similarity for negative (anchor-negative) pairs
            sim_neg = F.cosine_similarity(e_a, e_n).cpu().numpy()
            labels.extend([0] * len(sim_neg))
            scores.extend(sim_neg.tolist())

    labels = np.array(labels)
    scores = np.array(scores)

    #  Compute metrics 
    try:
        auc_val = roc_auc_score(labels, scores)
    except:
        auc_val = 0.5

    fpr, tpr, thresholds = roc_curve(labels, scores)
    fnr = 1 - tpr
    eer_idx = np.nanargmin(np.abs(fnr - fpr))
    eer = (fpr[eer_idx] + fnr[eer_idx]) / 2.0
    thresh = thresholds[eer_idx]

    preds = (scores >= thresh).astype(int)
    acc = np.mean(preds == labels)

    # Print metrics
    print("\n========== TEST RESULTS ==========")
    print(f"AUC:       {auc_val:.4f}")
    print(f"EER:       {eer*100:.2f}%")
    # print(f"Thr@EER:   {thresh:.4f}")
    print(f"Accuracy:  {acc*100:.2f}%")
    print("=================================\n")

    # Plot ROC
    plt.figure(figsize=(5, 5))
    plt.plot(fpr, tpr, label=f"ROC (AUC={auc_val:.4f})")
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Test ROC Curve")
    plt.legend()
    plt.grid(True)
    plt.show()

    return auc_val, eer, acc


In [None]:
model = TripletNet(embedding_dim=256, pretrained=True)
device = "mps"  
auc_val, eer, acc, df_preds = test_triplet_cosine(model, loader_test, device)

# Contact Finetuning

### Load data

In [None]:
class ContactTripletDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        #  Load files and group by ID -> would change accoridng to your naming
        self.id_to_images = {}  # { "1": ["1_1.tif", "1_2.tif"], ... }

        for fname in sorted(os.listdir(root_dir)):
            if not (fname.lower().endswith(".tif") or fname.lower().endswith(".png")):
                continue

            parts = fname.split("_")
            if len(parts) != 2:
                continue

            id_str = parts[0]  
            self.id_to_images.setdefault(id_str, []).append(fname)

        self.ids = list(self.id_to_images.keys())

        # Filter: only IDs with >=2 images can produce anchor + positive
        self.ids = [id_ for id_ in self.ids if len(self.id_to_images[id_]) >= 2]

    def __len__(self):
        return len(self.ids) * 3  # arbitrary multiplier for more triplets

    def load_image(self, path):
        img = Image.open(path).convert("L")
        if self.transform:
            img = self.transform(img)
        return img

    def __getitem__(self, index):
        #  Choose anchor ID 
        anchor_id = random.choice(self.ids)
        imgs = self.id_to_images[anchor_id]

        anchor_name, positive_name = random.sample(imgs, 2)

        neg_id = random.choice([i for i in self.ids if i != anchor_id])
        neg_name = random.choice(self.id_to_images[neg_id])

        # Load images
        anchor = self.load_image(os.path.join(self.root_dir, anchor_name))
        positive = self.load_image(os.path.join(self.root_dir, positive_name))
        negative = self.load_image(os.path.join(self.root_dir, neg_name))

        return anchor, positive, negative

In [None]:
class ContactlessTripletDataset(Dataset):
    """
    Creates triplets from two directories:
        dir1 = S1 images
        dir2 = S2 images
    Filenames must match: e.g., P10_LF2.tif in both dirs.
    
    Anchor  = img from S1
    Positive = matching img from S2
    Negative = img from S2 but from DIFFERENT person (different PXX)
    """

    def __init__(self, dir1, dir2, transform=None):
        self.dir1 = dir1
        self.dir2 = dir2
        self.transform = transform

        exts = (".tif", ".tiff", ".png", ".bmp", ".jpg")

        # load & sort
        files1 = [
            f for f in os.listdir(dir1)
            if f.lower().endswith(exts)
        ]
        files1 = sorted(files1)

        # store valid pairs
        self.pairs = []   # list of (id_string, path_S1, path_S2)

        for fname in files1:
            f1 = os.path.join(dir1, fname)
            f2 = os.path.join(dir2, fname)

            if os.path.exists(f2):
                id_name = os.path.splitext(fname)[0]
                self.pairs.append((id_name, f1, f2))

        # group by ID for sampling
        # ID = P<number>
        self.id_to_indices = {}   # { "P10": [idx1, idx2, ...] }

        for idx, (id_name, f1, f2) in enumerate(self.pairs):
            # extract "P10" out of "P10_LF2"
            base_id = id_name.split("_")[0]
            self.id_to_indices.setdefault(base_id, []).append(idx)

        self.ids = list(self.id_to_indices.keys())

    def __len__(self):
        return len(self.pairs)

    def _load_grayscale(self, path):
        img = Image.open(path)
        arr = np.array(img)
        arr = np.clip(arr, 0, 1)  
        arr = (arr * 255).astype(np.uint8)
        img = Image.fromarray(arr, mode='L')
        if self.transform:
            img = self.transform(img)
        return img

    def __getitem__(self, index):
        # Anchor and Positive pair (S1–S2 matching)
        id_name, f1, f2 = self.pairs[index]

        # parse person ID
        anchor_id = id_name.split("_")[0]

        anchor = self._load_grayscale(f1)
        positive = self._load_grayscale(f2)

        # Negative Sampling
        neg_id = random.choice([i for i in self.ids if i != anchor_id])
        neg_idx = random.choice(self.id_to_indices[neg_id])

        _, _, neg_path = self.pairs[neg_idx]  
        negative = self._load_grayscale(neg_path)

        return anchor, positive, negative

def get_eval_transforms(img_size=512):
    return transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

In [None]:
trans = get_eval_transforms(512)
dataset = ContactTripletDataset("<your_path>", transform=trans)
dataset_cl = ContactlessTripletDataset("<your_s1_path>", "<your_s2_path>", transform=trans)
loader = DataLoader(dataset, batch_size=8, shuffle=True, drop_last=True)
loader_cl = DataLoader(dataset_cl, batch_size=8, shuffle=True, drop_last=True)

### Define models

In [None]:
teacher = TripletNet(embedding_dim=256, pretrained=True).to(device)
teacher.load_state_dict(torch.load("<load_pt_path>"))

for p in teacher.parameters():
    p.requires_grad = False

teacher.eval()   # freeze completely
for name, p in teacher.named_parameters():
    print(name, p.requires_grad)

In [None]:
student = TripletNet(embedding_dim=256, pretrained=True).to(device)
student.load_state_dict(torch.load("<load_pt_path>"))
# freeze everything first
for name, p in student.named_parameters():
    p.requires_grad = False

# Unfreeze as required
for idx in [0, 1, 2, 3]:
    for p in student.backbone[idx].parameters():
        p.requires_grad = True

for i in range(4):
    student.backbone[i].train()
for i in range(4,8):
    student.backbone[i].eval()
student.head.eval()
for name, p in student.named_parameters():
    print(name, p.requires_grad)

### Finetune Loop

In [None]:
mse_loss = nn.MSELoss()
triplet_loss = nn.TripletMarginLoss(margin=0.4, p=2) 
lambda_distill = 0.9   
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, student.parameters()),
    lr=1e-4,
    weight_decay=1e-5
)
contactless_loader = loader_cl
contact_loader = loader

In [None]:
for epoch in range(10):
    total_loss = 0

    contactless_iter = iter(contactless_loader)

    for anchor, pos, neg in tqdm(contact_loader):

        # 1. Forward on student
        anchor = anchor.to(device)
        pos = pos.to(device)
        neg = neg.to(device)
        e_a, e_p, e_n = student(anchor, pos, neg)

        loss_triplet = triplet_loss(e_a, e_p, e_n)

        # 2. Get a batch of contactless and compute distillation loss
        try:
            cl_imgs, _, _ = next(contactless_iter)
        except StopIteration:
            contactless_iter = iter(contactless_loader)
            cl_imgs, _, _ = next(contactless_iter)

        cl_imgs = cl_imgs.to(device)

        with torch.no_grad():
            teacher_emb = teacher.forward_once(cl_imgs)

        student_emb = student.forward_once(cl_imgs)
        loss_distill = mse_loss(student_emb, teacher_emb)

        # 3. Total loss
        loss =  loss_triplet + lambda_distill * loss_distill

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}: Loss = {total_loss/len(contact_loader):.4f}")


In [None]:
torch.save(student.state_dict(), "<save_path>.pt")