# Setup Dependencies

In [1]:
!pip install timm==0.9.12
!git clone https://gitlab.idiap.ch/bob/bob.paper.tbiom2023_edgeface.git
%cd bob.paper.tbiom2023_edgeface

Collecting timm==0.9.12
  Downloading timm-0.9.12-py3-none-any.whl.metadata (60 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.6/60.6 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.7->timm==0.9.12)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.7->timm==0.9.12)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=1.7->timm==0.9.12)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch>=1.7->timm==0.9.12)
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch>=1.7->timm==0.9.12)
  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manyl

In [2]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import ReduceLROnPlateau
from backbones import get_model , replace_linear_with_lowrank_2, get_timmfrv2
import numpy as np
import os

# LFW Dataset Filtration 

In [3]:
import os
from torchvision import datasets
from torchvision import transforms
from collections import defaultdict
import shutil

def filter_classes_by_min_images(root_dir, write_dir, min_images=5):
    filtered_dir = os.path.join(write_dir, f'filtered_min{min_images}')
    os.makedirs(filtered_dir, exist_ok=True)

    for class_name in os.listdir(root_dir):
        class_path = os.path.join(root_dir, class_name)
        if not os.path.isdir(class_path):
            continue
        images = os.listdir(class_path)
        if len(images) >= min_images:
            os.makedirs(os.path.join(filtered_dir, class_name), exist_ok=True)
            for img in images:
                src = os.path.join(class_path, img)
                dst = os.path.join(filtered_dir, class_name, img)
                if not os.path.exists(dst):  # Avoid overwrite
                    shutil.copy2(src, dst)

    return filtered_dir

# Example usage:
root_dir = "/kaggle/input/lfw-dataset/lfw-deepfunneled/lfw-deepfunneled/"
write_dir= "/kaggle/working/"
min_required_images = 10
filtered_root = filter_classes_by_min_images(root_dir,write_dir, min_required_images)


In [13]:
filtered_root = '/kaggle/working/filtered_min10'

In [4]:
!ls /kaggle/working/

bob.paper.tbiom2023_edgeface  filtered_min10


# DataLoaders Setup

In [5]:
def getDataset(root_dir):
    transform = transforms.Compose([
            transforms.Resize((112, 112)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
        ])
        
    dataset = datasets.ImageFolder(root=root_dir, transform=transform)
    return dataset

## Class-Aware Dataset Partitioning

In [6]:
from collections import defaultdict
from torch.utils.data import Subset, DataLoader
import random

def stratified_split(dataset, val_ratio=0.1, test_ratio=0.1, seed=42):
    label_to_indices = defaultdict(list)
    for idx, (_, label) in enumerate(dataset):
        label_to_indices[label].append(idx)

    train_indices, val_indices, test_indices = [], [], []
    rng = random.Random(seed)

    for label, indices in label_to_indices.items():
        rng.shuffle(indices)
        total = len(indices)

        val_count = max(2, int(total * val_ratio))
        test_count = max(2, int(total * test_ratio))
        train_count = total - val_count - test_count

        if train_count < 1:
            train_count, val_count, test_count = max(1, total - 2), 1, 1

        train_indices += indices[:train_count]
        val_indices += indices[train_count:train_count + val_count]
        test_indices += indices[train_count + val_count:]

    return Subset(dataset, train_indices), Subset(dataset, val_indices), Subset(dataset, test_indices)

## Triplet Dataset

In [7]:
import torch
import random
from collections import defaultdict
from torch.utils.data import Dataset

class TripletDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.label_to_indices = self._build_label_index()

    def _build_label_index(self):
        label_to_indices = defaultdict(list)
        for idx, (_, label) in enumerate(self.dataset):
            label_to_indices[label].append(idx)
        return label_to_indices

    def __getitem__(self, index):
        anchor_img, anchor_label = self.dataset[index]

        # Positive
        pos_idx = index
        while pos_idx == index:
            pos_idx = random.choice(self.label_to_indices[anchor_label])
        positive_img, _ = self.dataset[pos_idx]

        # Negative
        neg_label = anchor_label
        while neg_label == anchor_label:
            neg_label = random.choice(list(self.label_to_indices.keys()))
        neg_idx = random.choice(self.label_to_indices[neg_label])
        negative_img, _ = self.dataset[neg_idx]

        return anchor_img, positive_img, negative_img

    def __len__(self):
        return len(self.dataset)

## Pairs Dataset

In [8]:
class PairDataset(Dataset):
    def __init__(self, dataset, num_pairs=1000):
        self.dataset = dataset
        self.label_to_indices = self._build_label_index()
        self.num_pairs = num_pairs
        self.pairs = self._generate_pairs()

    def _build_label_index(self):
        label_to_indices = defaultdict(list)
        for idx, (_, label) in enumerate(self.dataset):
            label_to_indices[label].append(idx)
        return label_to_indices
        
    def _generate_pairs(self):
        pairs = []
        for _ in range(self.num_pairs // 2):
            label = random.choice(list(self.label_to_indices.keys()))
            i1, i2 = random.sample(self.label_to_indices[label], 2)
            pairs.append((i1, i2, 1))  # positive pair

            label1, label2 = random.sample(list(self.label_to_indices.keys()), 2)
            i1 = random.choice(self.label_to_indices[label1])
            i2 = random.choice(self.label_to_indices[label2])
            pairs.append((i1, i2, 0))  # negative pair

        return pairs

    def __getitem__(self, index):
        i1, i2, label = self.pairs[index]
        img1, _ = self.dataset[i1]
        img2, _ = self.dataset[i2]
        return img1, img2, torch.tensor(label, dtype=torch.long)

    def __len__(self):
        return len(self.pairs)

In [9]:
def get_triplet_dataloaders(
    dataset,
    batch_size,
    num_workers=2,
    val_ratio=0.1,
    test_ratio=0.1,
    seed=42
):
    train_set, val_set, test_set = stratified_split(dataset, val_ratio, test_ratio, seed)

    triplet_train = TripletDataset(train_set)
    pair_val = PairDataset(val_set, num_pairs=1000)
    pair_test = PairDataset(test_set, num_pairs=1000)

    train_loader = DataLoader(triplet_train, batch_size=batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=True, drop_last=True)

    val_loader = DataLoader(pair_val, batch_size=batch_size, shuffle=False,
                            num_workers=num_workers, pin_memory=True, drop_last=False)

    test_loader = DataLoader(pair_test, batch_size=batch_size, shuffle=False,
                             num_workers=num_workers, pin_memory=True, drop_last=False)

    return train_loader, val_loader, test_loader

In [10]:
def get_classif_dataloaders(
    dataset,
    batch_size,
    num_workers=2,
    val_ratio=0.1,
    test_ratio=0.1,
    seed=42
):
    train_set, val_set, test_set = stratified_split(dataset, val_ratio, test_ratio, seed)

    triplet_train = train_set
    pair_val = PairDataset(val_set, num_pairs=1000)
    pair_test = PairDataset(test_set, num_pairs=1000)

    train_loader = DataLoader(triplet_train, batch_size=batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=True, drop_last=True)

    val_loader = DataLoader(pair_val, batch_size=batch_size, shuffle=False,
                            num_workers=num_workers, pin_memory=True, drop_last=False)

    test_loader = DataLoader(pair_test, batch_size=batch_size, shuffle=False,
                             num_workers=num_workers, pin_memory=True, drop_last=False)

    return train_loader, val_loader, test_loader

In [11]:
dataset = getDataset(root_dir=filtered_root)
triplet_train_loader, triplet_val_loader, triplet_test_loader = get_triplet_dataloaders(
    dataset, batch_size=128)

train_loader, val_loader, test_loader = get_classif_dataloaders(
    dataset, batch_size=128)

In [12]:
len(triplet_train_loader), len(triplet_val_loader), len(triplet_test_loader)

(26, 8, 8)

# Model Architecture Overview

In [13]:
model = replace_linear_with_lowrank_2(
            get_timmfrv2('edgenext_small',featdim=512), rank_ratio=0.5)
model

TimmFRWrapperV2(
  (model): EdgeNeXt(
    (stem): Sequential(
      (0): Conv2d(3, 48, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((48,), eps=1e-06, elementwise_affine=True)
    )
    (stages): Sequential(
      (0): EdgeNeXtStage(
        (downsample): Identity()
        (blocks): Sequential(
          (0): ConvBlock(
            (conv_dw): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48)
            (norm): LayerNorm((48,), eps=1e-06, elementwise_affine=True)
            (mlp): Mlp(
              (fc1): LoRaLin(
                (linear1): Linear(in_features=48, out_features=24, bias=False)
                (linear2): Linear(in_features=24, out_features=192, bias=True)
              )
              (act): GELU(approximate='none')
              (drop1): Dropout(p=0.0, inplace=False)
              (norm): Identity()
              (fc2): LoRaLin(
                (linear1): Linear(in_features=192, out_features=24, bias=False)
                (linea

# Loss Function Comparison

In [43]:
model_name='edgeface_s_gamma_05'
embedding_size=512
margin_list = (1.0, 0.0, 0.4)
num_classes = len(dataset.classes)
sample_rate = 1
lr = 1e-3
weight_decay = 0.05

In [44]:
import torch
from torch.nn.functional import normalize, linear
from typing import Callable
import math


class CombinedMarginLoss(torch.nn.Module):
    def __init__(self, s: float, m1: float, m2: float, m3: float):
        super().__init__()
        self.s = s
        self.m1 = m1
        self.m2 = m2
        self.m3 = m3

    def forward(self, logits: torch.Tensor, labels: torch.Tensor):
        one_hot = torch.zeros_like(logits)
        one_hot.scatter_(1, labels.view(-1, 1), 1)

        # cosine similarity values (logits) must be clamped for arccos stability
        cosine = logits.clamp(-1 + 1e-7, 1 - 1e-7)
        theta = cosine.acos()

        # Apply angular margin (ArcFace or SphereFace)
        if self.m1 != 1.0 or self.m2 != 0.0:
            theta = self.m1 * theta + self.m2
            target_logits = theta.cos()
        else:
            target_logits = cosine

        # Apply additive cosine margin (CosFace)
        if self.m3 > 0.0:
            target_logits -= self.m3

        # Update logits for the ground-truth classes
        logits = logits.clone()
        logits[one_hot.bool()] = target_logits[one_hot.bool()]

        # Apply scale
        logits *= self.s
        return logits


class SimplePartialFC(torch.nn.Module):
    def __init__(
        self,
        margin_loss: Callable,
        embedding_size: int,
        num_classes: int,
        sample_rate: float = 1.0,
        fp16: bool = False,
    ):
        super().__init__()
        self.embedding_size = embedding_size
        self.num_classes = num_classes
        self.sample_rate = sample_rate
        self.fp16 = fp16

        self.weight = torch.nn.Parameter(torch.randn(num_classes, embedding_size) * 0.01)
        self.margin_softmax = margin_loss
        self.ce_loss = torch.nn.CrossEntropyLoss()

    def forward(self, embeddings: torch.Tensor, labels: torch.Tensor):
        labels = labels.long().view(-1)

        if self.sample_rate < 1.0:
            with torch.no_grad():
                positive = torch.unique(labels)
                num_sample = int(self.sample_rate * self.num_classes)
                all_indices = torch.randperm(self.num_classes, device=embeddings.device)
                neg_sample = all_indices[~torch.isin(all_indices, positive)][: max(0, num_sample - len(positive))]
                sample_indices = torch.cat([positive, neg_sample])
                sample_indices, _ = sample_indices.sort()
                weight = self.weight[sample_indices]
                label_map = {old.item(): new for new, old in enumerate(sample_indices)}
                labels = torch.tensor([label_map[l.item()] for l in labels], device=labels.device)
            logits = linear(normalize(embeddings), normalize(weight))
        else:
            logits = linear(normalize(embeddings), normalize(self.weight))

        logits = logits.clamp(-1, 1)
        logits = self.margin_softmax(logits, labels)
        return self.ce_loss(logits, labels)

In [68]:
class LFWTrainer:
    def __init__(self, model_name, 
                 embedding_size, train_loader, 
                 val_loader, margin_list=margin_list):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Model
        self.model = get_model(model_name, dropout=0.0, num_features=embedding_size)
        self.model = self.model.to(self.device)

        self.num_classes = num_classes
        self.train_loader = train_loader
        self.val_loader = val_loader
  
        # Loss
        # self.criterion = ArcFace(64, 0.5).to(self.device)
        self.criterion = CombinedMarginLoss(64, margin_list[0], margin_list[1], margin_list[2])
        
        self.module_partial_fc = SimplePartialFC(self.criterion, embedding_size, 
                                         num_classes, sample_rate, False)
        self.module_partial_fc.train().cuda()

        # Optimizer
        self.optimizer = torch.optim.AdamW(
            params=[{"params": self.model.parameters()}, 
                    {"params": self.module_partial_fc.parameters()}],
            lr=lr, weight_decay=weight_decay)
         
        # self.optimizer = optim.AdamW(self.model.parameters(), lr=lr, weight_decay=weight_decay)

        self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=0.5, patience=12, verbose=True)

        # Transform
        self.transform = transforms.Compose([
            transforms.Resize((112, 112)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        
    def train(self, num_epochs=10):
        for epoch in range(num_epochs):
            self.model.train()
            running_loss = 0.0
            
            for i, (images, labels) in enumerate(self.train_loader):
                images = images.to(self.device)
                labels = labels.to(self.device)
                
                embeddings = self.model(images)
                loss: torch.Tensor = self.module_partial_fc(embeddings, labels)
                
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                running_loss += loss.item()
                
                # if i % 10 == 0:
                #     print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(self.train_loader)}], Loss: {loss.item():.4f}')
            
            # Validate
            val_acc = self.validate()
            current_lr = self.optimizer.param_groups[0]['lr']
            epoch_loss = running_loss / len(self.train_loader)
            print(f'Epoch [{epoch+1}/{num_epochs}] Loss: {epoch_loss:.4f}, Val Acc: {val_acc:.4f},'
            f', LR: {current_lr:.6f}')

            self.scheduler.step(val_acc)
            
    def validate(self):
        self.model.eval()
        all_sims = []
        all_labels = []
    
        with torch.no_grad():
            for img1, img2, label in self.val_loader:
                img1, img2 = img1.to(self.device), img2.to(self.device)
                label = label.to(self.device)
    
                emb1 = F.normalize(self.model(img1))
                emb2 = F.normalize(self.model(img2))
                sim = F.cosine_similarity(emb1, emb2)
    
                all_sims.extend(sim.cpu().numpy())
                all_labels.extend(label.cpu().numpy())
    
        all_sims = np.array(all_sims)
        all_labels = np.array(all_labels)
    
        # Find the best threshold
        best_acc = 0.0
        best_thresh = 0.0
        for thresh in np.arange(0, 1.01, 0.01):
            preds = (all_sims > thresh).astype(int)
            acc = (preds == all_labels).mean()
            if acc > best_acc:
                best_acc = acc
                best_thresh = thresh
        return best_acc


In [46]:
def test(model, loader):
    model.eval()
    all_sims = []
    all_labels = []
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    with torch.no_grad():
        for img1, img2, label in loader:
            img1, img2 = img1.to(device), img2.to(device)
            label = label.to(device)

            emb1 = F.normalize(model(img1))
            emb2 = F.normalize(model(img2))
            sim = F.cosine_similarity(emb1, emb2)

            all_sims.extend(sim.cpu().numpy())
            all_labels.extend(label.cpu().numpy())

    all_sims = np.array(all_sims)
    all_labels = np.array(all_labels)

    # Find the best threshold
    best_acc = 0.0
    best_thresh = 0.0
    for thresh in np.arange(0, 1.01, 0.01):
        preds = (all_sims > thresh).astype(int)
        acc = (preds == all_labels).mean()
        if acc > best_acc:
            best_acc = acc
            best_thresh = thresh
    return best_acc

In [47]:
trainer=LFWTrainer(model_name = model_name, 
                    embedding_size = embedding_size, 
                    train_loader = train_loader, 
                     val_loader = val_loader,
                  
                  )

trainer.train(num_epochs=120)

Epoch [1/120] Loss: 33.0116, Val Acc: 0.5150,, LR: 0.001000
Epoch [2/120] Loss: 30.9069, Val Acc: 0.5010,, LR: 0.001000
Epoch [3/120] Loss: 30.6641, Val Acc: 0.5040,, LR: 0.001000
Epoch [4/120] Loss: 30.5675, Val Acc: 0.5120,, LR: 0.001000
Epoch [5/120] Loss: 30.4162, Val Acc: 0.5440,, LR: 0.001000
Epoch [6/120] Loss: 30.1792, Val Acc: 0.5880,, LR: 0.001000
Epoch [7/120] Loss: 29.8838, Val Acc: 0.6170,, LR: 0.001000
Epoch [8/120] Loss: 29.5270, Val Acc: 0.6420,, LR: 0.001000
Epoch [9/120] Loss: 29.1677, Val Acc: 0.6210,, LR: 0.001000
Epoch [10/120] Loss: 28.9309, Val Acc: 0.6200,, LR: 0.001000
Epoch [11/120] Loss: 28.2442, Val Acc: 0.6480,, LR: 0.001000
Epoch [12/120] Loss: 27.3225, Val Acc: 0.6010,, LR: 0.001000
Epoch [13/120] Loss: 26.0842, Val Acc: 0.6190,, LR: 0.001000
Epoch [14/120] Loss: 24.6807, Val Acc: 0.6240,, LR: 0.001000
Epoch [15/120] Loss: 23.9559, Val Acc: 0.6240,, LR: 0.001000
Epoch [16/120] Loss: 23.2427, Val Acc: 0.6260,, LR: 0.001000
Epoch [17/120] Loss: 22.0934, Val

In [48]:
test(trainer.model, test_loader)

0.837

## CosFace Loss Model with Test Acc: 83.7%

In [64]:
model_name='edgeface_s_gamma_05'
embedding_size=512
margin_list = (1.0, 0.25, 0.0)
num_classes = len(dataset.classes)
sample_rate = 1
lr = 1e-3
weight_decay = 0.05

In [69]:
trainer=LFWTrainer(model_name = model_name, 
                    embedding_size = embedding_size, 
                    train_loader = train_loader, 
                     val_loader = val_loader,
                      margin_list= margin_list
                  )

trainer.train(num_epochs=120)

Epoch [1/120] Loss: 23.6632, Val Acc: 0.5000,, LR: 0.001000
Epoch [2/120] Loss: 21.4760, Val Acc: 0.5040,, LR: 0.001000
Epoch [3/120] Loss: 20.9318, Val Acc: 0.5070,, LR: 0.001000
Epoch [4/120] Loss: 20.7911, Val Acc: 0.5260,, LR: 0.001000
Epoch [5/120] Loss: 20.6231, Val Acc: 0.5260,, LR: 0.001000
Epoch [6/120] Loss: 20.2912, Val Acc: 0.5750,, LR: 0.001000
Epoch [7/120] Loss: 20.0190, Val Acc: 0.5870,, LR: 0.001000
Epoch [8/120] Loss: 19.6621, Val Acc: 0.5990,, LR: 0.001000
Epoch [9/120] Loss: 19.3801, Val Acc: 0.6380,, LR: 0.001000
Epoch [10/120] Loss: 18.8653, Val Acc: 0.6200,, LR: 0.001000
Epoch [11/120] Loss: 17.9028, Val Acc: 0.6190,, LR: 0.001000
Epoch [12/120] Loss: 16.8983, Val Acc: 0.6170,, LR: 0.001000
Epoch [13/120] Loss: 15.9060, Val Acc: 0.6470,, LR: 0.001000
Epoch [14/120] Loss: 15.2050, Val Acc: 0.6520,, LR: 0.001000
Epoch [15/120] Loss: 14.4358, Val Acc: 0.6760,, LR: 0.001000
Epoch [16/120] Loss: 13.5192, Val Acc: 0.6800,, LR: 0.001000
Epoch [17/120] Loss: 12.2151, Val

In [70]:
test(trainer.model, test_loader)

0.844

## ArcFace Loss Model with Test Acc: 84.4%

In [75]:
model_name='edgeface_s_gamma_05'
embedding_size=512
margin_list = (1.5, 0.0, 0.0)
num_classes = len(dataset.classes)
sample_rate = 1
lr = 1e-3
weight_decay = 0.05

In [77]:
trainer=LFWTrainer(model_name = model_name, 
                    embedding_size = embedding_size, 
                    train_loader = train_loader, 
                     val_loader = val_loader,
                      margin_list= margin_list
                  )

trainer.train(num_epochs=100)

Epoch [1/100] Loss: 51.5502, Val Acc: 0.5030,, LR: 0.001000
Epoch [2/100] Loss: 41.1902, Val Acc: 0.5050,, LR: 0.001000
Epoch [3/100] Loss: 21.3510, Val Acc: 0.5020,, LR: 0.001000
Epoch [4/100] Loss: 8.7982, Val Acc: 0.5000,, LR: 0.001000
Epoch [5/100] Loss: 5.8497, Val Acc: 0.5000,, LR: 0.001000
Epoch [6/100] Loss: 5.2681, Val Acc: 0.5000,, LR: 0.001000
Epoch [7/100] Loss: 5.1279, Val Acc: 0.5000,, LR: 0.001000
Epoch [8/100] Loss: 5.0865, Val Acc: 0.5000,, LR: 0.001000
Epoch [9/100] Loss: 5.0703, Val Acc: 0.5000,, LR: 0.001000
Epoch [10/100] Loss: 5.0623, Val Acc: 0.5000,, LR: 0.001000
Epoch [11/100] Loss: 5.0612, Val Acc: 0.5000,, LR: 0.001000
Epoch [12/100] Loss: 5.0547, Val Acc: 0.5000,, LR: 0.001000
Epoch [13/100] Loss: 5.0336, Val Acc: 0.5250,, LR: 0.001000
Epoch [14/100] Loss: 5.0123, Val Acc: 0.5280,, LR: 0.001000
Epoch [15/100] Loss: 4.9940, Val Acc: 0.5580,, LR: 0.001000
Epoch [16/100] Loss: 4.9248, Val Acc: 0.5680,, LR: 0.001000
Epoch [17/100] Loss: 4.9039, Val Acc: 0.6200,,

In [84]:
test(trainer.model, test_loader)

0.812

## SphereFace Loss Model with Test Acc: 81.2%

In [52]:
import os
import torch
import torch.nn.functional as F
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader
from typing import Callable


class LFWTripletTrainer:
    def __init__(self, model_name, embedding_size,
                 train_loader, val_loader,
                 lr=0.001, weight_decay=5e-4, margin=1):
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Model (embedding-only)
        self.model = get_model(model_name, dropout=0.0, num_features=embedding_size)
        self.model = self.model.to(self.device)

        self.train_loader = train_loader
        self.val_loader = val_loader

        # Triplet Loss
        self.criterion = nn.TripletMarginLoss(margin=margin, p=2)

        # Optimizer
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=weight_decay)

        self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=0.5, patience=20, verbose=True)

        # Transform
        self.transform = transforms.Compose([
            transforms.Resize((112, 112)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    def train(self, num_epochs=10, save_path='checkpoints'):
        os.makedirs(save_path, exist_ok=True)

        for epoch in range(num_epochs):
            self.model.train()
            running_loss = 0.0

            for i, (anchor, positive, negative) in enumerate(self.train_loader):
                anchor = anchor.to(self.device)
                positive = positive.to(self.device)
                negative = negative.to(self.device)

                anchor_emb = F.normalize(self.model(anchor))
                positive_emb = F.normalize(self.model(positive))
                negative_emb = F.normalize(self.model(negative))

                loss = self.criterion(anchor_emb, positive_emb, negative_emb)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                running_loss += loss.item()

                # if i % 10 == 0:
                #     print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(self.train_loader)}], Loss: {loss.item():.4f}')

            val_acc = self.validate()
            current_lr = self.optimizer.param_groups[0]['lr']
            epoch_loss = running_loss / len(self.train_loader)
            print(f'Epoch [{epoch+1}/{num_epochs}] Loss: {epoch_loss:.4f}, Val Acc: {val_acc:.4f},'
            f', LR: {current_lr:.6f}')

            self.scheduler.step(val_acc)
            
    def validate(self):
        self.model.eval()
        all_sims = []
        all_labels = []
    
        with torch.no_grad():
            for img1, img2, label in self.val_loader:
                img1, img2 = img1.to(self.device), img2.to(self.device)
                label = label.to(self.device)
    
                emb1 = F.normalize(self.model(img1))
                emb2 = F.normalize(self.model(img2))
                sim = F.cosine_similarity(emb1, emb2)
    
                all_sims.extend(sim.cpu().numpy())
                all_labels.extend(label.cpu().numpy())
    
        all_sims = np.array(all_sims)
        all_labels = np.array(all_labels)
    
        # Find the best threshold
        best_acc = 0.0
        best_thresh = 0.0
        for thresh in np.arange(0, 1.01, 0.01):
            preds = (all_sims > thresh).astype(int)
            acc = (preds == all_labels).mean()
            if acc > best_acc:
                best_acc = acc
                best_thresh = thresh
        return best_acc

In [53]:
triplet_trainer= LFWTripletTrainer(
        model_name= model_name,
        embedding_size=embedding_size,
        train_loader=triplet_train_loader,
        val_loader=triplet_val_loader,
        lr=lr,
        weight_decay=weight_decay,
        margin = 1.5
)

triplet_trainer.train(num_epochs=120)


Epoch [1/120] Loss: 1.4538, Val Acc: 0.5000,, LR: 0.001000
Epoch [2/120] Loss: 1.3598, Val Acc: 0.5390,, LR: 0.001000
Epoch [3/120] Loss: 1.2776, Val Acc: 0.5870,, LR: 0.001000
Epoch [4/120] Loss: 1.2408, Val Acc: 0.6130,, LR: 0.001000
Epoch [5/120] Loss: 1.2192, Val Acc: 0.6090,, LR: 0.001000
Epoch [6/120] Loss: 1.1745, Val Acc: 0.6400,, LR: 0.001000
Epoch [7/120] Loss: 1.1742, Val Acc: 0.6380,, LR: 0.001000
Epoch [8/120] Loss: 1.2008, Val Acc: 0.6400,, LR: 0.001000
Epoch [9/120] Loss: 1.1283, Val Acc: 0.6460,, LR: 0.001000
Epoch [10/120] Loss: 1.1372, Val Acc: 0.6430,, LR: 0.001000
Epoch [11/120] Loss: 1.1899, Val Acc: 0.6200,, LR: 0.001000
Epoch [12/120] Loss: 1.0933, Val Acc: 0.6680,, LR: 0.001000
Epoch [13/120] Loss: 1.0439, Val Acc: 0.6330,, LR: 0.001000
Epoch [14/120] Loss: 1.0810, Val Acc: 0.6530,, LR: 0.001000
Epoch [15/120] Loss: 1.0521, Val Acc: 0.6450,, LR: 0.001000
Epoch [16/120] Loss: 1.0323, Val Acc: 0.6610,, LR: 0.001000
Epoch [17/120] Loss: 0.9957, Val Acc: 0.6520,, LR

In [60]:
test(triplet_trainer.model, test_loader)

0.698

## Triplet Loss Model with Test Acc: 69.8%

# Gamma Values Comparison

In [14]:
model_name='edgenext_small'
gamma = 0.2
embedding_size=512
margin_list = (1.0, 0.0, 0.4)
num_classes = len(dataset.classes)
sample_rate = 1
lr = 1e-3
weight_decay = 0.05

In [15]:
from torch.nn.functional import normalize, linear
from typing import Callable
import math

class CombinedMarginLoss(torch.nn.Module):
    def __init__(self, s: float, m1: float, m2: float, m3: float):
        super().__init__()
        self.s = s
        self.m1 = m1
        self.m2 = m2
        self.m3 = m3

    def forward(self, logits: torch.Tensor, labels: torch.Tensor):
        one_hot = torch.zeros_like(logits)
        one_hot.scatter_(1, labels.view(-1, 1), 1)

        # cosine similarity values (logits) must be clamped for arccos stability
        cosine = logits.clamp(-1 + 1e-7, 1 - 1e-7)
        theta = cosine.acos()

        # Apply angular margin (ArcFace or SphereFace)
        if self.m1 != 1.0 or self.m2 != 0.0:
            theta = self.m1 * theta + self.m2
            target_logits = theta.cos()
        else:
            target_logits = cosine

        # Apply additive cosine margin (CosFace)
        if self.m3 > 0.0:
            target_logits -= self.m3

        # Update logits for the ground-truth classes
        logits = logits.clone()
        logits[one_hot.bool()] = target_logits[one_hot.bool()]

        # Apply scale
        logits *= self.s
        return logits


class SimplePartialFC(torch.nn.Module):
    def __init__(
        self,
        margin_loss: Callable,
        embedding_size: int,
        num_classes: int,
        sample_rate: float = 1.0,
        fp16: bool = False,
    ):
        super().__init__()
        self.embedding_size = embedding_size
        self.num_classes = num_classes
        self.sample_rate = sample_rate
        self.fp16 = fp16

        self.weight = torch.nn.Parameter(torch.randn(num_classes, embedding_size) * 0.01)
        self.margin_softmax = margin_loss
        self.ce_loss = torch.nn.CrossEntropyLoss()

    def forward(self, embeddings: torch.Tensor, labels: torch.Tensor):
        labels = labels.long().view(-1)

        if self.sample_rate < 1.0:
            with torch.no_grad():
                positive = torch.unique(labels)
                num_sample = int(self.sample_rate * self.num_classes)
                all_indices = torch.randperm(self.num_classes, device=embeddings.device)
                neg_sample = all_indices[~torch.isin(all_indices, positive)][: max(0, num_sample - len(positive))]
                sample_indices = torch.cat([positive, neg_sample])
                sample_indices, _ = sample_indices.sort()
                weight = self.weight[sample_indices]
                label_map = {old.item(): new for new, old in enumerate(sample_indices)}
                labels = torch.tensor([label_map[l.item()] for l in labels], device=labels.device)
            logits = linear(normalize(embeddings), normalize(weight))
        else:
            logits = linear(normalize(embeddings), normalize(self.weight))

        logits = logits.clamp(-1, 1)
        logits = self.margin_softmax(logits, labels)
        return self.ce_loss(logits, labels)

In [33]:
class LFWTrainer:
    def __init__(self, model_name, 
                 embedding_size, train_loader, 
                 val_loader, gamma):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Model
        # self.model = get_model(model_name, num_features=embedding_size)
        if gamma == 1:
            self.model = get_timmfrv2(model_name,featdim=embedding_size)
        else:
            self.model = replace_linear_with_lowrank_2(
                get_timmfrv2(model_name,featdim=embedding_size), rank_ratio=gamma)
        
        self.model = self.model.to(self.device)

        self.num_classes = num_classes
        self.train_loader = train_loader
        self.val_loader = val_loader
  
        # Loss
        self.criterion = CombinedMarginLoss(64, margin_list[0], margin_list[1], margin_list[2])
        
        self.module_partial_fc = SimplePartialFC(self.criterion, embedding_size, 
                                         num_classes, sample_rate, False)
        self.module_partial_fc.train().cuda()

        # Optimizer
        self.optimizer = torch.optim.AdamW(
            params=[{"params": self.model.parameters()}, 
                    {"params": self.module_partial_fc.parameters()}],
            lr=lr, weight_decay=weight_decay)
         
        # self.optimizer = optim.AdamW(self.model.parameters(), lr=lr, weight_decay=weight_decay)

        self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=0.5, patience=12, verbose=True)

        # Transform
        self.transform = transforms.Compose([
            transforms.Resize((112, 112)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        
    def train(self, num_epochs=10):
        for epoch in range(num_epochs):
            self.model.train()
            running_loss = 0.0
            
            for i, (images, labels) in enumerate(self.train_loader):
                images = images.to(self.device)
                labels = labels.to(self.device)
                
                embeddings = self.model(images)
                loss: torch.Tensor = self.module_partial_fc(embeddings, labels)
                
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                running_loss += loss.item()

            # Validate
            val_acc = self.validate()
            current_lr = self.optimizer.param_groups[0]['lr']
            epoch_loss = running_loss / len(self.train_loader)
            print(f'Epoch [{epoch+1}/{num_epochs}] Loss: {epoch_loss:.4f}, Val Acc: {val_acc:.4f},'
            f', LR: {current_lr:.6f}')

            self.scheduler.step(val_acc)
            
    def validate(self):
        self.model.eval()
        all_sims = []
        all_labels = []
    
        with torch.no_grad():
            for img1, img2, label in self.val_loader:
                img1, img2 = img1.to(self.device), img2.to(self.device)
                label = label.to(self.device)
    
                emb1 = F.normalize(self.model(img1))
                emb2 = F.normalize(self.model(img2))
                sim = F.cosine_similarity(emb1, emb2)
    
                all_sims.extend(sim.cpu().numpy())
                all_labels.extend(label.cpu().numpy())
    
        all_sims = np.array(all_sims)
        all_labels = np.array(all_labels)
    
        # Find the best threshold
        best_acc = 0.0
        best_thresh = 0.0
        for thresh in np.arange(0, 1.01, 0.01):
            preds = (all_sims > thresh).astype(int)
            acc = (preds == all_labels).mean()
            if acc > best_acc:
                best_acc = acc
                best_thresh = thresh
        return best_acc

In [21]:
trainer=LFWTrainer(model_name = model_name, 
                    embedding_size = embedding_size, 
                    train_loader = train_loader, 
                     val_loader = val_loader,
                     gamma = gamma
                  )

trainer.train(num_epochs=200)



Epoch [1/200] Loss: 32.5131, Val Acc: 0.5230,, LR: 0.001000
Epoch [2/200] Loss: 30.8665, Val Acc: 0.5180,, LR: 0.001000
Epoch [3/200] Loss: 30.6194, Val Acc: 0.5160,, LR: 0.001000
Epoch [4/200] Loss: 30.4827, Val Acc: 0.5360,, LR: 0.001000
Epoch [5/200] Loss: 30.2382, Val Acc: 0.5900,, LR: 0.001000
Epoch [6/200] Loss: 29.8767, Val Acc: 0.6210,, LR: 0.001000
Epoch [7/200] Loss: 29.4180, Val Acc: 0.6380,, LR: 0.001000
Epoch [8/200] Loss: 29.0263, Val Acc: 0.6120,, LR: 0.001000
Epoch [9/200] Loss: 28.1760, Val Acc: 0.6270,, LR: 0.001000
Epoch [10/200] Loss: 27.1271, Val Acc: 0.6300,, LR: 0.001000
Epoch [11/200] Loss: 25.8964, Val Acc: 0.6260,, LR: 0.001000
Epoch [12/200] Loss: 24.6477, Val Acc: 0.6630,, LR: 0.001000
Epoch [13/200] Loss: 23.9950, Val Acc: 0.6410,, LR: 0.001000
Epoch [14/200] Loss: 23.3927, Val Acc: 0.6730,, LR: 0.001000
Epoch [15/200] Loss: 22.5402, Val Acc: 0.6790,, LR: 0.001000
Epoch [16/200] Loss: 21.6448, Val Acc: 0.7040,, LR: 0.001000
Epoch [17/200] Loss: 20.7114, Val

KeyboardInterrupt: 

In [32]:
test(trainer.model, test_loader)

0.81

## Gamma 0.2 model with Test Acc: 81%

In [51]:
# CosFace Loss Model above, is also Gamma 0.5 Model

## Gamma 0.5 model with Test Acc: 83.7%

In [35]:
trainer_full=LFWTrainer(model_name = model_name, 
                    embedding_size = embedding_size, 
                    train_loader = train_loader, 
                     val_loader = val_loader,
                     gamma = 1
                  )

trainer_full.train(num_epochs=120)

Epoch [1/120] Loss: 33.7074, Val Acc: 0.5330,, LR: 0.001000
Epoch [2/120] Loss: 30.9313, Val Acc: 0.5070,, LR: 0.001000
Epoch [3/120] Loss: 30.6163, Val Acc: 0.5120,, LR: 0.001000
Epoch [4/120] Loss: 30.4814, Val Acc: 0.5260,, LR: 0.001000
Epoch [5/120] Loss: 30.2050, Val Acc: 0.5630,, LR: 0.001000
Epoch [6/120] Loss: 29.7476, Val Acc: 0.6090,, LR: 0.001000
Epoch [7/120] Loss: 29.2451, Val Acc: 0.6510,, LR: 0.001000
Epoch [8/120] Loss: 28.6550, Val Acc: 0.6290,, LR: 0.001000
Epoch [9/120] Loss: 27.2374, Val Acc: 0.6330,, LR: 0.001000
Epoch [10/120] Loss: 25.9500, Val Acc: 0.6140,, LR: 0.001000
Epoch [11/120] Loss: 24.4452, Val Acc: 0.6330,, LR: 0.001000
Epoch [12/120] Loss: 23.0406, Val Acc: 0.6490,, LR: 0.001000
Epoch [13/120] Loss: 21.5031, Val Acc: 0.6440,, LR: 0.001000
Epoch [14/120] Loss: 20.2032, Val Acc: 0.6620,, LR: 0.001000
Epoch [15/120] Loss: 19.1258, Val Acc: 0.6910,, LR: 0.001000
Epoch [16/120] Loss: 17.0902, Val Acc: 0.7610,, LR: 0.001000
Epoch [17/120] Loss: 15.5891, Val

In [37]:
test(trainer_full.model, test_loader)

0.839

## Gamma 1 model with Test Acc: 83.9%