In [1]:
import pandas as pd 

pd.set_option("display.max_columns", None)      # show all columns
pd.set_option("display.max_colwidth", None)     # show full cell content
pd.set_option("display.width", None)            # no line wrapping

In [2]:
!unzip -q /kaggle/input/create-pairs-from-processed-morph-cacd-agedb-fgnet/_output_.zip -d /kaggle/working

In [3]:
import os 
path = "/kaggle/working/"
files = sorted(os.listdir(path))
print(files)

['AgeDB_processed', 'AgeDB_processed.csv', 'CACD_processed', 'CACD_processed.csv', 'FGNET_processed', 'FGNET_processed.csv', 'MORPH_processed', 'MORPH_processed.csv', '__notebook__.ipynb', '__pycache__', 'morph_cacd_agedb_fgnet.csv', 'morph_cacd_agedb_fgnet_updated.csv', 'train_data']


In [4]:
morph_cacd_agedb_fgnet = pd.read_csv("/kaggle/working/morph_cacd_agedb_fgnet_updated.csv")

In [5]:
morph_cacd_agedb_fgnet.head()

Unnamed: 0,identity,age,gender,filename,filepath,dataset
0,AgeDB_BurtLancaster,56,m,BurtLancaster_56_m_000000.jpg,/kaggle/working/AgeDB_processed/BurtLancaster/BurtLancaster_56_m_000000.jpg,AgeDB
1,AgeDB_GordonThomson,64,m,GordonThomson_64_m_000001.jpg,/kaggle/working/AgeDB_processed/GordonThomson/GordonThomson_64_m_000001.jpg,AgeDB
2,AgeDB_angelamerkel,20,f,angelamerkel_20_f_000002.jpg,/kaggle/working/AgeDB_processed/angelamerkel/angelamerkel_20_f_000002.jpg,AgeDB
3,AgeDB_LawrenceTierney,34,m,LawrenceTierney_34_m_000003.jpg,/kaggle/working/AgeDB_processed/LawrenceTierney/LawrenceTierney_34_m_000003.jpg,AgeDB
4,AgeDB_JamesWoods,54,m,JamesWoods_54_m_000004.jpg,/kaggle/working/AgeDB_processed/JamesWoods/JamesWoods_54_m_000004.jpg,AgeDB


In [6]:
len(morph_cacd_agedb_fgnet)

174051

In [7]:
morph_cacd_agedb_fgnet.dtypes

identity    object
age          int64
gender      object
filename    object
filepath    object
dataset     object
dtype: object

In [8]:
import os
import pandas as pd

exists_count = 0
missing_count = 0
missing_files = []

for path in morph_cacd_agedb_fgnet['filepath']:
    if pd.isna(path):
        missing_count += 1
        missing_files.append(path)
    elif os.path.exists(str(path)):
        exists_count += 1
    else:
        missing_count += 1
        missing_files.append(path)

print(f"Total files       : {len(morph_cacd_agedb_fgnet)}")
print(f"Existing files    : {exists_count}")
print(f"Missing files     : {missing_count}")


Total files       : 174051
Existing files    : 174014
Missing files     : 37


In [9]:
morph_cacd_agedb_fgnet[
    morph_cacd_agedb_fgnet['filepath'].isna()
].head(5)

Unnamed: 0,identity,age,gender,filename,filepath,dataset
169938,FGNET_nan,33,,,,FGNET
170066,FGNET_nan,40,,,,FGNET
170108,FGNET_nan,4,,,,FGNET
170123,FGNET_nan,8,,,,FGNET
170137,FGNET_nan,22,,,,FGNET


In [10]:
import os
import pandas as pd

# Keep only rows with valid existing file paths
mask_exists = (
    morph_cacd_agedb_fgnet['filepath'].notna() &
    morph_cacd_agedb_fgnet['filepath'].apply(lambda x: os.path.exists(str(x)))
)

morph_cacd_agedb_fgnet_clean = morph_cacd_agedb_fgnet[mask_exists].reset_index(drop=True)

print(f"Before cleaning : {len(morph_cacd_agedb_fgnet)}")
print(f"After cleaning  : {len(morph_cacd_agedb_fgnet_clean)}")
print(f"Dropped rows    : {len(morph_cacd_agedb_fgnet) - len(morph_cacd_agedb_fgnet_clean)}")


morph_cacd_agedb_fgnet = morph_cacd_agedb_fgnet_clean

Before cleaning : 174051
After cleaning  : 174014
Dropped rows    : 37


In [11]:
import os

# Assume your DataFrame is called `cacd`
# and the column name is exactly `filepath`

exists_count = 0
missing_count = 0
missing_files = []

for path in morph_cacd_agedb_fgnet['filepath']:
    if os.path.exists(path):
        exists_count += 1
    else:
        missing_count += 1
        missing_files.append(path)

print(f"Total files       : {len(morph_cacd_agedb_fgnet)}")
print(f"Existing files    : {exists_count}")
print(f"Missing files     : {missing_count}")


Total files       : 174014
Existing files    : 174014
Missing files     : 0


In [12]:
# Save cleaned DataFrame
output_path = "/kaggle/working/morph_cacd_agedb_fgnet_clean.csv"
morph_cacd_agedb_fgnet_clean.to_csv(output_path, index=False)

print(f"Saved cleaned file to: {output_path}")
print(f"Final number of rows: {len(morph_cacd_agedb_fgnet_clean)}")

Saved cleaned file to: /kaggle/working/morph_cacd_agedb_fgnet_clean.csv
Final number of rows: 174014


In [13]:
# Boolean mask: True if filepath contains at least one space
has_space = morph_cacd_agedb_fgnet["filepath"].str.contains(" ", regex=False)

# Counts
total_paths = len(morph_cacd_agedb_fgnet)
paths_with_space = has_space.sum()
paths_without_space = total_paths - paths_with_space

print(f"Total paths           : {total_paths}")
print(f"Paths WITH spaces     : {paths_with_space}")
print(f"Paths WITHOUT spaces  : {paths_without_space}")

Total paths           : 174014
Paths WITH spaces     : 0
Paths WITHOUT spaces  : 174014


In [14]:
len(morph_cacd_agedb_fgnet['identity'].unique())

3236

In [15]:
!pip install -q onnx onnx2pytorch accelerate scikit-learn scipy matplotlib tqdm

[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m58.3/58.3 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [16]:
!pip install -q gdown

In [17]:
file_path = "/kaggle/working/train_data/all_pairs.txt"

bad_rows = []

with open(file_path, "r", encoding="utf-8") as f:
    for line_num, line in enumerate(f, start=1):
        line = line.strip()

        # Skip empty lines safely
        if not line:
            bad_rows.append((line_num, 0, "EMPTY LINE"))
            continue

        values = line.split()  # split on spaces
        num_values = len(values)

        if num_values != 5:
            bad_rows.append((line_num, num_values, line))

# ---------- REPORT ----------
if bad_rows:
    print(f"\n‚ùå Found {len(bad_rows)} bad rows:\n")
    # for row in bad_rows:
    #     print(f"Line {row[0]} -> {row[1]} values:")
    #     print(row[2])
    #     print("-" * 80)
else:
    print("\n‚úÖ All rows have exactly 5 values. File is CLEAN!")

print(f"\n‚úÖ Total checked rows: {line_num}")
print(f"‚ùå Total bad rows: {len(bad_rows)}")


‚úÖ All rows have exactly 5 values. File is CLEAN!

‚úÖ Total checked rows: 277828
‚ùå Total bad rows: 0


In [18]:
!head -n 5 /kaggle/working/train_data/all_pairs.txt

/kaggle/working/MORPH_processed/69975/69975_54_m_045703.jpg /kaggle/working/MORPH_processed/8332/8332_16_m_049868.jpg 0 0 38.0
/kaggle/working/CACD_processed/1995/1995_19_f_154594.jpg /kaggle/working/CACD_processed/454/454_49_m_031592.jpg 0 0 30.0
/kaggle/working/AgeDB_processed/MaureenOHara/MaureenOHara_78_f_002181.jpg /kaggle/working/CACD_processed/685/685_48_m_048949.jpg 0 0 30.0
/kaggle/working/FGNET_processed/020A36.JPG /kaggle/working/CACD_processed/505/505_47_m_035080.jpg 0 0 11.0
/kaggle/working/CACD_processed/628/628_46_m_044396.jpg /kaggle/working/CACD_processed/628/628_48_m_044415.jpg 1 1 2.0


In [19]:
import math

files = {
    "all" : "/kaggle/working/train_data/all_pairs.txt",
    "train": "/kaggle/working/train_data/train_pairs.txt",
    "val": "/kaggle/working/train_data/val_pairs.txt",
}

def check_file(file_path):
    total = 0
    bad = 0

    with open(file_path, "r") as f:
        for line_num, line in enumerate(f, start=1):
            total += 1
            parts = line.strip().split()

            # 1) Less than 5 columns
            if len(parts) < 5:
                bad += 1
                continue

            # 2) Literal 'nan' or empty string in ANY column
            if any(p.strip() == "" or p.lower() == "nan" for p in parts):
                bad += 1
                continue

            try:
                # 3) Numeric checks (last 3 columns)
                nums = [float(parts[2]), float(parts[3]), float(parts[4])]

                # 4) Actual NaN values
                if any(math.isnan(x) for x in nums):
                    bad += 1

            except ValueError:
                # Non-numeric values
                bad += 1

    return total, bad


# Run checks
for name, path in files.items():
    total, bad = check_file(path)
    print(f"\nüìÇ {name.upper()} FILE")
    print(f"Total rows   : {total}")
    print(f"Invalid rows : {bad}")
    print(f"Valid rows   : {total - bad}")



üìÇ ALL FILE
Total rows   : 277828
Invalid rows : 121
Valid rows   : 277707

üìÇ TRAIN FILE
Total rows   : 222262
Invalid rows : 98
Valid rows   : 222164

üìÇ VAL FILE
Total rows   : 55566
Invalid rows : 23
Valid rows   : 55543


In [20]:
import math
import os
import shutil

files = {
    "all"  : "/kaggle/working/train_data/all_pairs.txt",
    "train": "/kaggle/working/train_data/train_pairs.txt",
    "val"  : "/kaggle/working/train_data/val_pairs.txt",
}

def is_valid_row(parts):
    # Less than 5 columns
    if len(parts) < 5:
        return False

    # Literal 'nan' or empty string anywhere
    if any(p.strip() == "" or p.lower() == "nan" for p in parts):
        return False

    try:
        nums = [float(parts[2]), float(parts[3]), float(parts[4])]
        if any(math.isnan(x) for x in nums):
            return False
    except ValueError:
        return False

    return True


for name, path in files.items():
    temp_path = path + ".tmp"

    total = 0
    kept = 0

    with open(path, "r") as fin, open(temp_path, "w") as fout:
        for line in fin:
            total += 1
            parts = line.strip().split()

            if is_valid_row(parts):
                fout.write(line)
                kept += 1

    # Replace original file safely
    shutil.move(temp_path, path)

    print(f"\nüìÇ {name.upper()} FILE CLEANED")
    print(f"Original rows : {total}")
    print(f"Kept rows     : {kept}")
    print(f"Dropped rows  : {total - kept}")



üìÇ ALL FILE CLEANED
Original rows : 277828
Kept rows     : 277707
Dropped rows  : 121

üìÇ TRAIN FILE CLEANED
Original rows : 222262
Kept rows     : 222164
Dropped rows  : 98

üìÇ VAL FILE CLEANED
Original rows : 55566
Kept rows     : 55543
Dropped rows  : 23


In [21]:
!git clone https://github.com/AbdoTW/AQUAFace.git

Cloning into 'AQUAFace'...
remote: Enumerating objects: 77, done.[K
remote: Counting objects: 100% (77/77), done.[K
remote: Compressing objects: 100% (69/69), done.[K
remote: Total 77 (delta 2), reused 73 (delta 2), pack-reused 0 (from 0)[K
Receiving objects: 100% (77/77), 3.66 MiB | 36.00 MiB/s, done.
Resolving deltas: 100% (2/2), done.


In [22]:
!gdown --id 1dWZb0SLcdzr-toUzsVZ1zogn9dEIW1Dk -O R18_MS1MV3.onnx

Downloading...
From: https://drive.google.com/uc?id=1dWZb0SLcdzr-toUzsVZ1zogn9dEIW1Dk
To: /kaggle/working/R18_MS1MV3.onnx
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 96.1M/96.1M [00:01<00:00, 69.0MB/s]


In [23]:
!gdown --id 1Gh8C-bwl2B90RDrvKJkXafvZC3q4_H_z -O R100_Glint360K.onnx

Downloading...
From (original): https://drive.google.com/uc?id=1Gh8C-bwl2B90RDrvKJkXafvZC3q4_H_z
From (redirected): https://drive.google.com/uc?id=1Gh8C-bwl2B90RDrvKJkXafvZC3q4_H_z&confirm=t&uuid=4f866ca7-c0c2-464f-aba4-95be61aa43d1
To: /kaggle/working/R100_Glint360K.onnx
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 261M/261M [00:02<00:00, 115MB/s]


In [24]:
# Create the folder
!mkdir -p /kaggle/working/AQUAFace/pretrained_models

# Move the ONNX models into it
!mv /kaggle/working/R100_Glint360K.onnx /kaggle/working/AQUAFace/pretrained_models/
!mv /kaggle/working/R18_MS1MV3.onnx /kaggle/working/AQUAFace/pretrained_models/

!ls /kaggle/working/AQUAFace/pretrained_models

R100_Glint360K.onnx  R18_MS1MV3.onnx


In [25]:
# !rm -f /kaggle/working/AQUAFace/train_data/all_pairs.txt
# !rm -f /kaggle/working/AQUAFace/train_data/train_pairs.txt
# !rm -f /kaggle/working/AQUAFace/train_data/val_pairs.txt

In [26]:
!cp /kaggle/working/train_data/all_pairs.txt /kaggle/working/AQUAFace/train_data/
!cp /kaggle/working/train_data/train_pairs.txt /kaggle/working/AQUAFace/train_data/
!cp /kaggle/working/train_data/val_pairs.txt /kaggle/working/AQUAFace/train_data/

In [27]:
file_path = "/kaggle/working/AQUAFace/train_data/all_pairs.txt"

bad_rows = []

with open(file_path, "r", encoding="utf-8") as f:
    for line_num, line in enumerate(f, start=1):
        line = line.strip()

        # Skip empty lines safely
        if not line:
            bad_rows.append((line_num, 0, "EMPTY LINE"))
            continue

        values = line.split()  # split on spaces
        num_values = len(values)

        if num_values != 5:
            bad_rows.append((line_num, num_values, line))

# ---------- REPORT ----------
if bad_rows:
    print(f"\n‚ùå Found {len(bad_rows)} bad rows:\n")
    # for row in bad_rows:
    #     print(f"Line {row[0]} -> {row[1]} values:")
    #     print(row[2])
    #     print("-" * 80)
else:
    print("\n‚úÖ All rows have exactly 5 values. File is CLEAN!")

print(f"\n‚úÖ Total checked rows: {line_num}")
print(f"‚ùå Total bad rows: {len(bad_rows)}")



‚úÖ All rows have exactly 5 values. File is CLEAN!

‚úÖ Total checked rows: 277707
‚ùå Total bad rows: 0


In [28]:
%%bash
cat << 'EOF' > /kaggle/working/AQUAFace/config/config.py
import os
import pandas as pd

class Config(object):
    def __init__(
        self,
        test_model_path='checkpoints/resnet18_110.pth',
        lfw_test_list=None
    ):
        # =====================================================
        # Base directory (AQUAFace root)
        # =====================================================
        self.base_dir = os.path.abspath(
            os.path.join(os.path.dirname(__file__), '..')
        )
        
        self.env = 'kaggle'  # Environment identifier
        self.backbone = 'r100'  
        self.classify = 'softmax'
        
        # ===== Auto-calculate num_classes from metadata =====
        metadata_path = '/kaggle/working/morph_cacd_agedb_fgnet_clean.csv'
        if os.path.exists(metadata_path):
            df = pd.read_csv(metadata_path)
            self.num_classes = len(df['identity'].unique())
            print(f"Auto-detected num_classes: {self.num_classes}")
        else:
            # Fallback if file doesn't exist
            self.num_classes = 19385
            print(f"Warning: Metadata file not found, using default num_classes: {self.num_classes}")
        
        self.metric = 'arc_margin'
        self.easy_margin = False
        self.use_se = True
        self.loss = 'con_loss'
        
        self.display = False
        self.finetune = True  # We're fine-tuning pretrained model
        
        # =====================================================
        # Training data paths
        # =====================================================
        self.train_root = os.path.join(self.base_dir, 'dataset', 'AgeDB')
        self.train_list = os.path.join(self.base_dir, 'train_data', 'train_pairs.txt')
        self.val_list = os.path.join(self.base_dir, 'train_data', 'val_pairs.txt')
        
        # Test data
        self.test_root = os.path.join(self.base_dir, 'dataset', 'AgeDB')
        self.test_list = os.path.join(self.base_dir, 'test.txt')
        
        # LFW (disabled on Kaggle)
        self.lfw_root = None
        self.lfw_test_list = lfw_test_list
        
        # =====================================================
        # Checkpoints & pretrained models
        # =====================================================
        self.checkpoints_path = os.path.join(self.base_dir, 'checkpoints')
        
        # Pretrained model path
        self.load_model_path = os.path.join(
            self.base_dir,
            'pretrained_models',
            'R100_Glint360K.onnx'
        )
        
        self.test_model_path = test_model_path
        self.save_interval = 1  # Save every epoch
        
        # =====================================================
        # OPTIMIZED FOR KAGGLE 16GB GPU
        # =====================================================
        self.train_batch_size = 128  # Reduced from 512 for stability
        self.test_batch_size = 128
        
        self.input_shape = (1, 112, 112)  # Standard for face recognition models
        self.optimizer = 'adam'  # Use adam like the old code
        
        self.lr = 2e-3  # Learning rate for Adam optimizer (matching old code)
        self.lr_step = 10
        self.lr_decay = 0.95
        self.weight_decay = 5e-4
        self.momentum = 0.9
        
        self.max_epoch = 15  # Reduced for Kaggle
        
        self.num_workers = 2  # Kaggle-friendly
        self.pin_memory = True

EOF

In [29]:
%%bash
cat << 'EOF' > /kaggle/working/AQUAFace/train.py

from __future__ import print_function
import os
import sys
from models import get_model
from loss.focal import *
from PIL import Image
import torch
from torch.utils import data
import torch.nn.functional as F
from models import *
from models import metrics
from utils.visualizer import *
import torchvision
import numpy as np
import random
random.seed(16)
import time
from config.config import *
from sklearn.metrics import roc_auc_score, roc_curve, auc
from torch.nn import DataParallel
from torch.optim.lr_scheduler import StepLR
from collections import defaultdict
from scipy.stats import hmean
from sklearn.metrics import accuracy_score
from sklearn.mixture import GaussianMixture
from sklearn.metrics.pairwise import cosine_similarity
from scipy import spatial
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from sklearn.model_selection import train_test_split
from accelerate import Accelerator
import matplotlib.pyplot as plt
from tqdm import tqdm
import re
import pandas as pd
torch.manual_seed(16)
torch.autograd.set_detect_anomaly(True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# =====================================================
# LOAD METADATA AND CREATE LOOKUP DICTIONARY
# =====================================================
METADATA_PATH = '/kaggle/working/morph_cacd_agedb_fgnet_clean.csv'
print("Loading metadata...")
metadata_df = pd.read_csv(METADATA_PATH)
print(f"Loaded {len(metadata_df)} rows from metadata")

# Create fast lookup dictionary
metadata_lookup = {
    row['filepath']: {
        'identity': row['identity'],
        'age': row['age'],
        'gender': row['gender']
    }
    for _, row in metadata_df.iterrows()
}
print(f"Created lookup dictionary with {len(metadata_lookup)} entries")

# Create subject ID mapping from unique identities
unique_identities = sorted(metadata_df['identity'].unique())
subject_id_map = {identity: idx for idx, identity in enumerate(unique_identities)}
print(f"Created subject_id_map with {len(subject_id_map)} unique identities")


class FixedDropout(torch.nn.Module):
    def __init__(self, p=0.5):
        super(FixedDropout, self).__init__()
        self.p = p

    def forward(self, x):
        return torch.nn.functional.dropout(x, p=self.p, training=self.training)


def save_model(model, save_path, name, iter_cnt, best_metric, current_metric, metric_name="loss"):
    """Save model checkpoint if metric improves"""
    save_name = os.path.join(save_path, name + '_' + str(iter_cnt) + '.pth')
    os.makedirs(save_path, exist_ok=True)
    
    print(f'Best {metric_name}: {best_metric:.4f}, Current {metric_name}: {current_metric:.4f}')
    
    # For loss: lower is better, for accuracy/AUC: higher is better
    if metric_name == "loss":
        improved = current_metric < best_metric
    else:  # accuracy, AUC, etc.
        improved = current_metric > best_metric
    
    if improved:
        print(f'Saving checkpoint to {save_name}')
        torch.save({
            'model_state_dict': model.module.state_dict() if hasattr(model, 'module') else model.state_dict(),
            'epoch': iter_cnt,
            metric_name: current_metric
        }, save_name)
        return current_metric
    else:
        print('Metric did not improve, not saving')
        return best_metric


class FDataset(Dataset):
    def __init__(self, data, transforms=None):
        self.data = data
        self.transforms = transforms
        
        # Use the global subject_id_map (already created from metadata)
        self.subject_id_map = subject_id_map
        print(f"FDataset initialized with {len(self.data)} samples")
        print(f"Using subject_id_map with {len(self.subject_id_map)} unique identities")

    def __len__(self):
        return len(self.data)

    def get_metadata(self, filepath):
        """Get identity, age, gender from metadata lookup"""
        metadata = metadata_lookup.get(filepath)
        if metadata is None:
            raise ValueError(f"Filepath not found in metadata: {filepath}")
        return metadata['identity'], metadata['age'], metadata['gender']

    def __getitem__(self, index):
        try:
            sample = self.data[index]
            splits = sample.strip().split()
            
            # Parse the pair information
            img1_path = splits[0].strip()
            img2_path = splits[1].strip()
            same_person = int(splits[2])
            same_age_group = int(splits[3])
            age_gap = float(splits[4])
            
            # Get metadata for both images
            identity1, age1, gender1 = self.get_metadata(img1_path)
            identity2, age2, gender2 = self.get_metadata(img2_path)
            
            # Get subject IDs
            subject_id1 = self.subject_id_map[identity1]
            subject_id2 = self.subject_id_map[identity2]
            
            # Load images
            img1 = Image.open(img1_path).convert('RGB')
            img2 = Image.open(img2_path).convert('RGB')
            
            if self.transforms is not None:
                img1 = self.transforms(img1)
                img2 = self.transforms(img2)
            
            return {
                'img1': img1,
                'img2': img2,
                'same_person': same_person,
                'same_age_group': same_age_group,
                'age_gap': age_gap,
                'subject_id1': subject_id1,
                'subject_id2': subject_id2,
                'age1': age1,
                'age2': age2
            }
            
        except Exception as e:
            print(f"Error processing sample at index {index}: {sample}")
            print(f"Error: {e}")
            raise


def load_data(file_path):
    """Load pairs from text file"""
    with open(file_path, 'r') as f:
        pairs = f.readlines()
    return pairs


def get_metrics(embedding1, embedding2, same_person, same_age_group, age_gaps):
    """Calculate metrics"""
    # Cosine similarity
    cos_sim = F.cosine_similarity(embedding1, embedding2)
    
    # Convert to numpy
    cos_sim_np = cos_sim.detach().cpu().numpy()
    same_person_np = same_person.cpu().numpy()
    same_age_group_np = same_age_group.cpu().numpy()
    age_gaps_np = age_gaps.cpu().numpy()
    
    # Person verification AUC
    person_auc = roc_auc_score(same_person_np, cos_sim_np)
    
    # Age group AUC
    age_group_auc = roc_auc_score(same_age_group_np, cos_sim_np)
    
    # Calculate best threshold for person verification
    fpr, tpr, thresholds = roc_curve(same_person_np, cos_sim_np)
    best_threshold_idx = np.argmax(tpr - fpr)
    best_threshold = thresholds[best_threshold_idx]
    
    # Accuracy at best threshold
    predictions = (cos_sim_np >= best_threshold).astype(int)
    accuracy = accuracy_score(same_person_np, predictions)
    
    return {
        'person_auc': person_auc,
        'age_group_auc': age_group_auc,
        'accuracy': accuracy,
        'best_threshold': best_threshold
    }


def train_one_epoch(model, train_loader, criterion, optimizer, accelerator, epoch):
    """Train for one epoch"""
    model.train()
    
    total_loss = 0
    all_cos_sim = []
    all_same_person = []
    all_same_age_group = []
    all_age_gaps = []
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch} Training")
    
    for batch_idx, batch in enumerate(pbar):
        img1 = batch['img1']
        img2 = batch['img2']
        same_person = batch['same_person']
        same_age_group = batch['same_age_group']
        age_gap = batch['age_gap']
        subject_id1 = batch['subject_id1']
        subject_id2 = batch['subject_id2']
        
        # Get embeddings
        embedding1 = model(img1)
        embedding2 = model(img2)
        
        # Apply metric learning through the attached metric_fc
        output1 = model.module.metric_fc(embedding1, subject_id1)
        output2 = model.module.metric_fc(embedding2, subject_id2)
        
        # Calculate loss
        loss1 = criterion(output1, subject_id1)
        loss2 = criterion(output2, subject_id2)
        loss = (loss1 + loss2) / 2
        
        # Backward pass
        optimizer.zero_grad()
        accelerator.backward(loss)
        optimizer.step()
        
        # Track metrics
        total_loss += loss.item()
        
        with torch.no_grad():
            cos_sim = F.cosine_similarity(embedding1, embedding2)
            all_cos_sim.append(cos_sim.cpu())
            all_same_person.append(same_person.cpu())
            all_same_age_group.append(same_age_group.cpu())
            all_age_gaps.append(age_gap.cpu())
        
        # Update progress bar
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    # Calculate epoch metrics
    avg_loss = total_loss / len(train_loader)
    
    all_cos_sim = torch.cat(all_cos_sim).numpy()
    all_same_person = torch.cat(all_same_person).numpy()
    all_same_age_group = torch.cat(all_same_age_group).numpy()
    all_age_gaps = torch.cat(all_age_gaps).numpy()
    
    # Calculate person AUC
    try:
        person_auc = roc_auc_score(all_same_person, all_cos_sim)
    except ValueError:
        person_auc = 0.0
        print("Warning: Could not calculate person AUC (only one class present)")
    
    # Calculate age group AUC
    try:
        age_group_auc = roc_auc_score(all_same_age_group, all_cos_sim)
    except ValueError:
        age_group_auc = 0.0
        print("Warning: Could not calculate age group AUC (only one class present)")
    
    # Calculate accuracy
    try:
        fpr, tpr, thresholds = roc_curve(all_same_person, all_cos_sim)
        best_threshold = thresholds[np.argmax(tpr - fpr)]
        accuracy = accuracy_score(all_same_person, (all_cos_sim >= best_threshold).astype(int))
    except Exception as e:
        print(f"Warning: Could not calculate accuracy: {e}")
        accuracy = 0.0
        best_threshold = 0.5
    
    return {
        'loss': avg_loss,
        'person_auc': person_auc,
        'age_group_auc': age_group_auc,
        'accuracy': accuracy
    }


def validate(model, val_loader, criterion, accelerator):
    """Validate the model"""
    model.eval()
    
    total_loss = 0
    all_cos_sim = []
    all_same_person = []
    all_same_age_group = []
    all_age_gaps = []
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc="Validation")
        for batch in pbar:
            img1 = batch['img1']
            img2 = batch['img2']
            same_person = batch['same_person']
            same_age_group = batch['same_age_group']
            age_gap = batch['age_gap']
            subject_id1 = batch['subject_id1']
            subject_id2 = batch['subject_id2']
            
            # Get embeddings
            embedding1 = model(img1)
            embedding2 = model(img2)
            
            # Apply metric learning
            output1 = model.module.metric_fc(embedding1, subject_id1)
            output2 = model.module.metric_fc(embedding2, subject_id2)
            
            # Calculate loss
            loss1 = criterion(output1, subject_id1)
            loss2 = criterion(output2, subject_id2)
            loss = (loss1 + loss2) / 2
            
            total_loss += loss.item()
            
            # Calculate cosine similarity
            cos_sim = F.cosine_similarity(embedding1, embedding2)
            all_cos_sim.append(cos_sim.cpu())
            all_same_person.append(same_person.cpu())
            all_same_age_group.append(same_age_group.cpu())
            all_age_gaps.append(age_gap.cpu())
    
    # Calculate metrics
    avg_loss = total_loss / len(val_loader)
    
    all_cos_sim = torch.cat(all_cos_sim).numpy()
    all_same_person = torch.cat(all_same_person).numpy()
    all_same_age_group = torch.cat(all_same_age_group).numpy()
    all_age_gaps = torch.cat(all_age_gaps).numpy()
    
    # Calculate person AUC
    try:
        person_auc = roc_auc_score(all_same_person, all_cos_sim)
    except ValueError:
        person_auc = 0.0
        print("Warning: Could not calculate person AUC (only one class present)")
    
    # Calculate age group AUC
    try:
        age_group_auc = roc_auc_score(all_same_age_group, all_cos_sim)
    except ValueError:
        age_group_auc = 0.0
        print("Warning: Could not calculate age group AUC (only one class present)")
    
    # Calculate accuracy
    try:
        fpr, tpr, thresholds = roc_curve(all_same_person, all_cos_sim)
        best_threshold = thresholds[np.argmax(tpr - fpr)]
        accuracy = accuracy_score(all_same_person, (all_cos_sim >= best_threshold).astype(int))
    except Exception as e:
        print(f"Warning: Could not calculate accuracy: {e}")
        accuracy = 0.0
        best_threshold = 0.5
    
    return {
        'loss': avg_loss,
        'person_auc': person_auc,
        'age_group_auc': age_group_auc,
        'accuracy': accuracy
    }


def evaluate_base_model(model, val_loader, device):
    """
    Evaluate base model before training.
    Returns metrics: ROC AUC, Accuracy, TPR@FPR
    """
    print(f"\n{'='*80}")
    print("üîç EVALUATING BASE MODEL (BEFORE TRAINING)")
    print("="*80)
    
    model.eval()
    y_true = []
    similarities = []
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validating Base Model"):
            img1 = batch['img1'].to(device)
            img2 = batch['img2'].to(device)
            same_person = batch['same_person']
            
            # Get embeddings
            feature1 = model(img1)
            feature2 = model(img2)
            
            # Compute cosine similarity
            feature1_np = feature1.data.cpu().numpy()
            feature2_np = feature2.data.cpu().numpy()
            
            for i in range(feature1_np.shape[0]):
                sim = 1 - spatial.distance.cosine(feature1_np[i].flatten(), feature2_np[i].flatten())
                if sim < 0:
                    sim = 0.0
                similarities.append(sim)
                y_true.append(same_person[i].item())
    
    y_true = np.array(y_true)
    similarities = np.array(similarities)
    
    # Compute ROC curve
    fpr, tpr, thresholds = roc_curve(y_true, similarities)
    roc_auc = auc(fpr, tpr)
    
    # Find optimal threshold
    optimal_idx = np.argmax(tpr - fpr)
    optimal_threshold = thresholds[optimal_idx]
    optimal_accuracy = np.mean((similarities >= optimal_threshold) == y_true)
    
    # TPR at fixed FPR values
    tpr_at_fpr = {}
    for fpr_value in [0.0001, 0.001, 0.01, 0.1]:
        idx = np.where(fpr >= fpr_value)[0]
        if len(idx) > 0:
            tpr_at_fpr[fpr_value] = tpr[idx[0]]
        else:
            tpr_at_fpr[fpr_value] = 0.0
    
    # Print results
    print(f"\nüéØ Validation Results:")
    print(f"  Validation pairs: {len(y_true)}")
    print(f"  ROC AUC: {roc_auc:.4f}")
    print(f"  Optimal Threshold: {optimal_threshold:.4f}")
    print(f"  Accuracy at Optimal Threshold: {optimal_accuracy:.4f}")
    print(f"  TPR at FPR=0.01%: {tpr_at_fpr.get(0.0001, 0):.4f}")
    print(f"  TPR at FPR=0.1%:  {tpr_at_fpr.get(0.001, 0):.4f}")
    print(f"  TPR at FPR=1%:    {tpr_at_fpr.get(0.01, 0):.4f}")
    print(f"{'='*80}\n")
    
    print("\nüìä BASE MODEL PERFORMANCE:")
    print(f"  ROC AUC: {roc_auc:.4f}")
    print(f"  Accuracy: {optimal_accuracy:.4f}")
    print(f"  TPR@FPR=0.01%: {tpr_at_fpr.get(0.0001, 0):.4f}")
    print(f"  TPR@FPR=1%: {tpr_at_fpr.get(0.01, 0):.4f}")
    
    print("\n" + "="*80)
    print("üöÄ NOW STARTING TRAINING...")
    print("="*80 + "\n")
    
    return {
        'roc_auc': roc_auc,
        'accuracy': optimal_accuracy,
        'threshold': optimal_threshold,
        'tpr_at_fpr': tpr_at_fpr
    }


def main():
    # Initialize config and accelerator
    opt = Config()
    accelerator = Accelerator()
    
    print(f"Accelerator device: {accelerator.device}")
    print()
    print("=" * 60)
    print("Kaggle Configuration Loaded")
    print("=" * 60)
    print(f"Batch Size (Train): {opt.train_batch_size}")
    print(f"Batch Size (Test):  {opt.test_batch_size}")
    print(f"Max Epochs:         {opt.max_epoch}")
    print(f"Num Classes:        {opt.num_classes}")
    print(f"Num Workers:        {opt.num_workers}")
    print(f"Pretrained Model:   {os.path.basename(opt.load_model_path)}")
    print("=" * 60)
    print()
    
    # Data transforms
    transform = transforms.Compose([
        transforms.Resize((112, 112)),  # Standard size for face recognition
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    # Load training data
    print("=" * 60)
    print("Loading Training Data (MORPH + AgeDB + CACD + FGNET)")
    print("=" * 60)
    
    train_pairs_file = '/kaggle/working/train_data/train_pairs.txt'
    val_pairs_file = '/kaggle/working/train_data/val_pairs.txt'
    
    print(f"USING TRAIN LIST FILE: {train_pairs_file}")
    
    # Load pairs
    all_pairs = load_data(train_pairs_file)
    print(f"Total train pairs loaded: {len(all_pairs)}")
    
    val_pairs = load_data(val_pairs_file)
    print(f"Total val pairs loaded: {len(val_pairs)}")
    
    # Create datasets
    train_dataset = FDataset(data=all_pairs, transforms=transform)
    val_dataset = FDataset(data=val_pairs, transforms=transform)
    
    print(f"Train dataset size: {len(train_dataset)}")
    print(f"Val dataset size: {len(val_dataset)}")
    print()
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=opt.train_batch_size,
        shuffle=True,
        num_workers=opt.num_workers,
        pin_memory=opt.pin_memory
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=opt.test_batch_size,
        shuffle=False,
        num_workers=opt.num_workers,
        pin_memory=opt.pin_memory
    )
    
    # Initialize model - Load ONNX pretrained model
    print("=" * 60)
    print("Loading Pretrained Model")
    print("=" * 60)
    
    try:
        import onnx
        from onnx2pytorch import ConvertModel
        
        onnx_path = opt.load_model_path
        print(f"Loading ONNX model from: {onnx_path}")
        onnx_model = onnx.load(onnx_path)
        model = ConvertModel(onnx_model)
        print("‚úÖ ONNX model loaded successfully")
    except Exception as e:
        print(f"‚ùå Could not load pretrained ONNX model: {e}")
        print("Please install: pip install onnx onnx2pytorch")
        raise
    
    # Metric learning head
    if opt.metric == 'add_margin':
        metric_fc = metrics.AddMarginProduct(512, opt.num_classes, s=30, m=0.35)
    elif opt.metric == 'arc_margin':
        metric_fc = metrics.ArcMarginProduct(512, opt.num_classes, s=30, m=0.5, easy_margin=opt.easy_margin)
    elif opt.metric == 'sphere':
        metric_fc = metrics.SphereProduct(512, opt.num_classes, m=4)
    else:
        metric_fc = torch.nn.Linear(512, opt.num_classes)
    
    # Wrap model with DataParallel and attach metric_fc
    model = DataParallel(model).to(device)
    metric_fc = metric_fc.to(device)
    model.module.metric_fc = metric_fc
    
    print(f"Model: {opt.backbone}")
    print(f"Metric FC: {opt.metric}")
    print("=" * 60)
    print()
    
    # ===== EVALUATE BASE MODEL BEFORE TRAINING =====
    base_results = evaluate_base_model(model, val_loader, device)
    
    # ===== FREEZE BACKBONE LAYERS (Fine-tuning strategy) =====
    print("=" * 60)
    print("Freezing Backbone - Fine-tuning Strategy")
    print("=" * 60)
    
    for name, param in model.named_parameters():
        # Only train: bn4, fc5, bn5, and metric_fc
        if 'bn4' not in name and 'fc5' not in name and 'bn5' not in name and 'metric_fc' not in name:
            param.requires_grad = False
    
    print("Trainable parameters:")
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"  {name}")
    print("=" * 60)
    print()
    
    # Loss function
    criterion = torch.nn.CrossEntropyLoss()
    
    # Optimizer
    if opt.optimizer == 'sgd':
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=opt.lr,
            momentum=opt.momentum,
            weight_decay=opt.weight_decay
        )
    else:  # adam
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=opt.lr,
            weight_decay=opt.weight_decay
        )
    
    scheduler = StepLR(optimizer, step_size=opt.lr_step, gamma=opt.lr_decay)
    
    # Prepare with accelerator
    model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare(
        model, optimizer, train_loader, val_loader, scheduler
    )
    
    print("=" * 60)
    print("Training Configuration")
    print("=" * 60)
    print(f"Optimizer: {opt.optimizer}")
    print(f"Learning rate: {opt.lr}")
    print(f"Loss function: CrossEntropyLoss")
    print("=" * 60)
    print()
    
    # Training loop
    best_val_auc = base_results['roc_auc'] if base_results else 0.0  # Initialize with base model AUC
    
    for epoch in range(1, opt.max_epoch + 1):
        print(f"\n{'='*60}")
        print(f"Epoch {epoch}/{opt.max_epoch}")
        print(f"{'='*60}")
        
        # Train
        train_metrics = train_one_epoch(model, train_loader, criterion, optimizer, accelerator, epoch)
        
        print(f"\nTrain Metrics:")
        print(f"  Loss: {train_metrics['loss']:.4f}")
        print(f"  Person AUC: {train_metrics['person_auc']:.4f}")
        print(f"  Age Group AUC: {train_metrics['age_group_auc']:.4f}")
        print(f"  Accuracy: {train_metrics['accuracy']:.4f}")
        
        # Validate
        val_metrics = validate(model, val_loader, criterion, accelerator)
        
        print(f"\nValidation Metrics:")
        print(f"  Loss: {val_metrics['loss']:.4f}")
        print(f"  Person AUC: {val_metrics['person_auc']:.4f}")
        print(f"  Age Group AUC: {val_metrics['age_group_auc']:.4f}")
        print(f"  Accuracy: {val_metrics['accuracy']:.4f}")
        
        # Step scheduler
        scheduler.step()
        
        # Save model
        if epoch % opt.save_interval == 0:
            best_val_auc = save_model(
                model,
                opt.checkpoints_path,
                opt.backbone,
                epoch,
                best_val_auc,
                val_metrics['person_auc'],
                metric_name="person_auc"
            )
    
    print("\n" + "="*60)
    print("Training completed!")
    print("="*60)


if __name__ == '__main__':
    main()

EOF

In [30]:
cd AQUAFace

/kaggle/working/AQUAFace


In [31]:
!pip install -q visdom

[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m1.4/1.4 MB[0m [31m27.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for visdom (setup.py) ... [?25l[?25hdone


In [32]:
# cat /kaggle/working/AQUAFace/models/__init__.py

In [33]:
!python train.py

Using device: cuda
Loading metadata...
Loaded 174014 rows from metadata
Created lookup dictionary with 174014 entries
Created subject_id_map with 3236 unique identities
Auto-detected num_classes: 3236
Accelerator device: cuda

Kaggle Configuration Loaded
Batch Size (Train): 128
Batch Size (Test):  128
Max Epochs:         15
Num Classes:        3236
Num Workers:        2
Pretrained Model:   R100_Glint360K.onnx

Loading Training Data (MORPH + AgeDB + CACD + FGNET)
USING TRAIN LIST FILE: /kaggle/working/train_data/train_pairs.txt
Total train pairs loaded: 222164
Total val pairs loaded: 55543
FDataset initialized with 222164 samples
Using subject_id_map with 3236 unique identities
FDataset initialized with 55543 samples
Using subject_id_map with 3236 unique identities
Train dataset size: 222164
Val dataset size: 55543

Loading Pretrained Model
Loading ONNX model from: /kaggle/working/AQUAFace/pretrained_models/R100_Glint360K.onnx
‚úÖ ONNX model loaded successfu