In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory
import os
'''
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))
'''
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# We begin by importing all the necessary libraries for building, training, and evaluating a deep learning model using PyTorch.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
# Using the high-performing EfficientNet-B4
from torchvision.models import efficientnet_b4, EfficientNet_B4_Weights

import os
import shutil
import random
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm
import collections
import math
from sklearn.metrics import f1_score # Import F1 score utility

# 🛠️ 1. Setup and Configuration

We define a configuration dictionary that stores all the relevant hyperparameters and file paths used throughout the training process.

In [None]:
# --- 1. Setup and Configuration ---

# Configuration dictionary for hyperparameters and paths
CONFIG = {
    "BASE_PATH": "/kaggle/input/comsys/Comys_Hackathon5/Task_B", #
    "OUTPUT_PATH": "/kaggle/working/data", #
    "BEST_MODEL_PATH": "/kaggle/working/data/best_embedding_model.pth",
    "TRAIN_SPLIT_RATIO": 0.8, #
    "BATCH_SIZE": 16, # Kept smaller for the large EfficientNet model
    "EPOCHS": 25, #
    "LEARNING_RATE": 0.001, #
    "EMBEDDING_DIM": 512,
    # --- ArcFace Hyperparameters ---
    "ARCFACE_SCALE": 30.0,
    "ARCFACE_MARGIN": 0.5,
}

# Ensure the output directory exists
os.makedirs(CONFIG["OUTPUT_PATH"], exist_ok=True) #

# Set device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"--- Using device: {DEVICE} ---") #

# 📦 2. Data Preparation

We define a function to split the original dataset into class-exclusive **train**, **validation**, and **test**

In [None]:
# --- 2. Data Preparation ---

def prepare_datasets(base_path, output_path, train_split_ratio):
    """
    Splits the original dataset into class-exclusive train, validation,
    and test sets as per the requirements.
    """
    print("--- Starting Data Preparation ---")
    
    original_train_path = os.path.join(base_path, "train") #
    original_val_path = os.path.join(base_path, "val") #
    
    final_train_path = os.path.join(output_path, "train_final") #
    final_val_path = os.path.join(output_path, "val_final") #
    final_test_path = os.path.join(output_path, "test") #

    for path in [final_train_path, final_val_path, final_test_path]:
        if os.path.exists(path):
            shutil.rmtree(path) #
        os.makedirs(path) #
    
    person_classes = [d for d in os.listdir(original_train_path) if os.path.isdir(os.path.join(original_train_path, d))] #
    random.shuffle(person_classes) #
    
    split_index = int(len(person_classes) * train_split_ratio) #
    train_classes = person_classes[:split_index] #
    val_classes = person_classes[split_index:] #

    print(f"Splitting original train data: {len(train_classes)} classes for training, {len(val_classes)} for validation.") #

    for class_name in tqdm(train_classes, desc="Copying train classes"):
        shutil.copytree(os.path.join(original_train_path, class_name), os.path.join(final_train_path, class_name)) #
        
    for class_name in tqdm(val_classes, desc="Copying validation classes"):
        shutil.copytree(os.path.join(original_train_path, class_name), os.path.join(final_val_path, class_name)) #
        
    print("Preparing test set from original 'val' folder...") #
    test_classes = [d for d in os.listdir(original_val_path) if os.path.isdir(os.path.join(original_val_path, d))] #
    for class_name in tqdm(test_classes, desc="Copying test classes"):
        shutil.copytree(os.path.join(original_val_path, class_name), os.path.join(final_test_path, class_name)) #
        
    print("--- Data Preparation Complete ---") #
    return final_train_path, final_val_path, final_test_path, len(train_classes)

# 🧠 3. Advanced Model Architecture with ArcFace

## 🔍 EmbeddingNet

The `EmbeddingNet` class uses the pretrained **EfficientNet-B4** model as a backbone for feature extraction. Instead of outputting class probabilities, the final classification layer is replaced with a fully connected layer that maps to a fixed-size embedding vector (e.g., 512 dimensions). This embedding is then used as input to a metric-learning-based classification layer such as ArcFace.

The model also uses built-in input transforms from the EfficientNet weights, which ensure that the input images are normalized and resized appropriately.

---

## 🧩 ArcMarginProduct: ArcFace Layer

The `ArcMarginProduct` class implements the **ArcFace** loss function, which is designed for classification tasks that benefit from learning **discriminative features** — such as face recognition or fine-grained visual identification.

Instead of using raw logits for classification, ArcFace computes the cosine similarity between normalized embeddings and class weights. Then it adds an **angular margin penalty** to the similarity score of the correct class. This makes the decision boundary more strict and forces the network to produce embeddings that are:

- **Close together** for samples from the same class
- **Far apart** for samples from different classes

The logits are scaled by a constant factor before applying softmax, which helps stabilize training.

---

## ✅ Summary

This architecture combines a strong image feature extractor (EfficientNet-B4) with a metric learning head (ArcFace), enabling more robust and meaningful embeddings for classification tasks. This is particularly effective in scenarios with a large number of classes and subtle differences between them.



In [None]:
# --- 3. Advanced Model Architecture with ArcFace ---

class EmbeddingNet(nn.Module):
    """The core feature extractor network, using EfficientNet-B4."""
    def __init__(self, embedding_dim):
        super(EmbeddingNet, self).__init__()
        weights = EfficientNet_B4_Weights.DEFAULT
        self.backbone = efficientnet_b4(weights=weights)
        in_features = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Linear(in_features, embedding_dim)
        self.transforms = weights.transforms()

    def forward(self, x):
        embedding = self.backbone(x)
        return embedding

class ArcMarginProduct(nn.Module):
    """Implementation of ArcFace layer."""
    def __init__(self, in_features, out_features, s=30.0, m=0.50):
        super(ArcMarginProduct, self).__init__()
        self.in_features, self.out_features, self.s, self.m = in_features, out_features, s, m
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)
        self.cos_m, self.sin_m = math.cos(m), math.sin(m)
        self.th, self.mm = math.cos(math.pi - m), math.sin(math.pi - m) * m

    def forward(self, embedding, label):
        cosine = F.linear(F.normalize(embedding), F.normalize(self.weight))
        sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
        phi = cosine * self.cos_m - sine * self.sin_m
        phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = torch.zeros(cosine.size(), device=DEVICE)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output

# 🗂️ 4. Dataset for Classification

## Custom Dataset: `FaceClassificationDataset`

This class implements a PyTorch `Dataset` designed specifically for ArcFace training.

### Key Features:

- **Data Organization:**
  - Assumes the dataset is organized by class folders inside a root directory (`data_dir`).
  - Each folder corresponds to a person/class and contains the images belonging to that class.
  - Additionally, if a subfolder named `distortion` exists inside a class folder, images inside it are also included under the same class label.

- **Class Label Mapping:**
  - Classes (person names) are sorted alphabetically and mapped to integer indices.
  - This mapping is stored in `class_to_idx` for consistent label encoding.

- **Data Access:**
  - `__len__` returns the total number of images.
  - `__getitem__` loads an image by its index, applies optional transformations, and returns the `(image, class_label)` pair.

### Usage:

This dataset can be used with PyTorch DataLoaders to efficiently load and batch images during training. It supports data augmentation or normalization via the `transform` parameter.

---

This design ensures that images — including those in distortion folders — are correctly grouped by class for effective training with ArcFace loss.


In [None]:
# --- 4. Dataset for Classification ---

class FaceClassificationDataset(Dataset):
    """Custom dataset for ArcFace training. Returns (image, class_label)."""
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.image_paths, self.class_to_idx = [], {}
        person_classes = sorted([d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))])
        for i, class_name in enumerate(person_classes): self.class_to_idx[class_name] = i
        for class_name, idx in self.class_to_idx.items():
            class_path = os.path.join(self.data_dir, class_name)
            distortion_path = os.path.join(class_path, 'distortion')
            for img_name in os.listdir(class_path):
                if img_name != 'distortion': self.image_paths.append((os.path.join(class_path, img_name), idx))
            if os.path.exists(distortion_path):
                for img_name in os.listdir(distortion_path): self.image_paths.append((os.path.join(distortion_path, img_name), idx))
    def __len__(self): return len(self.image_paths)
    def __getitem__(self, idx):
        img_path, label = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform: image = self.transform(image)
        return image, label

# 🚀 5. Training and Evaluation

## Data Preparation for Evaluation

- **prepare_evaluation_sets** scans the dataset directory once to gather:
  - A **reference gallery**: clean images per class used as a stable set of embeddings for comparison.
  - A **query set**: images (e.g., distorted or augmented) that need to be identified by the model.
  - This avoids repeatedly scanning disk during evaluation, improving efficiency.

---

## Model Evaluation

- The evaluation function uses the pre-collected reference gallery and query images.
- For each class in the reference gallery, embeddings are computed and averaged to form a class prototype vector.
- Each query image embedding is compared against all class prototypes using cosine distance.
- The closest prototype determines the predicted class.
- Metrics computed:
  - **Top-1 Accuracy**: Percentage of correct predictions.
  - **Macro F1-Score**: F1 score averaged across classes, robust to class imbalance.

---

## Model Training Loop

- The training loop jointly trains the feature extractor and the ArcFace classification head.
- For each batch:
  - Images are fed through the feature extractor to obtain embeddings.
  - The ArcFace head produces margin-penalized logits from embeddings and labels.
  - The loss is computed, gradients are backpropagated, and parameters updated.
- After each epoch:
  - The model is evaluated on the validation set.
  - Learning rate scheduler adjusts based on validation accuracy.
  - The best performing model checkpoint (based on accuracy) is saved.

---

## Summary

This setup trains the model to learn discriminative embeddings with ArcFace margin penalties, and periodically evaluates using a reference-query retrieval approach to monitor performance with meaningful metrics like accuracy and macro F1.


In [None]:
# --- 5. Training and Evaluation ---

def prepare_evaluation_sets(data_path):
    """
    Scans the data directory once to get the paths for reference and query images.
    This avoids re-scanning the disk on every epoch.
    """
    reference_gallery_paths = {}
    query_set = []
    person_classes = sorted([d for d in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, d))])
    
    for class_name in person_classes:
        class_path = os.path.join(data_path, class_name)
        distortion_path = os.path.join(class_path, 'distortion')
        clean_images = [os.path.join(class_path, f) for f in os.listdir(class_path) if f != 'distortion' and f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        if clean_images:
            reference_gallery_paths[class_name] = clean_images
        if os.path.exists(distortion_path):
            for img_name in os.listdir(distortion_path):
                query_set.append((os.path.join(distortion_path, img_name), class_name))
    
    return reference_gallery_paths, query_set, person_classes

def evaluate(model, ref_gallery_paths, query_set, person_classes, transform, device):
    """
    Evaluates the model using pre-calculated file paths.
    """
    print(f"\n--- Evaluating Model ---")
    model.to(device)
    model.eval()

    y_true, y_pred = [], []
    class_to_idx = {name: i for i, name in enumerate(person_classes)}

    if not query_set:
        print("No query images found to evaluate.")
        return 0.0, 0.0

    print("Creating reference embedding gallery...")
    avg_reference_embeddings = {}
    with torch.no_grad():
        for class_name, img_paths in tqdm(ref_gallery_paths.items(), desc="Processing reference images"):
            # The model is run here, so embeddings are always up-to-date
            embeddings = [model(transform(Image.open(p).convert("RGB")).unsqueeze(0).to(device)) for p in img_paths]
            avg_reference_embeddings[class_name] = torch.mean(torch.cat(embeddings), dim=0)

    ref_labels = list(avg_reference_embeddings.keys())
    ref_embeds = torch.stack(list(avg_reference_embeddings.values()))

    print("Matching query images against gallery...")
    with torch.no_grad():
        for query_path, true_label in tqdm(query_set, desc="Processing query images"):
            img_tensor = transform(Image.open(query_path).convert("RGB")).unsqueeze(0).to(device)
            query_embedding = model(img_tensor)
            distances = torch.cdist(F.normalize(query_embedding), F.normalize(ref_embeds))
            best_match_idx = torch.argmin(distances, dim=1).item()
            predicted_label = ref_labels[best_match_idx]
            y_true.append(class_to_idx[true_label])
            y_pred.append(class_to_idx[predicted_label])
            
    accuracy = (np.sum(np.array(y_true) == np.array(y_pred)) / len(y_true)) * 100
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    
    print(f"Evaluation Complete:")
    print(f"  - Top-1 Accuracy: {accuracy:.2f}%")
    print(f"  - Macro Avg F1-Score: {macro_f1:.4f}")
    return accuracy, macro_f1


def train_model(feature_extractor, arcface_head, optimizer, criterion, scheduler, train_loader, val_gallery_paths, val_query_set, val_person_classes, transform, device, epochs, best_model_path):
    print("\n--- Starting Model Training with ArcFace ---")
    feature_extractor.to(device)
    arcface_head.to(device)
    best_val_acc = 0.0

    for epoch in range(epochs):
        feature_extractor.train()
        arcface_head.train()
        running_loss = 0.0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            embeddings = feature_extractor(images)
            logits = arcface_head(embeddings, labels)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())
            
        epoch_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs}, Average Training Loss: {epoch_loss:.4f}")
        
        # --- Validation after each epoch ---
        # Pass the pre-computed paths to the evaluate function
        val_acc, val_f1 = evaluate(feature_extractor, val_gallery_paths, val_query_set, val_person_classes, transform, device)
        scheduler.step(val_acc)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            print(f"*** New best model found! Acc: {val_acc:.2f}%. Saving to {best_model_path} ***")
            torch.save(feature_extractor.state_dict(), best_model_path)
            
    print("--- Model Training Complete ---")

# ▶️ 6. Main Execution Block

This is the entry point of the training and evaluation pipeline. It orchestrates the entire process as follows:

## Dataset Preparation

- Calls `prepare_datasets` to split the original data into training, validation, and test sets with class-exclusive splits.
- Receives the number of training classes to configure the ArcFace head.

## Model and Optimizer Setup

- Instantiates the `EmbeddingNet` feature extractor and the `ArcMarginProduct` ArcFace head with appropriate dimensions and hyperparameters from `CONFIG`.
- Defines the loss function (`CrossEntropyLoss`) and optimizer (`Adam`).
- Sets up a learning rate scheduler (`ReduceLROnPlateau`) that reduces learning rate when validation accuracy plateaus.

## Data Loading

- Creates a PyTorch DataLoader for the training set, applying the EfficientNet-specific data transforms.

## Validation Set Preprocessing

- Prepares validation gallery and query sets once before training to avoid repeated disk I/O during validation.
- This optimization accelerates validation and evaluation phases.

## Training

- Calls the `train_model` function to train over the specified number of epochs.
- The best model (based on validation accuracy) is saved to disk.

## Final Testing

- Prepares test set gallery and query paths similarly.
- Loads the best saved model weights.
- Runs final evaluation on the test set, reporting accuracy and F1-score.

---

## Summary

This main block connects data preparation, model setup, training, and final evaluation seamlessly, ensuring efficient execution and robust performance monitoring.


In [None]:
# --- 6. Main Execution Block ---

if __name__ == '__main__':
    train_path, val_path, test_path, num_train_classes = prepare_datasets(
        CONFIG["BASE_PATH"], CONFIG["OUTPUT_PATH"], CONFIG["TRAIN_SPLIT_RATIO"]
    )
    
    feature_extractor = EmbeddingNet(embedding_dim=CONFIG["EMBEDDING_DIM"])
    arcface_head = ArcMarginProduct(
        in_features=CONFIG["EMBEDDING_DIM"], out_features=num_train_classes,
        s=CONFIG["ARCFACE_SCALE"], m=CONFIG["ARCFACE_MARGIN"]
    )
    
    data_transform = feature_extractor.transforms
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(
        list(feature_extractor.parameters()) + list(arcface_head.parameters()),
        lr=CONFIG["LEARNING_RATE"]
    )
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=2, verbose=True)

    train_dataset = FaceClassificationDataset(data_dir=train_path, transform=data_transform)
    train_loader = DataLoader(train_dataset, batch_size=CONFIG["BATCH_SIZE"], shuffle=True, num_workers=2)
    
    # --- OPTIMIZATION: Prepare validation set paths ONCE before training ---
    print("\n--- Pre-calculating validation set file paths to optimize validation loop ---")
    val_gallery_paths, val_query_set, val_person_classes = prepare_evaluation_sets(val_path)
    print(f"Found {len(val_gallery_paths)} reference classes and {len(val_query_set)} query images in the validation set.")

    train_model(
        feature_extractor, arcface_head, optimizer, criterion, scheduler,
        train_loader, val_gallery_paths, val_query_set, val_person_classes, 
        data_transform, DEVICE, CONFIG["EPOCHS"], CONFIG["BEST_MODEL_PATH"]
    )
    
    # --- Final evaluation using the BEST saved model ---
    print("\n--- Pre-calculating test set file paths for final evaluation ---")
    test_gallery_paths, test_query_set, test_person_classes = prepare_evaluation_sets(test_path)
    print(f"Found {len(test_gallery_paths)} reference classes and {len(test_query_set)} query images in the test set.")

    print("\n--- Loading best model for final evaluation on Test Set ---")
    feature_extractor.load_state_dict(torch.load(CONFIG["BEST_MODEL_PATH"]))
    evaluate(feature_extractor, test_gallery_paths, test_query_set, test_person_classes, data_transform, DEVICE)