In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
#for dirname, _, filenames in os.walk('/kaggle/input'):
    #for filename in filenames:
      #  print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# 🧹 Cleanup Script for Kaggle Working Directory

This script removes **all directories** inside the specified working directory (`/kaggle/working/`).

### Workflow:

1. Lists all items in the `working_dir`.
2. Iterates through each item:
   - If it is a **directory**, attempts to delete it and prints confirmation.
   - Handles any errors gracefully, printing an error message if removal fails.
3. (Optional) You can enable file removal by uncommenting the corresponding block.

### Usage:

- Keeps the Kaggle working directory clean by removing residual directories between runs.
- Helps avoid clutter or conflicts from previous outputs or temporary data.

---

**Note:**  
Be cautious when enabling file removal to avoid accidental deletion of important files.



In [None]:
import os
import shutil

# Define the working directory
working_dir = '/kaggle/working/'

# Get a list of all items in the working directory
items_in_working = os.listdir(working_dir)

# Iterate through the items and remove directories
print(f"Cleaning up {working_dir}...")
for item in items_in_working:
    item_path = os.path.join(working_dir, item)
    if os.path.isdir(item_path):
        try:
            shutil.rmtree(item_path)
            print(f"Removed directory: {item_path}")
        except OSError as e:
            print(f"Error removing directory {item_path}: {e}")
    # Optional: If you also want to remove files, uncomment the else if below
    # elif os.path.isfile(item_path):
    #     try:
    #         os.remove(item_path)
    #         print(f"Removed file: {item_path}")
    #     except OSError as e:
    #         print(f"Error removing file {item_path}: {e}")

print("Cleanup complete.")

# 📦 Install OpenCV-Python

To install the OpenCV Python package, run:

```bash
pip install opencv-python


In [None]:
pip install opencv-python

# 📥 Download Face Detection Model Files

This snippet performs the following:

1. **Creates a directory** at `/kaggle/working/face-detection-model-files/` if it does not already exist.
2. **Downloads two essential files** for OpenCV’s DNN face detector into that directory:
   - `deploy.prototxt`: The model architecture definition.
   - `res10_300x300_ssd_iter_140000.caffemodel`: The pre-trained weights.

---

### Bash commands used inside Python with `wget`:

```bash
wget -P /kaggle/working/face-detection-model-files/ https://raw.githubusercontent.com/opencv/opencv/4.x/samples/dnn/face_detector/deploy.prototxt
wget -P /kaggle/working/face-detection-model-files/ https://github.com/opencv/opencv_3rdparty/raw/dnn_samples_face_detector_20170830/res10_300x300_ssd_iter_140000.caffemodel


In [None]:
import os

# Create the directory if it doesn't exist
model_dir = "/kaggle/working/face-detection-model-files/"
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
    print(f"Created directory: {model_dir}")

# Download deploy.prototxt
!wget -P {model_dir} https://raw.githubusercontent.com/opencv/opencv/4.x/samples/dnn/face_detector/deploy.prototxt

# Download res10_300x300_ssd_iter_140000.caffemodel
!wget -P {model_dir} https://github.com/opencv/opencv_3rdparty/raw/dnn_samples_face_detector_20170830/res10_300x300_ssd_iter_140000.caffemodel

# Face Preprocessing and Single-Face Detection Pipeline

This script preprocesses images from an input directory, detects a single face using a pretrained Caffe SSD model, crops the detected face, and saves it to an output directory. If no high-confidence face is found, it selects the detection with the highest confidence. If no face is detected at all, a dummy black image is saved as a placeholder.

---

## Features

- **Image preprocessing:**  
  - Grayscale conversion  
  - Normalization  
  - Gaussian blur for noise reduction  
  - Laplacian filter for edge enhancement  
- **Face detection:**  
  - Uses OpenCV's DNN face detector with Caffe model (ResNet SSD)  
  - Selects the best face based on confidence threshold (default 0.5)  
- **Robust handling:**  
  - If no high-confidence detection, fallback to best detection overall  
  - Creates dummy black image if no face is detected  
- **Batch processing:**  
  - Processes images grouped by dataset splits (`train`, `val`) and gender (`male`, `female`)

---

## Dependencies

- Python 3.x  
- OpenCV (`opencv-python`)  
- NumPy  

Install OpenCV if not installed:

```bash
pip install opencv-python


In [None]:
import cv2
import os
import numpy as np

def preprocess_and_detect_face(input_folder, output_folder):
    """
    Preprocesses images, detects a single face, and saves the cropped face
    to a new folder. If no high-confidence face is found, it selects the
    detection with the highest overall confidence.

    Args:
        input_folder (str): Path to the folder containing input images.
        output_folder (str): Path to the folder where preprocessed and
                              cropped face images will be saved.
    """

    # Create the output folder if it doesn't exist
    try:
        if not os.path.exists(output_folder):
            os.makedirs(output_folder)
            print(f"Created output folder: {output_folder}")
        else:
            print(f"Output folder already exists: {output_folder}")
    except OSError as e:
        print(f"Error creating output folder '{output_folder}': {e}")
        print("Please check your permissions or the specified path.")
        return # Exit if folder cannot be created

    # --- Load pre-trained face detection model (Caffe model - ResNet SSD) ---
    # These files are crucial for face detection.
    prototxt_path = "/kaggle/working/face-detection-model-files/deploy.prototxt"
    caffemodel_path = "/kaggle/working/face-detection-model-files/res10_300x300_ssd_iter_140000.caffemodel"

    if not os.path.exists(prototxt_path) or not os.path.exists(caffemodel_path):
        print("Error: Pre-trained Caffe model files (deploy.prototxt and res10_300x300_ssd_iter_140000.caffemodel) not found.")
        print("Please ensure they are correctly placed at the specified Kaggle input path:")
        print(f"Prototxt Path: {prototxt_path}")
        print(f"Caffemodel Path: {caffemodel_path}")
        print("Download links:")
        print("deploy.prototxt download: https://github.com/opencv/opencv/blob/4.x/samples/dnn/face_detector/deploy.prototxt")
        print("res10_300x300_ssd_iter_140000.caffemodel download: https://github.com/opencv/opencv_3rdparty/raw/dnn_samples_face_detector_20170830/res10_300x300_ssd_iter_140000.caffemodel")
        return

    net = cv2.dnn.readNetFromCaffe(prototxt_path, caffemodel_path)
    print("Pre-trained face detection model loaded successfully.")

    # Iterate through all files in the input folder
    if not os.path.exists(input_folder):
        print(f"Error: Input folder '{input_folder}' does not exist. Please check the path.")
        return
        
    for filename in os.listdir(input_folder):
        # Skip if it's a directory
        if os.path.isdir(os.path.join(input_folder, filename)):
            continue

        if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
            image_path = os.path.join(input_folder, filename)
            print(f"\nProcessing image: {image_path}")

            # --- Read the image ---
            image = cv2.imread(image_path)
            if image is None:
                print(f"Warning: Could not read image {filename}. Skipping.")
                continue

            original_image = image.copy() # Keep a copy for later cropping

            # --- Phase 1: Image Pre-processing for Enhanced Detection ---
            # 1. Grayscale Conversion: Reduces computational complexity.
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

            # 2. Normalization: Standardizes pixel intensity values.
            # Convert to float and normalize to [0, 1]
            normalized_gray = gray.astype(np.float32) / 255.0
            
            # 3. Noise Reduction (Gaussian Blur): Smoothes image, removes noise.
            # Kernel size (5,5) is a common choice, adjust as needed.
            blurred_image = cv2.GaussianBlur(normalized_gray, (5, 5), 0)

            # 4. Edge Enhancement (Laplacian Filter): Highlights boundaries.
            # Re-normalize to 0-255 after applying filter, as Laplacian can produce negative values.
            laplacian = cv2.Laplacian(blurred_image, cv2.CV_32F)
            # Scale and shift to bring values into 0-1 range, then back to 0-255
            sharpened_image = cv2.normalize(laplacian, None, 0, 1, cv2.NORM_MINMAX)
            preprocessed_image = np.uint8(sharpened_image * 255) # Convert back to 8-bit for DNN

            # IMPORTANT FIX: Convert preprocessed_image to 3-channel (BGR) as the Caffe model expects 3 channels.
            preprocessed_image_bgr = cv2.cvtColor(preprocessed_image, cv2.COLOR_GRAY2BGR)

            # --- Phase 2: Robust Single-Face Detection ---
            # Prepare the preprocessed image for the DNN model
            # The model expects a blob of images.
            # Args: image, scalefactor, size, mean, swapRB, crop
            # scalefactor: Multiplier for image values (1.0/255.0 if input is 0-255)
            # size: Spatial size for output image (300x300 for this Caffe model)
            # mean: Mean subtraction values (B, G, R)
            # swapRB: Swap R and B channels (True for BGR to RGB)
            # crop: Crop image after resizing (True)
            h, w = preprocessed_image_bgr.shape[:2] # Get height and width from the 3-channel image
            blob = cv2.dnn.blobFromImage(cv2.resize(preprocessed_image_bgr, (300, 300)), 1.0,
                                         (300, 300), (104.0, 177.0, 123.0), False, False)
            net.setInput(blob)
            detections = net.forward()

            # --- Post-processing for Single Face Extraction ---
            # Initialize variables to store the best detection
            max_confidence_above_threshold = -1
            best_bbox_above_threshold = None

            max_overall_confidence = -1
            best_overall_bbox = None

            # Loop over the detections to find both high-confidence and overall best
            for i in range(0, detections.shape[2]):
                confidence = detections[0, 0, i, 2]
                box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
                current_bbox = tuple(box.astype("int"))

                # Keep track of the best detection above threshold
                if confidence > 0.5: # Confidence threshold, can be adjusted
                    if confidence > max_confidence_above_threshold:
                        max_confidence_above_threshold = confidence
                        best_bbox_above_threshold = current_bbox
                
                # Always keep track of the overall best detection
                if confidence > max_overall_confidence:
                    max_overall_confidence = confidence
                    best_overall_bbox = current_bbox
            
            # Determine which bounding box to use
            final_bbox = None
            if best_bbox_above_threshold is not None:
                final_bbox = best_bbox_above_threshold
                print(f"Using high-confidence detection (confidence: {max_confidence_above_threshold:.2f}).")
            elif best_overall_bbox is not None:
                final_bbox = best_overall_bbox
                print(f"No high-confidence face found. Using best overall detection (confidence: {max_overall_confidence:.2f}).")
            
            if final_bbox:
                (startX, startY, endX, endY) = final_bbox
                # Ensure bounding box coordinates are within image boundaries
                startX = max(0, startX)
                startY = max(0, startY)
                endX = min(w, endX)
                endY = min(h, endY) # Corrected: Ensure endY is within height bounds.

                # Crop the face from the original image (not the preprocessed one)
                cropped_face = original_image[startY:endY, startX:endX]

                if cropped_face.size == 0:
                    print(f"Warning: Cropped face from {filename} is empty or invalid. This might be due to bad bbox coordinates.")
                    print(f"Creating a dummy black image for {filename} as a placeholder.")
                    cropped_face = np.zeros((100, 100, 3), dtype=np.uint8) # Create a 100x100 black image
                
                # Define the output path for the cropped face
                output_filename = f"preprocessed_face_{filename}"
                output_path = os.path.join(output_folder, output_filename)

                # Save the cropped face image
                cv2.imwrite(output_path, cropped_face)
                print(f"Saved preprocessed and cropped face to: {output_path}")

            else:
                print(f"No face detections found for {filename} whatsoever. Creating a dummy black image as a placeholder.")
                # If no detection is found at all, create a dummy black image as a placeholder.
                dummy_image = np.zeros((100, 100, 3), dtype=np.uint8)
                output_filename = f"preprocessed_face_{filename}"
                output_path = os.path.join(output_folder, output_filename)
                cv2.imwrite(output_path, dummy_image)
                print(f"Saved dummy black image to: {output_path}")


# --- Main execution block ---
if __name__ == "__main__":
    # Define your base input and output folders
    # IMPORTANT: Adjust this path to your actual base input image folder on Kaggle
    base_input_folder = "/kaggle/input/comsys-taska/Task_A"
    base_output_folder = "/kaggle/working/preprocessed_faces" # Base output folder for all preprocessed faces

    # Define the combinations of sets and genders
    sets = ["train", "val"]
    genders = ["male", "female"]

    for data_set in sets:
        for gender in genders:
            current_input_folder = os.path.join(base_input_folder, data_set, gender)
            current_output_folder = os.path.join(base_output_folder, data_set, gender)
            
            print(f"\n--- Processing {data_set}/{gender} images ---")
            preprocess_and_detect_face(current_input_folder, current_output_folder)
            print(f"--- Finished processing {data_set}/{gender} images ---")

    print("\nAll processing complete. Check the specified output folder for preprocessed images.")

# PyTorch Lightning & EfficientNet Setup Overview

This document summarizes the key Python libraries and modules used for training an image classification model using PyTorch Lightning and EfficientNet architectures.

---

## Libraries and Their Roles

- **os**: For handling file and directory operations.

- **torch, torch.nn, torch.nn.functional**: Core PyTorch modules for tensor computations, neural network layers, and functions.

- **torch.utils.data.Dataset, DataLoader, WeightedRandomSampler**: Tools for creating datasets, loading data in batches, and balancing classes during training.

- **torchvision.transforms, models**: Utilities for image preprocessing, augmentation, and pretrained models.

- **PIL.Image**: Image reading and manipulation.

- **pytorch_lightning**: High-level framework to simplify training loops, logging, and checkpointing in PyTorch.

- **pytorch_lightning.loggers.CSVLogger**: For logging training metrics into CSV files.

- **pytorch_lightning.callbacks (ModelCheckpoint, EarlyStopping, LearningRateMonitor)**: Callbacks to save the best model, stop training early when performance plateaus, and monitor learning rate changes.

- **numpy**: Numerical computations and array manipulations.

- **collections.Counter**: For counting frequency of labels, useful in handling imbalanced datasets.

- **sklearn.metrics**: Metrics for model evaluation, including F1 score, ROC AUC, precision-recall curves, classification reports, confusion matrices, and accuracy.

- **warnings**: To suppress specific non-critical warnings during execution for cleaner output.

- **matplotlib.pyplot, seaborn**: Visualization libraries for plotting training metrics and confusion matrices.

- **tqdm**: Displays progress bars during loops for better tracking of long-running processes.

- **timm**: Provides access to EfficientNet models and other state-of-the-art pretrained architectures.

---

## Additional Notes

- The combination of these libraries enables efficient, scalable training and evaluation of deep learning models with strong support for transfer learning and model interpretability.

- Class balancing is handled via `WeightedRandomSampler` to mitigate dataset imbalance.

- The use of PyTorch Lightning streamlines the training pipeline, reducing boilerplate and increasing reproducibility.

---


In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms, models
from PIL import Image
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
import numpy as np
from collections import Counter
from sklearn.metrics import f1_score, roc_auc_score, precision_recall_curve, auc, classification_report, confusion_matrix, accuracy_score
import warnings
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
# Import timm library for EfficientNet
import timm

# Suppress specific warnings if needed, e.g., about PIL image conversion or small batches
warnings.filterwarnings("ignore", category=UserWarning, module='torchvision.transforms.functional_tensor')


# Custom Modules Documentation

---

## 0. Squeeze-and-Excitation (SE) Block

The SE Block is designed to recalibrate channel-wise feature responses by explicitly modelling interdependencies between channels. It works by:

- Applying global average pooling across the feature dimension to extract channel-wise statistics.
- Passing these statistics through two fully connected layers with a ReLU activation in between.
- Using a sigmoid activation to generate per-channel gating weights.
- Multiplying the original features by these gating weights to emphasize important channels adaptively.

This mechanism helps the network focus on the most informative features, improving performance in tasks involving feature extraction.

---

## 0. Focal Loss

Focal Loss is a loss function that addresses the problem of class imbalance by:

- Down-weighting easy examples so the model focuses more on hard, misclassified examples.
- Incorporating a tunable focusing parameter (gamma) that controls the rate at which easy examples are down-weighted.
- Supporting optional class weighting (alpha) to balance the importance of different classes.
- Allowing label smoothing to reduce overconfidence and improve generalization.
- Providing flexible reduction modes (mean, sum, or none) for the output loss.

This loss is particularly effective in tasks such as object detection or classification where there is a significant class imbalance.

---


In [None]:
# --- 0. Squeeze-and-Excitation (SE) Block Implementation ---
class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1) # Global average pooling across sequence dimension (for 1D feature vectors)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        # x is assumed to be of shape [batch_size, feature_dim]
        # For SE, we want to operate on channels. If x is 2D, we can treat the feature_dim as channels.
        # But global average pooling expects [N, C, H, W] or similar.
        # So, we'll simulate a (height, width) of 1 for 1D feature vectors.
        b, c = x.size()
        # Reshape to [B, C, 1] to apply AvgPool1d, then squeeze back to [B, C]
        y = self.avg_pool(x.unsqueeze(-1)).squeeze(-1)
        y = self.fc(y).view(b, c) # Reshape back to [B, C] for element-wise multiplication
        return x * y.expand_as(x) # Element-wise scale

# --- 0. Focal Loss Implementation ---
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean', epsilon=1e-12, label_smoothing=0.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.epsilon = epsilon
        self.label_smoothing = label_smoothing

        if self.alpha is not None:
            if not isinstance(self.alpha, torch.Tensor):
                self.alpha = torch.tensor(self.alpha, dtype=torch.float32)

    def forward(self, inputs, targets):
        # inputs: raw logits (before softmax/sigmoid) of shape (N, C)
        # targets: class labels of shape (N,)

        num_classes = inputs.shape[1]
        
        # Apply label smoothing to targets
        if self.label_smoothing > 0:
            smoothed_targets = torch.full_like(inputs, self.label_smoothing / (num_classes - 1))
            # Scatter the 1.0-label_smoothing value to the true class position
            smoothed_targets.scatter_(1, targets.unsqueeze(1), 1.0 - self.label_smoothing)
        else:
            # Standard one-hot encoding if no smoothing
            smoothed_targets = F.one_hot(targets, num_classes=num_classes).float()

        # Compute log probabilities (log(pt))
        log_pt = F.log_softmax(inputs, dim=1)
        
        # Calculate pt (probabilities) from log_pt for the focusing term
        pt = torch.exp(log_pt)

        # Get pt for the true (hard) class for the focusing term
        # This uses the original (hard) target labels
        pt_true_class = pt.gather(1, targets.long().unsqueeze(1)).squeeze()

        # Compute the base loss, which is the smoothed cross-entropy
        # Summing over the channels/classes dimension for the smoothed targets
        base_loss = -(smoothed_targets * log_pt).sum(dim=1)
        
        # Focusing mechanism
        focal_term = (1 - pt_true_class).pow(self.gamma)
        loss = focal_term * base_loss

        # Alpha weighting (applies to the loss per sample before reduction)
        if self.alpha is not None:
            if self.alpha.device != inputs.device:
                self.alpha = self.alpha.to(inputs.device)
            # alpha_t based on original hard target
            alpha_t = self.alpha.gather(0, targets.long())
            loss = alpha_t * loss
        
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else: # 'none'
            return loss

# Enhanced Supervised Contrastive Loss with Class Reweighting and Hard Negative Mining

This loss function extends the supervised contrastive learning objective with additional features to improve training robustness and handle class imbalance effectively.

## Key Concepts:

- **Contrastive Learning:** Encourages features of samples from the same class (positives) to be closer in the feature space, while pushing apart features from different classes (negatives).

- **Multiple Views:** Can operate on multiple augmented views of each sample to create stronger representations. Supports two modes:
  - `'all'`: Uses all views for contrastive pairs.
  - `'one'`: Uses a single anchor and one contrastive view.

- **Temperature Scaling:** Controls the sharpness of the similarity distribution used in the contrastive loss, helping with training stability.

## Enhancements:

1. **Class Reweighting:**
   - Allows assigning different weights to each class.
   - Helps mitigate class imbalance by giving more importance to under-represented classes.
   - Weights are applied to the positive pairs' contribution to the loss.

2. **Hard Negative Mining:**
   - Instead of treating all negatives equally, it selects a fraction (`hard_mining_ratio`) of the hardest negatives (those most similar to the anchor).
   - This focuses the training on the most confusing negative examples.
   - A margin is used to push negatives away further, enhancing separation.

3. **Dynamic Masking:**
   - Constructs masks to identify positive pairs (same class) and exclude self-contrast pairs.
   - Applies hard negative mining mask to include only selected negatives for the loss denominator.

## Workflow Overview:

- Features are normalized and cosine similarities between anchors and contrasts are computed.
- Logits are adjusted for numerical stability.
- Positive and negative pairs are masked.
- Margin is applied to negatives to increase their separation.
- Hard negatives are mined based on similarity scores.
- Class weights are applied to positive pairs.
- The final loss is computed as the weighted, temperature-scaled average of positive log probabilities.

## Use Cases:

- Suitable for imbalanced classification tasks where contrastive representation learning is beneficial.
- Enhances robustness by focusing learning on challenging negative samples.
- Useful in scenarios involving multi-view or augmented data representations.

---


In [None]:
# --- Enhanced Supervised Contrastive Loss with Class Reweighting and Hard Negative Mining ---
class EnhancedSupConLoss(nn.Module):
    def __init__(self, temperature=0.05, base_temperature=0.07, contrast_mode='all', 
                 hard_mining_ratio=0.35, margin=0.2):
        super(EnhancedSupConLoss, self).__init__()
        self.temperature = temperature
        self.base_temperature = base_temperature
        self.contrast_mode = contrast_mode
        self.hard_mining_ratio = hard_mining_ratio  # Ratio of hard negatives to mine
        self.margin = margin  # Margin to push negatives further
        
    def forward(self, features, labels=None, mask=None, class_weights=None):        
        """        
        Args:        
            features: hidden vector of shape [bsz, n_views, feature_dim] during training.
                      During validation/inference, it might be [bsz, feature_dim]
                      In that case, it's unsqueezed to [bsz, 1, feature_dim].
            labels: ground truth of shape [bsz].        
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j        
                  has the same class as sample i. Can be asymmetric.        
            class_weights: dictionary mapping class indices to weights.        
        Returns:        A loss scalar.        
        """
        device = features.device

        # Handle features shape: [bsz, n_views, feature_dim]
        # If features are [bsz, feature_dim], unsqueeze to [bsz, 1, feature_dim] for consistency
        if len(features.shape) < 3:
            features = features.unsqueeze(1) # Add a view dimension if it's missing (e.g., in validation)
        
        batch_size = features.shape[0]
        original_n_views = features.shape[1] # Number of views per original sample

        if self.contrast_mode == 'one':
            # For 'one' mode, we need at least two views (anchor, contrast)
            if original_n_views < 2:
                raise ValueError("`contrast_mode='one'` requires at least 2 views (e.g., [bsz, 2, feature_dim])")
            anchor_feature = features[:, 0]
            contrast_feature = features[:, 1]
            # Labels correspond to the original batch size for `contrast_mode='one'`
            # and are not repeated for the features.
        elif self.contrast_mode == 'all':
            # Reshape to (batch_size * n_views, feature_dim)
            anchor_feature = features.view(-1, features.shape[-1])
            contrast_feature = anchor_feature
            
            if labels is not None:
                # Repeat labels for each view to match the flattened feature size
                labels = labels.repeat_interleave(original_n_views)
                
            batch_size = anchor_feature.shape[0] # Update batch_size after flattening
        else:
            raise ValueError('Unknown contrast mode: {}'.format(self.contrast_mode))
            
        # Compute logits (cosine similarity)
        # Normalize features for cosine similarity
        anchor_feature = F.normalize(anchor_feature, dim=1)
        contrast_feature = F.normalize(contrast_feature, dim=1)

        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature
        )
        
        # For numerical stability (max trick) - apply before exp
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        anchor_dot_contrast = anchor_dot_contrast - logits_max.detach()
        
        # Mask diagonal (self-contrast) if contrast_feature is anchor_feature
        # Only relevant for contrast_mode='all' where anchor_feature == contrast_feature
        logits_mask = 1 - torch.eye(batch_size, device=device) if self.contrast_mode == 'all' else torch.ones(batch_size, batch_size, device=device)
        
        # Create mask for positive pairs
        if mask is None: 
            mask = torch.eq(labels.unsqueeze(1), labels.unsqueeze(0)).float().to(device)
        
        # Hard negative mining: find the hardest negatives (highest similarity)
        # Exclude positives and self-contrast from neg_mask
        neg_mask = (1 - mask) * logits_mask # Mask for negative pairs, excluding self and positives
        
        # Apply margin to negative pairs (push them further away)
        # Only modify the logits for actual negative pairs.
        if self.margin > 0:
            # Create a margined version of negative logits, only where neg_mask is active
            margined_neg_logits = (anchor_dot_contrast * neg_mask) - (self.margin * neg_mask)
            # Combine positives (original logits) with margined negatives
            anchor_dot_contrast = (anchor_dot_contrast * mask) + margined_neg_logits
        
        # Hard negative mining: select the hardest negatives
        # The goal here is to dynamically adjust `logits_mask` to only include the *k* hardest negatives
        # for each anchor, in addition to all positive pairs.
        if self.hard_mining_ratio < 1.0 and self.hard_mining_ratio > 0:
            k = int(batch_size * self.hard_mining_ratio)
            k = max(k, 1) # At least one negative if possible

            # We need to compute the `exp_logits` and `log_prob` using the selected hard negatives.
            # `logits_mask` will represent the final set of pairs to consider for the denominator of the loss.
            # Start with a mask that includes all positives.
            current_logits_mask = mask.clone() # This will be built upon

            # Iterate through each anchor in the batch
            for i in range(batch_size):
                # Identify valid negative indices for the current anchor `i`
                # Only consider where `neg_mask[i]` is 1 (actual negatives, not self or positives)
                valid_neg_indices = torch.where(neg_mask[i] > 0)[0]

                if len(valid_neg_indices) > 0:
                    # Get the similarity scores for these valid negatives
                    neg_sims = anchor_dot_contrast[i, valid_neg_indices]
                
                    # Select the top k hardest negatives (highest similarity)
                    k_actual = min(k, len(valid_neg_indices))
                    if k_actual > 0:
                        # topk returns values and indices.
                        # We only need indices.
                        _, hard_neg_local_indices = torch.topk(neg_sims, k_actual)
                        
                        # Map local indices back to global batch indices
                        hard_neg_global_indices = valid_neg_indices[hard_neg_local_indices]
                    
                        # Add these hard negatives to the current anchor's `logits_mask`
                        current_logits_mask[i, hard_neg_global_indices] = 1.0
                
            logits_mask = current_logits_mask # Update the main logits_mask for calculations
        
        # The denominator for the log_prob will sum over all entries where logits_mask is 1
        # This now includes positives AND selected hard negatives
        exp_logits = torch.exp(anchor_dot_contrast) * logits_mask
        log_prob = anchor_dot_contrast - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12)
        
        # Calculate loss considering class weights for positive pairs
        # The sum is only over positive pairs as defined by `mask`
        positive_log_probs = log_prob * mask # This zeros out non-positive pairs for summing

        # Sum of log probabilities for positive pairs
        sum_positive_log_probs = positive_log_probs.sum(1)
        # Count of positive pairs for normalization (add epsilon to avoid div by zero)
        count_positive_pairs = mask.sum(1) + 1e-12

        # Apply class weights to the mean of positive log probabilities
        if class_weights is not None and labels is not None:
            # Need to get original labels for weighting.
            # If contrast_mode='all', labels were repeated, so take original batch labels.
            if self.contrast_mode == 'all':
                original_labels_for_weights = labels[::original_n_views] # Assumes labels were repeated for `original_n_views`
            else: # contrast_mode='one'
                original_labels_for_weights = labels

            # Map class labels to their respective weights
            weight_values = torch.tensor([class_weights.get(label.item(), 1.0)
                                          for label in original_labels_for_weights], device=device)
            
            # Repeat weights for each view if contrast_mode='all' to match current labels dimension
            if self.contrast_mode == 'all':
                weight_values = weight_values.repeat_interleave(original_n_views)

            # Apply weights to the mean log probability
            # We are applying weight to each anchor's contribution
            weighted_mean_log_prob_pos = (sum_positive_log_probs * weight_values) / count_positive_pairs
        else:
            weighted_mean_log_prob_pos = sum_positive_log_probs / count_positive_pairs
        
        # Final loss calculation
        loss = -(self.temperature / self.base_temperature) * weighted_mean_log_prob_pos
        loss = loss.mean() # Average across the batch
        
        return loss

# Custom Dataset for Multi-View Augmentation and Combined Images

## Dataset Paths

- `KAGGLE_INPUT_PATH`: Root directory containing the original dataset.
- `PREPROCESSED_PATH`: Directory containing preprocessed face images.
- Separate folders for training and validation data both in original and preprocessed forms:
  - `TRAIN_PATH` and `VAL_PATH` for original images.
  - `TRAIN_PREPROCESSED_PATH` and `VAL_PREPROCESSED_PATH` for preprocessed images.

## GenderDataset Class Overview

This custom PyTorch `Dataset` is designed to handle paired inputs consisting of original and preprocessed face images along with their labels.

### Key Features

- **Data Loading:**
  - Loads images from class-specific folders: `"female"` and `"male"`.
  - Maintains consistent mapping: `female` → 0, `male` → 1.
  - Loads both original and corresponding preprocessed images.
  - Verifies existence of both original and preprocessed images to ensure pairing.
  - Issues warnings for any missing directories or missing preprocessed images.

- **Data Augmentation and Views:**
  - Supports multi-view augmentations during training for contrastive learning.
    - Applies two different random transformations to both original and preprocessed images, producing four augmented views per sample.
  - For validation, applies a single deterministic transform to both images.
  
- **Handling Corrupted Images:**
  - Includes exception handling when loading images.
  - If loading fails, replaces images with a black placeholder image to prevent crashes.

- **Label and Class Counts:**
  - Stores labels alongside images.
  - Keeps track of class counts for potential use in weighting or analysis.

### Return Format

- **Training mode:** Returns a tuple of four augmented views — two from original images, two from preprocessed images — along with the label.
- **Validation mode:** Returns one transformed original image and one transformed preprocessed image with the label.

---

This setup is useful for training models that leverage multiple views of both raw and preprocessed images, such as contrastive or multi-input networks for gender classification.


In [None]:
# Define paths
KAGGLE_INPUT_PATH = '/kaggle/input/comsys-taska/Task_A'
PREPROCESSED_PATH = '/kaggle/working/preprocessed_faces'

TRAIN_PATH = os.path.join(KAGGLE_INPUT_PATH, 'train')
VAL_PATH = os.path.join(KAGGLE_INPUT_PATH, 'val')

TRAIN_PREPROCESSED_PATH = os.path.join(PREPROCESSED_PATH, 'train')
VAL_PREPROCESSED_PATH = os.path.join(PREPROCESSED_PATH, 'val')

# --- 1. Custom Dataset for Multi-View Augmentation and Combined Images ---
class GenderDataset(Dataset):
    def __init__(self, data_dir, preprocessed_data_dir, transform=None, is_train=True):
        self.data_dir = data_dir
        self.preprocessed_data_dir = preprocessed_data_dir
        self.transform = transform
        self.is_train = is_train
        self.image_paths = []
        self.preprocessed_image_paths = []
        self.labels = [] # 0 for female, 1 for male (consistent mapping)
        self.class_to_idx = {'female': 0, 'male': 1}
        self.idx_to_class = {0: 'female', 1: 'male'}
        
        print(f"Loading dataset from: {data_dir} (is_train={is_train})")

        temp_image_paths = []
        temp_preprocessed_image_paths = []
        temp_labels = []

        for gender in ['female', 'male']:
            gender_path = os.path.join(data_dir, gender)
            preprocessed_gender_path = os.path.join(preprocessed_data_dir, gender)
            class_idx = self.class_to_idx[gender]
            
            if not os.path.exists(gender_path):
                print(f"Warning: Directory not found: {gender_path}. Skipping.")
                continue
            if not os.path.exists(preprocessed_gender_path):
                print(f"Warning: Directory not found: {preprocessed_gender_path}. Skipping.")
                continue

            for img_name in os.listdir(gender_path):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
                    original_img_path = os.path.join(gender_path, img_name)
                    # Construct preprocessed image path
                    preprocessed_img_name = f"preprocessed_face_{img_name}" # Assuming this naming convention from preprocessing
                    preprocessed_img_path = os.path.join(preprocessed_gender_path, preprocessed_img_name)

                    if os.path.exists(preprocessed_img_path):
                        temp_image_paths.append(original_img_path)
                        temp_preprocessed_image_paths.append(preprocessed_img_path)
                        temp_labels.append(class_idx)
                    else:
                        print(f"Warning: Corresponding preprocessed image not found for {original_img_path}. Skipping.")
        
        # Apply outlier detection only for training data if threshold is provided
        self.image_paths = temp_image_paths
        self.preprocessed_image_paths = temp_preprocessed_image_paths
        self.labels = temp_labels

        self.class_counts = Counter(self.labels)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        preprocessed_img_path = self.preprocessed_image_paths[idx]
        label = self.labels[idx]
        
        try:
            img = Image.open(img_path).convert('RGB') # Load as RGB
        except Exception as e:
            print(f"Error loading original image {img_path}: {e}")
            img = Image.new('RGB', (224, 224), color='black')

        try:
            preprocessed_img = Image.open(preprocessed_img_path).convert('RGB') # Load as RGB
        except Exception as e:
            print(f"Error loading preprocessed image {preprocessed_img_path}: {e}")
            preprocessed_img = Image.new('RGB', (224, 224), color='black')

        if self.transform:
            if self.is_train:
                # Apply two different random augmentations for contrastive learning
                img1_original = self.transform(img)
                img2_original = self.transform(img)

                img1_processed = self.transform(preprocessed_img)
                img2_processed = self.transform(preprocessed_img)
                
                # Return tuples of (original_view1, processed_view1), (original_view2, processed_view2)
                return (img1_original, img2_original, img1_processed, img2_processed), label
            else:
                # For validation, apply a single, deterministic transform
                img_transformed = self.transform(img)
                preprocessed_img_transformed = self.transform(preprocessed_img)
                return (img_transformed, preprocessed_img_transformed), label
        
        return (img, preprocessed_img), label


# PyTorch Lightning DataModule for Gender Classification

## Purpose

This `GenderDataModule` class is designed to streamline data loading, preprocessing, and batching for training and validation in a PyTorch Lightning workflow. It handles paired datasets of original and preprocessed face images with support for class imbalance correction.

---

## Key Components

### Initialization

- Takes directories for:
  - Training original images
  - Validation original images
  - Training preprocessed images
  - Validation preprocessed images
- Configurable batch size, number of workers, and image size.
- Defines data augmentation and normalization pipelines for training and validation:
  - **Training augmentations** include resizing, cropping, flips, color jitter, rotations, blurring, perspective distortions, and random erasing — enhancing robustness.
  - **Validation transforms** apply standard resizing and center cropping, with normalization.

### Setup Method

- Loads training and validation datasets using the custom `GenderDataset` class.
- Calculates class counts and computes **class weights** using inverse frequency:
  - This weighting mitigates class imbalance by giving higher loss weight to minority classes.
- Stores these weights both as a dictionary and a tensor for later use.

### Data Loaders

- **Training DataLoader:**
  - Uses a `WeightedRandomSampler` to balance sampling according to class weights.
  - Ensures the model sees a balanced representation of classes each epoch despite dataset imbalance.
- **Validation DataLoader:**
  - Provides deterministic batches without shuffling.

---

## Benefits

- Automated multi-view augmentation suited for contrastive or multi-input models.
- Handles class imbalance effectively during sampling and loss computation.
- Clean integration with PyTorch Lightning for easy training loops.

---

This DataModule is a solid base for training robust gender classification models using both raw and preprocessed face images.


In [None]:
# --- 2. PyTorch Lightning DataModule ---
class GenderDataModule(pl.LightningDataModule):
    def __init__(self, train_data_dir, val_data_dir, train_preprocessed_data_dir, val_preprocessed_data_dir, batch_size=64, num_workers=4, image_size=(224, 224)):
        super().__init__()
        self.train_data_dir = train_data_dir
        self.val_data_dir = val_data_dir
        self.train_preprocessed_data_dir = train_preprocessed_data_dir
        self.val_preprocessed_data_dir = val_preprocessed_data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.image_size = image_size

        # ImageNet mean and std for normalization as recommended by models
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        # Define transformations
        self.train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(0.7, 1.0), ratio=(0.75, 1.33)), # Simulates scale and crop
            transforms.RandomHorizontalFlip(), # Geometric transformation
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1), # Color space transformation
            transforms.RandomRotation(degrees=20), # Geometric transformation
            transforms.GaussianBlur(kernel_size=3), # Add some blur/noise
            transforms.RandomPerspective(distortion_scale=0.2, p=0.5), # New: Perspective distortion
            transforms.ToTensor(), # Convert PIL Image to PyTorch Tensor
            self.normalize, # Normalize pixel values
            transforms.RandomErasing(p=0.2, scale=(0.02, 0.1), ratio=(0.3, 3.3)) # Simulates occlusions
        ])
        
        self.val_transform = transforms.Compose([
            transforms.Resize(int(image_size[0] / 0.875)), # Standard resize
            transforms.CenterCrop(image_size), # Standard center crop
            transforms.ToTensor(),
            self.normalize
        ])

        self.train_dataset = None
        self.val_dataset = None
        self.class_weights_for_loss = None
        self.class_weights_tensor = None

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.train_dataset = GenderDataset(self.train_data_dir, self.train_preprocessed_data_dir, transform=self.train_transform, is_train=True)
            self.val_dataset = GenderDataset(self.val_data_dir, self.val_preprocessed_data_dir, 
                                             transform=self.val_transform, is_train=False)

            total_samples = len(self.train_dataset)
            num_classes = len(self.train_dataset.class_to_idx)
            
            female_count = self.train_dataset.class_counts.get(0, 0)
            male_count = self.train_dataset.class_counts.get(1, 0)

            # Inverse frequency weighting: total_samples / (num_classes * class_count)
            weight_female = total_samples / (num_classes * female_count) if female_count > 0 else 1.0
            weight_male = total_samples / (num_classes * male_count) if male_count > 0 else 1.0
            
            self.class_weights_for_loss = {0: weight_female, 1: weight_male}
            self.class_weights_tensor = torch.tensor([weight_female, weight_male], dtype=torch.float32)

            print(f"Calculated class weights for loss (Female: {self.class_weights_for_loss[0]:.2f}, Male: {self.class_weights_for_loss[1]:.2f})")
            print(f"Train dataset class counts: {self.train_dataset.class_counts}")


    def train_dataloader(self):
        labels = self.train_dataset.labels
        sample_weights = [self.class_weights_for_loss[label] for label in labels]
        
        sample_weights = np.array(sample_weights)
        sample_weights[~np.isfinite(sample_weights)] = 1.0

        sampler = WeightedRandomSampler(
            weights=list(sample_weights),
            num_samples=len(sample_weights),
            replacement=True
        )
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            sampler=sampler,
            num_workers=self.num_workers,
            pin_memory=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

# GenderClassificationModel Overview

## Model Architecture

- **Dual EfficientNet-B3 Backbones**  
  Extract feature embeddings independently from:
  - Original images  
  - Preprocessed images  
  These embeddings are concatenated to form a richer combined representation.

- **Squeeze-and-Excitation (SE) Block**  
  Applies channel-wise attention to recalibrate the concatenated features, improving representational power.

- **Projection Head**  
  Maps SE-recalibrated features to a lower-dimensional embedding space (128 dims) used for **Supervised Contrastive Loss** (SupCon).

- **Classification Head**  
  A strengthened multi-layer perceptron with dropout, producing logits for the two-class gender classification task.

- **Learnable Loss Weight Parameter**  
  Dynamically balances the importance between the SupCon loss and the classification loss during training.

---

## Loss Functions

- **Enhanced Supervised Contrastive Loss (SupCon)**  
  Encourages the model to learn discriminative embeddings by pulling together samples of the same class and pushing apart samples of different classes.

- **Focal Loss with Label Smoothing**  
  Combats class imbalance by focusing training on hard-to-classify samples and softening hard labels to improve generalization.

- Class-wise weighting can be incorporated in focal loss based on dataset imbalance statistics.

---

## Forward Pass Modes

- **Training (4-tensor tuple input)**  
  Receives pairs of original and preprocessed images with two augmented views each.  
  Extracts features, recalibrates, projects for contrastive loss, and outputs classification logits.

- **Validation/Inference (2-tensor tuple input)**  
  Processes single views of original and preprocessed images to output classification logits.

---

## Training & Validation Steps

- Calculates both SupCon loss and focal classification loss, combined dynamically with the learnable weight parameter.  
- Logs losses and metrics (accuracy, F1, ROC-AUC, PR-AUC) at both batch and epoch levels.  
- Accumulates raw predictions and labels across batches for comprehensive epoch-end metrics.

---

## Optimizer & Scheduler

- AdamW optimizer with differentiated learning rates:
  - Backbone models get a base LR.
  - Projection and classification heads get a higher LR.
  - The learnable loss weighting parameter has the highest LR for faster adaptation.

- **Cosine Annealing Warm Restarts** scheduler to cyclically adjust learning rates, enhancing convergence.

---

## Utility: Visualization of Misclassifications

- Visualizes misclassified validation images by displaying original images alongside their true and predicted labels.  
- Supports showing a user-specified number of misclassifications or all if unspecified.

---

# Summary

This model elegantly fuses state-of-the-art CNN feature extractors with contrastive and classification objectives, balanced dynamically to leverage complementary training signals. It addresses data imbalance with focal loss and class weighting, and includes mechanisms for robust training monitoring and error analysis.


In [None]:
# --- 3. PyTorch Lightning Model with EfficientNetB3 and EnhancedSupConLoss ---
class GenderClassificationModel(pl.LightningModule):
    def __init__(self, num_classes=2, learning_rate=1e-4, weight_decay=1e-5, 
                 supcon_temp=0.07, supcon_base_temp=0.07, supcon_hard_mining_ratio=0.35, 
                 supcon_margin=0.2, class_weights_for_loss=None, class_weights_tensor=None,
                 max_epochs: int = 50, label_smoothing: float = 0.0, gamma: float = 2.0): # Added gamma here
        super().__init__()
        self.save_hyperparameters() # Saves all init args as self.hparams

        # Feature extractor for original images - using EfficientNetB3
        # num_classes=0 means it returns features without a classification head
        self.feature_extractor_original = timm.create_model('efficientnet_b3', pretrained=True, num_classes=0)
        # Feature dimension for efficientnet_b3 is 1536
        efficientnet_feature_dim = self.feature_extractor_original.num_features 

        # Feature extractor for preprocessed images - using EfficientNetB3
        self.feature_extractor_processed = timm.create_model('efficientnet_b3', pretrained=True, num_classes=0)
        
        # Combined feature dimension before projection/classification heads
        combined_feature_dim = efficientnet_feature_dim * 2

        # Squeeze-and-Excitation (SE) block for channel-wise feature recalibration
        self.se_block = SEBlock(channel=combined_feature_dim)
        
        # Learnable parameter for weighting between contrastive and classification loss
        self.loss_weight_param = nn.Parameter(torch.tensor(0.5, dtype=torch.float32)) # Initial weight of 0.5

        # Projection head for Supervised Contrastive Loss (takes SE-recalibrated concatenated features)
        self.projection_head = nn.Sequential(
            nn.Linear(combined_feature_dim, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 128) # Output embedding dimension for contrastive loss
        )
        
        # Strengthened Classification Head (takes SE-recalibrated concatenated features)
        self.classification_head = nn.Sequential(
            nn.Linear(combined_feature_dim, 256), # First hidden layer
            nn.ReLU(inplace=True),
            nn.Dropout(0.3), # Added dropout for regularization
            nn.Linear(256, num_classes) # Output layer
        )

        self.supcon_loss_fn = EnhancedSupConLoss(
            temperature=supcon_temp,
            base_temperature=supcon_base_temp,
            contrast_mode='all',
            hard_mining_ratio=supcon_hard_mining_ratio, 
            margin=supcon_margin 
        )
        
        if class_weights_tensor is not None:
            alpha_for_focal = torch.tensor([class_weights_for_loss[0], class_weights_for_loss[1]], dtype=torch.float32)
        else:
            alpha_for_focal = None

        self.focal_loss_fn = FocalLoss(
            alpha=alpha_for_focal, # Pass class-wise alpha
            gamma=self.hparams.gamma, # Gamma now correctly accessed from hparams
            label_smoothing=label_smoothing # Pass label smoothing parameter
        )
        
        self.class_weights_for_loss = class_weights_for_loss
        self.class_weights_tensor = class_weights_tensor

        self.train_raw_preds = []
        self.train_labels = []
        self.val_raw_preds = []
        self.val_labels = []

    def forward(self, x):
        if isinstance(x, tuple) and len(x) == 4: # Training phase: (img1_original, img2_original, img1_processed, img2_processed)
            img1_original, img2_original, img1_processed, img2_processed = x

            # Extract features from both original and preprocessed images
            features1_original = self.feature_extractor_original(img1_original)
            features2_original = self.feature_extractor_original(img2_original)
            features1_processed = self.feature_extractor_processed(img1_processed)
            features2_processed = self.feature_extractor_processed(img2_processed)

            # Concatenate features from corresponding original and processed views
            features1_combined = torch.cat((features1_original, features1_processed), dim=1)
            features2_combined = torch.cat((features2_original, features2_processed), dim=1)
            
            # Apply SE Block for feature recalibration after concatenation
            features1_combined_se = self.se_block(features1_combined)
            features2_combined_se = self.se_block(features2_combined)

            # Apply projection head for contrastive loss
            proj_features1 = self.projection_head(features1_combined_se)
            proj_features2 = self.projection_head(features2_combined_se)
            
            supcon_features = torch.stack((proj_features1, proj_features2), dim=1)
            
            # Use the combined features from the first view for classification
            logits = self.classification_head(features1_combined_se) 
            return supcon_features, logits
        elif isinstance(x, tuple) and len(x) == 2: # Validation/inference phase: (img_original, img_processed)
            img_original, img_processed = x
            features_original = self.feature_extractor_original(img_original)
            features_processed = self.feature_extractor_processed(img_processed)
            
            # Concatenate features for inference
            features_combined = torch.cat((features_original, features_processed), dim=1)
            
            # Apply SE Block for feature recalibration
            features_combined_se = self.se_block(features_combined)
            
            logits = self.classification_head(features_combined_se)
            return logits
        else:
            raise ValueError("Unexpected input format for forward pass. Expected tuple of 2 or 4 tensors.")


    def training_step(self, batch, batch_idx):
        (imgs1_original, imgs2_original, imgs1_processed, imgs2_processed), labels = batch
        
        supcon_features, logits = self((imgs1_original, imgs2_original, imgs1_processed, imgs2_processed))

        supcon_loss = self.supcon_loss_fn(
            features=supcon_features,
            labels=labels,
            class_weights=self.class_weights_for_loss
        )

        ce_loss = self.focal_loss_fn(logits, labels)
        
        # Dynamically weight the losses
        total_loss = self.loss_weight_param * supcon_loss + (1 - self.loss_weight_param) * ce_loss
        
        self.log('train_supcon_loss', supcon_loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_ce_loss', ce_loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_loss', total_loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('loss_weight_param', self.loss_weight_param, on_step=True, on_epoch=True, prog_bar=True)


        self.train_raw_preds.append(F.softmax(logits, dim=1).detach().cpu().numpy())
        self.train_labels.append(labels.cpu().numpy())
        
        return total_loss

    def validation_step(self, batch, batch_idx):
        (imgs_original, imgs_processed), labels = batch
        logits = self((imgs_original, imgs_processed))
        
        loss = self.focal_loss_fn(logits, labels)
        
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)

        self.val_raw_preds.append(F.softmax(logits, dim=1).detach().cpu().numpy())
        self.val_labels.append(labels.cpu().numpy())
        
        return loss

    def on_train_epoch_end(self):
        if len(self.train_raw_preds) == 0:
            return
        
        all_raw_preds = np.concatenate(self.train_raw_preds)
        all_labels = np.concatenate(self.train_labels)
        
        all_preds_classes = np.argmax(all_raw_preds, axis=1)

        accuracy = accuracy_score(all_labels, all_preds_classes)
        self.log('train_accuracy_epoch', accuracy, prog_bar=True)

        # Calculate F1-score
        f1 = f1_score(all_labels, all_preds_classes, average='weighted', zero_division=0)
        self.log('train_f1_epoch', f1, prog_bar=True)

        # Calculate ROC-AUC
        try:
            if len(np.unique(all_labels)) > 1:
                roc_auc = roc_auc_score(all_labels, all_raw_preds[:, 1], average='weighted')
                self.log('train_roc_auc_epoch', roc_auc, prog_bar=True)
            else:
                self.log('train_roc_auc_epoch', 0.0, prog_bar=True)
        except ValueError:
            self.log('train_roc_auc_epoch', 0.0, prog_bar=True)

        # Calculate PR-AUC
        try:
            if len(np.unique(all_labels)) > 1:
                precision, recall, _ = precision_recall_curve(all_labels, all_raw_preds[:, 1])
                pr_auc = auc(recall, precision)
                self.log('train_pr_auc_epoch', pr_auc, prog_bar=True)
            else:
                self.log('train_pr_auc_epoch', 0.0, prog_bar=True)
        except ValueError:
            self.log('train_pr_auc_epoch', 0.0, prog_bar=True)
        
        self.train_raw_preds.clear()
        self.train_labels.clear()

    def on_validation_epoch_end(self):
        if len(self.val_raw_preds) == 0:
            return
        
        all_raw_preds = np.concatenate(self.val_raw_preds)
        all_labels = np.concatenate(self.val_labels)
        
        all_preds_classes = np.argmax(all_raw_preds, axis=1)

        accuracy = accuracy_score(all_labels, all_preds_classes)
        self.log('val_accuracy_epoch', accuracy, prog_bar=True)

        f1 = f1_score(all_labels, all_preds_classes, average='weighted', zero_division=0)
        self.log('val_f1_epoch', f1, prog_bar=True)

        try:
            if len(np.unique(all_labels)) > 1:
                roc_auc = roc_auc_score(all_labels, all_raw_preds[:, 1], average='weighted')
                self.log('val_roc_auc_epoch', roc_auc, prog_bar=True)
            else:
                self.log('val_roc_auc_epoch', 0.0, prog_bar=True)
        except ValueError:
            self.log('val_roc_auc_epoch', 0.0, prog_bar=True)

        try:
            if len(np.unique(all_labels)) > 1:
                precision, recall, _ = precision_recall_curve(all_labels, all_raw_preds[:, 1])
                pr_auc = auc(recall, precision)
                self.log('val_pr_auc_epoch', pr_auc, prog_bar=True)
            else:
                self.log('val_pr_auc_epoch', 0.0, prog_bar=True)
        except ValueError:
            self.log('val_pr_auc_epoch', 0.0, prog_bar=True)
        
        self.val_raw_preds.clear()
        self.val_labels.clear()

    def configure_optimizers(self):
        # Separate optimizer for backbone vs. heads (including loss_weight_param)
        # This allows different learning rates for different parts of the model
        optimizer_params = [
            {'params': self.feature_extractor_original.parameters(), 'lr': self.hparams.learning_rate},
            {'params': self.feature_extractor_processed.parameters(), 'lr': self.hparams.learning_rate},
            {'params': self.se_block.parameters(), 'lr': self.hparams.learning_rate},
            {'params': self.projection_head.parameters(), 'lr': self.hparams.learning_rate * 2}, # Higher LR for heads
            {'params': self.classification_head.parameters(), 'lr': self.hparams.learning_rate * 2}, # Higher LR for heads
            {'params': self.loss_weight_param, 'lr': self.hparams.learning_rate * 5} # Even higher LR for the loss weight
        ]

        optimizer = torch.optim.AdamW(optimizer_params, weight_decay=self.hparams.weight_decay)
        
        # Using Cosine Annealing Warm Restarts scheduler
        scheduler = {
            'scheduler': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer,
                T_0=self.hparams.max_epochs // 5, # Number of epochs for the first restart cycle
                T_mult=2, # Multiplier for T_0 after each restart (cycle length increases)
                eta_min=self.hparams.learning_rate / 100, # Minimum learning rate
                verbose=True
            ),
            'monitor': 'val_loss', # Still monitor val_loss for logging/debugging purposes, though LR changes by itself
            'interval': 'epoch',
            'frequency': 1
        }
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

# --- Utility Function for Final Evaluation on Train and Val Sets ---
def visualize_misclassifications(model, dataloader, class_names, num_images=None): # Changed default to None
    """
    Visualizes misclassified images with their true and predicted labels.
    If num_images is None, all misclassified images are shown.
    """
    misclassified_samples = []
    
    # Ensure model is in evaluation mode
    model.eval()
    device = next(model.parameters()).device # Get current device of the model

    with torch.no_grad():
        for batch_idx, (imgs_tuple, labels) in enumerate(dataloader):
            # Unpack the tuple of original and preprocessed images
            imgs_original, imgs_processed = imgs_tuple 
            
            imgs_original = imgs_original.to(device)
            imgs_processed = imgs_processed.to(device)
            labels = labels.to(device)

            logits = model((imgs_original, imgs_processed)) # Pass as tuple
            predicted_probs = F.softmax(logits, dim=1)
            predicted_labels = torch.argmax(predicted_probs, dim=1)

            for i in range(len(labels)):
                if predicted_labels[i] != labels[i]:
                    # Detach the ORIGINAL image and move to CPU for plotting
                    # We plot original to see the actual input, not the preprocessed version
                    img_cpu = imgs_original[i].cpu() 
                    # Denormalize image for display (reverse ImageNet normalization)
                    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
                    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
                    img_display = img_cpu * std + mean
                    img_display = torch.clamp(img_display, 0, 1) # Clamp to [0,1]
                    
                    misclassified_samples.append({
                        'image': img_display,
                        'true_label': class_names[labels[i].item()],
                        'predicted_label': class_names[predicted_labels[i].item()]
                    })
            # If num_images is specified, stop collecting after that many
            if num_images is not None and len(misclassified_samples) >= num_images:
                break
    
    if misclassified_samples:
        display_count = len(misclassified_samples) if num_images is None else min(num_images, len(misclassified_samples))
        print(f"\n--- Visualizing {display_count} Misclassified Images ---")
        
        # Calculate grid dimensions: aiming for roughly square layout
        cols = 5
        rows = (display_count + cols - 1) // cols
        
        plt.figure(figsize=(15, 4 * rows)) # Adjust figure size dynamically
        for i, sample in enumerate(misclassified_samples[:display_count]):
            plt.subplot(rows, cols, i + 1)
            plt.imshow(sample['image'].permute(1, 2, 0).numpy()) # Convert C,H,W to H,W,C for imshow
            plt.title(f"True: {sample['true_label']}\nPred: {sample['predicted_label']}")
            plt.axis('off')
        plt.tight_layout()
        plt.show()
    else:
        print("\nNo misclassified images found in the validation set (or accuracy is very high)!")



# Detailed Model Evaluation Report Generator

This document explains the function `generate_detailed_reports` which loads a trained classification model, performs inference on both training and validation datasets, and produces comprehensive evaluation reports including accuracy, classification reports, confusion matrices, and ROC/PR AUC scores.

---

## Function Overview

`generate_detailed_reports(model_path, data_module_instance, class_names)`

- **Purpose:**  
  Load the best checkpoint of a gender classification model, run inference on training and validation sets, and generate detailed performance metrics and visualizations.

- **Inputs:**  
  - `model_path` : Path to the saved model checkpoint (.ckpt) file.  
  - `data_module_instance` : An instance of a data module providing data loaders and transforms.  
  - `class_names` : List of class names corresponding to the classification labels.

- **Outputs:**  
  Prints detailed classification metrics and shows confusion matrix heatmaps for both datasets.

---

## Detailed Steps

### 1. Re-initialize Data Module  
- To obtain class weights and other dataset parameters, the data module is re-instantiated with the same parameters and setup for 'fit'.

### 2. Load the Best Model Checkpoint  
- The model checkpoint is loaded with necessary hyperparameters such as learning rate, weight decay, temperature parameters for supervised contrastive loss, class weights, etc.  
- The model is set to evaluation mode and frozen to disable training behavior.

### 3. Inference on Training Set  
- A DataLoader for the training set is created without augmentation and shuffling, ensuring a consistent evaluation.  
- Predictions (softmax probabilities) and true labels are collected for the entire training set.

### 4. Training Set Evaluation  
- Overall accuracy is computed.  
- A detailed classification report is printed including precision, recall, and F1-score for each class.  
- Confusion matrix is generated and visualized as a heatmap.  
- ROC-AUC and Precision-Recall AUC scores are calculated if multiple classes are present.

### 5. Inference on Validation Set  
- Similar steps are followed as for the training set but using the validation dataloader provided by the data module.  
- Predictions and labels are collected for the validation set.

### 6. Validation Set Evaluation  
- Accuracy, classification report, confusion matrix visualization, ROC-AUC, and PR-AUC scores are computed and displayed similarly.

---

In [None]:
def generate_detailed_reports(model_path, data_module_instance, class_names):
    """
    Loads the best model, runs inference on both train and validation sets,
    and generates comprehensive classification reports and confusion matrices for both.
    Also visualizes misclassifications.
    """
    print(f"\n--- Generating detailed reports using model: {model_path} ---")

    # Re-instantiate DataModule to get its class weights, which are needed by the model's init.
    temp_data_module = GenderDataModule(
        train_data_dir=data_module_instance.train_data_dir,
        val_data_dir=data_module_instance.val_data_dir,
        train_preprocessed_data_dir=data_module_instance.train_preprocessed_data_dir,
        val_preprocessed_data_dir=data_module_instance.val_preprocessed_data_dir,
        batch_size=data_module_instance.batch_size,
        num_workers=data_module_instance.num_workers,
    )
    temp_data_module.setup('fit')

    best_model = GenderClassificationModel.load_from_checkpoint(
        model_path,
        map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
        num_classes=2,
        learning_rate=LEARNING_RATE, # This LR will be overridden by scheduler's behavior at inference
        weight_decay=WEIGHT_DECAY, 
        supcon_temp=SUPCON_TEMP, 
        supcon_base_temp=SUPCON_BASE_TEMP, 
        supcon_hard_mining_ratio=SUPCON_HARD_MINING_RATIO, 
        supcon_margin=SUPCON_MARGIN, 
        class_weights_for_loss=temp_data_module.class_weights_for_loss,
        class_weights_tensor=temp_data_module.class_weights_tensor,
        max_epochs=MAX_EPOCHS, # Pass max_epochs when loading
        label_smoothing=LABEL_SMOOTHING_EPSILON, # Pass label_smoothing when loading
        gamma=FOCAL_LOSS_GAMMA # Pass gamma when loading
    )
    best_model.eval() # Set model to evaluation mode
    best_model.freeze() # Freeze layers for inference

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    best_model.to(device) # Move model to appropriate device

    # --- Evaluate on Training Set ---
    # Need a DataLoader for the training set without WeightedRandomSampler for evaluation.
    # IMPORTANT: For evaluation, use is_train=False to get single view per image
    train_eval_dataset = GenderDataset(data_module_instance.train_data_dir, 
                                       data_module_instance.train_preprocessed_data_dir,
                                       transform=data_module_instance.val_transform,
                                       is_train=False) # Use False for consistent evaluation (no double augmentation)
    train_eval_dataloader = DataLoader(
        train_eval_dataset,
        batch_size=data_module_instance.batch_size,
        shuffle=False,
        num_workers=data_module_instance.num_workers,
        pin_memory=True
    )
    
    all_train_labels = []
    all_train_preds_probs = []

    print("\nCollecting training set predictions for report...")
    for batch in tqdm(train_eval_dataloader, desc="Training Set Inference"):
        imgs_original, imgs_processed = batch[0] # Unpack the tuple for original and processed
        labels = batch[1]
        
        imgs_original = imgs_original.to(device)
        imgs_processed = imgs_processed.to(device)
        
        with torch.no_grad():
            logits = best_model((imgs_original, imgs_processed)) # Pass as tuple
            probs = F.softmax(logits, dim=1)
        
        all_train_labels.extend(labels.cpu().numpy())
        all_train_preds_probs.extend(probs.cpu().numpy())

    all_train_labels = np.array(all_train_labels)
    all_train_preds_probs = np.array(all_train_preds_probs)
    all_train_preds_classes = np.argmax(all_train_preds_probs, axis=1)

    print("\n--- Comprehensive Evaluation Report (Training Set) ---")
    accuracy_train = accuracy_score(all_train_labels, all_train_preds_classes)
    print(f"Overall Accuracy (Training Set): {accuracy_train:.4f}")
    print("\nDetailed Classification Report (Training Set):")
    print(classification_report(all_train_labels, all_train_preds_classes, target_names=class_names, zero_division=0))

    print("\n--- Confusion Matrix (Training Set) ---")
    cm_train = confusion_matrix(all_train_labels, all_train_preds_classes)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm_train, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix (Training Set)')
    plt.show()

    try:
        if len(np.unique(all_train_labels)) > 1:
            roc_auc_train = roc_auc_score(all_train_labels, all_train_preds_probs[:, 1], average='weighted')
            print(f"Final Training ROC-AUC (weighted): {roc_auc_train:.4f}")
        else:
            print("Cannot compute ROC-AUC: Only one class present in training labels.")
    except Exception as e:
        print(f"Error computing final Training ROC-AUC: {e}")

    try:
        if len(np.unique(all_train_labels)) > 1:
            precision_train, recall_train, _ = precision_recall_curve(all_train_labels, all_train_preds_probs[:, 1])
            pr_auc_train = auc(recall_train, precision_train)
            print(f"Final Training PR-AUC (weighted): {pr_auc_train:.4f}")
        else:
            print("Cannot compute PR-AUC: Only one class present in training labels.")
    except Exception as e:
        print(f"Error computing final Training PR-AUC: {e}")


    # --- Evaluate on Validation Set ---
    val_dataloader = data_module_instance.val_dataloader()
    
    all_val_labels = []
    all_val_preds_probs = []

    print("\nCollecting validation set predictions for report...")
    for batch in tqdm(val_dataloader, desc="Validation Set Inference"):
        imgs_original, imgs_processed = batch[0] # Unpack the tuple for original and processed
        labels = batch[1]

        imgs_original = imgs_original.to(device)
        imgs_processed = imgs_processed.to(device)
      
        with torch.no_grad():
            logits = best_model((imgs_original, imgs_processed)) # Pass as tuple
            probs = F.softmax(logits, dim=1)
        
        all_val_labels.extend(labels.cpu().numpy())
        all_val_preds_probs.extend(probs.cpu().numpy())

    all_val_labels = np.array(all_val_labels)
    all_val_preds_probs = np.array(all_val_preds_probs)
    all_val_preds_classes = np.argmax(all_val_preds_probs, axis=1)

    print("\n--- Comprehensive Evaluation Report (Validation Set) ---")
    accuracy_val = accuracy_score(all_val_labels, all_val_preds_classes)
    print(f"Overall Accuracy (Validation Set): {accuracy_val:.4f}")
    print("\nDetailed Classification Report (Validation Set):")
    print(classification_report(all_val_labels, all_val_preds_classes, target_names=class_names, zero_division=0))

    print("\n--- Confusion Matrix (Validation Set) ---")
    cm_val = confusion_matrix(all_val_labels, all_val_preds_classes)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm_val, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix (Validation Set)')
    plt.show()

    try:
        if len(np.unique(all_val_labels)) > 1:
            roc_auc_val = roc_auc_score(all_val_labels, all_val_preds_probs[:, 1], average='weighted')
            print(f"Final Validation ROC-AUC (weighted): {roc_auc_val:.4f}")
        else:
            print("Cannot compute ROC-AUC: Only one class present in validation labels.")
    except Exception as e:
        print(f"Error computing final Validation ROC-AUC: {e}")

    try:
        if len(np.unique(all_val_labels)) > 1:
            precision_val, recall_val, _ = precision_recall_curve(all_val_labels, all_val_preds_probs[:, 1])
            pr_auc_val = auc(recall_val, precision_val)
            print(f"Final Validation PR-AUC (weighted): {pr_auc_val:.4f}")
        else:
            print("Cannot compute PR-AUC: Only one class present in validation labels.")
    except Exception as e:
        print(f"Error computing final Validation PR-auc: {e}")


# Main Training Script Overview

## Hyperparameters

- **Batch Size:** 16 (recommended low due to two EfficientNet models)  
- **Number of Workers:** 4 (increase if CPU resources allow)  
- **Learning Rate:** 5e-5 (fine-tuned for EfficientNet)  
- **Weight Decay:** 1e-5 (L2 regularization)  
- **Max Epochs:** 50 (enables scheduler effectiveness)  
- **Label Smoothing Epsilon:** 0.1 (helps generalization by smoothing labels)  
- **Focal Loss Gamma:** 3.0 (focuses training on hard examples)

## Supervised Contrastive Loss Parameters

- **Temperature:** 0.07  
- **Base Temperature:** 0.07  
- **Hard Mining Ratio:** 0.5 (aggressive mining of difficult negatives)  
- **Margin:** 0.3 (encourages wider separation between classes)

## Visualization Option

- Visualize all misclassified images after training is enabled.

## Data Module Initialization

- Loads training and validation datasets with specified directories and batch settings.  
- Sets up data preprocessing and augmentation pipelines.

## Model Initialization

- Defines a binary classification model for gender classification.  
- Applies hyperparameters including learning rate, weight decay, contrastive loss parameters, class weights, max epochs, label smoothing, and focal loss gamma.

## Callbacks Configuration

- Model Checkpoint: Saves the best-performing model based on validation accuracy.  
- Early Stopping: Stops training if validation accuracy does not improve for 15 consecutive epochs.  
- Learning Rate Monitor: Logs learning rate adjustments every epoch.

## Logger Setup

- Uses a CSV logger to store training logs under a designated folder with the experiment name.

## Trainer Setup

- Selects GPU for acceleration if available; otherwise CPU.  
- Uses mixed precision training on GPU for efficiency.  
- Logs training progress every 10 steps.  
- Accumulates gradients over 2 batches to simulate a larger batch size.

## Training Execution

- Starts the training process using the configured trainer and data module.  
- Prints messages indicating start and completion of training.

## Post-training Evaluation

- Loads the best model checkpoint after training completion.  
- Generates detailed evaluation reports including classification metrics and confusion matrices on both training and validation sets.  
- Visualizes misclassified images from the validation dataset, either all or a limited number based on the visualization flag.

## Final Notes

- Confirms when the full training, evaluation, and visualization pipeline has completed successfully.


In [None]:

# --- Main Training Script ---
if __name__ == '__main__':
    # Hyperparameters (you can tune these)
    BATCH_SIZE = 16 # Keep this at 16 or lower due to two EfficientNet models
    NUM_WORKERS = 4
    LEARNING_RATE = 5e-5 # Adjusted for fine-tuning EfficientNet
    WEIGHT_DECAY = 1e-5 # L2 regularization
    MAX_EPOCHS = 50 # Increased epochs to allow Cosine Annealing to work
    LABEL_SMOOTHING_EPSILON = 0.1 # Epsilon for label smoothing (common values are 0.05, 0.1, 0.2)
    FOCAL_LOSS_GAMMA = 3.0 # Increased gamma for Focal Loss to focus more on hard examples

    # Supervised Contrastive Loss parameters
    SUPCON_TEMP = 0.07 
    SUPCON_BASE_TEMP = 0.07 
    SUPCON_HARD_MINING_RATIO = 0.5 # Increased for more aggressive hard mining
    SUPCON_MARGIN = 0.3 # Increased for wider separation of negative pairs

    # Flag to visualize all misclassified images
    VISUALIZE_ALL_MISCLASSIFICATIONS = True 

    # Initialize DataModule
    data_module = GenderDataModule(
        train_data_dir=TRAIN_PATH,
        val_data_dir=VAL_PATH,
        train_preprocessed_data_dir=TRAIN_PREPROCESSED_PATH,
        val_preprocessed_data_dir=VAL_PREPROCESSED_PATH,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS, # Consider increasing this if CPU is underutilized and you have more cores
    )
    data_module.setup('fit')

    # Initialize Model
    model = GenderClassificationModel(
        num_classes=2,
        learning_rate=LEARNING_RATE, # This learning rate will be used by the optimizer, then modulated by scheduler
        weight_decay=WEIGHT_DECAY,
        supcon_temp=SUPCON_TEMP,
        supcon_base_temp=SUPCON_BASE_TEMP,
        supcon_hard_mining_ratio=SUPCON_HARD_MINING_RATIO,
        supcon_margin=SUPCON_MARGIN,
        class_weights_for_loss=data_module.class_weights_for_loss,
        class_weights_tensor=data_module.class_weights_tensor,
        max_epochs=MAX_EPOCHS, # Pass MAX_EPOCHS to the model's constructor
        label_smoothing=LABEL_SMOOTHING_EPSILON, # Pass label smoothing parameter
        gamma=FOCAL_LOSS_GAMMA # Pass gamma to the model's constructor for Focal Loss
    )

    # Callbacks
    # Checkpoint Callback: Save the best model based on validation accuracy
    checkpoint_callback = ModelCheckpoint(
        dirpath='checkpoints',
        filename='best_model', # Default filename, can be customized with metrics
        monitor='val_accuracy_epoch', # Monitoring val_accuracy_epoch now
        mode='max',
        save_top_k=1,
        verbose=True
    )
    
    # Early Stopping Callback: Stop if validation accuracy doesn't improve for 15 epochs
    early_stopping_callback = EarlyStopping(
        monitor='val_accuracy_epoch', # Monitor val_accuracy_epoch for early stopping
        patience=15, 
        mode='max',
        verbose=True
    )
    lr_monitor = LearningRateMonitor(logging_interval='epoch')

    # Logger
    logger = CSVLogger("logs", name="gender_classification")

    # Initialize Trainer
    trainer = pl.Trainer(
        max_epochs=MAX_EPOCHS,
        logger=logger,
        callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor],
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=1, # Use 1 GPU if available
        precision=16 if torch.cuda.is_available() else 32, # Mixed precision for faster training on GPU
        log_every_n_steps=10, # Changed from 1 to 10 for slightly faster logging
        accumulate_grad_batches=2 # Accumulate gradients over 2 batches to simulate larger batch size
    )

    print("\nStarting model training...")
    trainer.fit(model, datamodule=data_module)
    print("\nTraining complete!")

    # --- Final Evaluation and Report Generation ---
    if checkpoint_callback.best_model_path:
        class_names = data_module.train_dataset.idx_to_class.values()
        generate_detailed_reports(
            checkpoint_callback.best_model_path, 
            data_module,
            list(class_names)
        )
        # Determine how many misclassified images to visualize
        num_images_to_viz = None if VISUALIZE_ALL_MISCLASSIFICATIONS else 10
        
        # Visualize misclassifications from the validation set using the best model
        # Re-load the best model for visualization to ensure it's the one that generated the best performance
        best_model_for_viz = GenderClassificationModel.load_from_checkpoint(
            checkpoint_callback.best_model_path,
            map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
            num_classes=2,
            learning_rate=LEARNING_RATE,
            weight_decay=WEIGHT_DECAY,
            supcon_temp=SUPCON_TEMP,
            supcon_base_temp=SUPCON_BASE_TEMP,
            supcon_hard_mining_ratio=SUPCON_HARD_MINING_RATIO,
            supcon_margin=SUPCON_MARGIN,
            class_weights_for_loss=data_module.class_weights_for_loss,
            class_weights_tensor=data_module.class_weights_tensor,
            max_epochs=MAX_EPOCHS, # Pass MAX_EPOCHS to the model's constructor
            label_smoothing=LABEL_SMOOTHING_EPSILON, # Pass label_smoothing
            gamma=FOCAL_LOSS_GAMMA # Pass gamma when loading
        )
        val_dataloader_for_viz = data_module.val_dataloader()
        visualize_misclassifications(
            best_model_for_viz, 
            val_dataloader_for_viz, 
            list(class_names), 
            num_images=num_images_to_viz # Pass num_images to visualize
        )
    else:
        print("No best model checkpoint found to generate final report or visualize misclassifications.")

    print("\nPipeline execution complete!")
