In [1]:
# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

In [2]:
# pip install numpy opencv-python matplotlib scikit-learn imutils

In [None]:
import random, os
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
import torch.nn.functional as F
import cv2
from matplotlib import pyplot as plt
from sklearn.preprocessing import LabelEncoder
from imutils import paths
from sklearn.model_selection import train_test_split
import pickle
import shutil


def set_seed(seed=0):
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    if torch.cuda.is_available() and not torch.cuda.is_initialized():
        print("CUDA is available but not initialized properly")
    else:
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)


# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

set_seed()

Using device: cuda
CUDA is available but not initialized properly


# Architecture

In [None]:
# U-net based architecture

import torch
import torch.nn as nn
import torch.nn.functional as F


class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNetBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class UNetBinaryClassifier_CE(nn.Module):
    def __init__(self, in_channels=3, base_channels=32):
        super(UNetBinaryClassifier_CE, self).__init__()

        # Encoder
        self.enc1 = UNetBlock(in_channels, base_channels)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = UNetBlock(base_channels, base_channels * 2)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = UNetBlock(base_channels * 2, base_channels * 4)
        self.pool3 = nn.MaxPool2d(2)
        self.enc4 = UNetBlock(base_channels * 4, base_channels * 8)
        self.pool4 = nn.MaxPool2d(2)

        # Bottleneck
        self.bottleneck = UNetBlock(base_channels * 8, base_channels * 16)

        # Decoder
        self.up4 = nn.ConvTranspose2d(
            base_channels * 16, base_channels * 8, kernel_size=2, stride=2
        )
        self.dec4 = UNetBlock(base_channels * 16, base_channels * 8)
        self.up3 = nn.ConvTranspose2d(
            base_channels * 8, base_channels * 4, kernel_size=2, stride=2
        )
        self.dec3 = UNetBlock(base_channels * 8, base_channels * 4)
        self.up2 = nn.ConvTranspose2d(
            base_channels * 4, base_channels * 2, kernel_size=2, stride=2
        )
        self.dec2 = UNetBlock(base_channels * 4, base_channels * 2)
        self.up1 = nn.ConvTranspose2d(
            base_channels * 2, base_channels, kernel_size=2, stride=2
        )
        self.dec1 = UNetBlock(base_channels * 2, base_channels)

        # Final classifier (global average pooling + linear layer)
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(base_channels, 2),  # Output 2 logits for binary classification
            # Removed nn.Sigmoid() because CrossEntropyLoss expects raw logits
        )

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        e4 = self.enc4(self.pool3(e3))

        # Bottleneck
        b = self.bottleneck(self.pool4(e4))

        # Decoder
        d4 = self.dec4(torch.cat([self.up4(b), e4], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))

        # Global pooling + classification
        out = self.global_pool(d1)
        out = self.classifier(out)  # [batch_size, 2], raw logits

        return out

In [None]:
# U-net based architecture (AAM loss)

import torch
import torch.nn as nn
import torch.nn.functional as F


class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNetBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class ArcFaceLayer(nn.Module):
    def __init__(self, in_features, out_features, s=30.0, m=0.50):
        super(ArcFaceLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

    def forward(self, embeddings, labels=None):
        # Normalize embeddings and weights
        embeddings = F.normalize(embeddings, p=2, dim=1)  # [batch, in_features]
        weights = F.normalize(self.weight, p=2, dim=1)  # [out_features, in_features]

        # Cosine similarity
        cos_theta = F.linear(embeddings, weights)  # [batch, out_features]

        if labels is None:
            # Inference: return scaled cosine similarities
            return self.s * cos_theta

        # Training: apply ArcFace margin
        theta = torch.acos(torch.clamp(cos_theta, -1.0 + 1e-7, 1.0 - 1e-7))
        one_hot = torch.zeros_like(cos_theta)
        one_hot.scatter_(1, labels.view(-1, 1).long(), 1)

        # Add margin to true class
        cos_theta_m = torch.cos(theta + self.m * one_hot)
        logits = self.s * (cos_theta * (1.0 - one_hot) + cos_theta_m * one_hot)

        return logits


class UNetBinaryClassifier(nn.Module):
    def __init__(
        self, in_channels=3, base_channels=32, embedding_size=512, num_classes=2
    ):
        super(UNetBinaryClassifier, self).__init__()

        # Encoder
        self.enc1 = UNetBlock(in_channels, base_channels)
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.enc2 = UNetBlock(base_channels, base_channels * 2)
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.enc3 = UNetBlock(base_channels * 2, base_channels * 4)
        self.pool3 = nn.MaxPool2d(kernel_size=2)
        self.enc4 = UNetBlock(base_channels * 4, base_channels * 8)
        self.pool4 = nn.MaxPool2d(kernel_size=2)

        # Bottleneck
        self.bottleneck = UNetBlock(base_channels * 8, base_channels * 16)

        # Decoder
        self.up4 = nn.ConvTranspose2d(
            base_channels * 16, base_channels * 8, kernel_size=2, stride=2
        )
        self.dec4 = UNetBlock(base_channels * 16, base_channels * 8)
        self.up3 = nn.ConvTranspose2d(
            base_channels * 8, base_channels * 4, kernel_size=2, stride=2
        )
        self.dec3 = UNetBlock(base_channels * 8, base_channels * 4)
        self.up2 = nn.ConvTranspose2d(
            base_channels * 4, base_channels * 2, kernel_size=2, stride=2
        )
        self.dec2 = UNetBlock(base_channels * 4, base_channels * 2)
        self.up1 = nn.ConvTranspose2d(
            base_channels * 2, base_channels, kernel_size=2, stride=2
        )
        self.dec1 = UNetBlock(base_channels * 2, base_channels)

        # Embedding layer
        self.global_pool = nn.AdaptiveAvgPool2d(output_size=1)
        self.embedding = nn.Sequential(
            nn.Flatten(),
            nn.Linear(base_channels, embedding_size),
            nn.BatchNorm1d(embedding_size),
            nn.ReLU(inplace=True),
        )

        # ArcFace layer
        self.arcface = ArcFaceLayer(embedding_size, num_classes, s=30.0, m=0.5)

    def forward(self, x, labels=None):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        e4 = self.enc4(self.pool3(e3))

        # Bottleneck
        b = self.bottleneck(self.pool4(e4))

        # Decoder
        d4 = self.dec4(torch.cat([self.up4(b), e4], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))

        # Embedding
        emb = self.embedding(self.global_pool(d1))  # [batch, embedding_size]

        # ArcFace logits
        logits = self.arcface(emb, labels)  # [batch, num_classes]

        # Probabilities for inference
        probs = F.softmax(logits, dim=1) if labels is None else None

        return logits, probs, emb

    def get_activation_map(self, x):
        with torch.no_grad():
            e1 = self.enc1(x)
            e2 = self.enc2(self.pool1(e1))
            e3 = self.enc3(self.pool2(e2))
            e4 = self.enc4(self.pool3(e3))
            b = self.bottleneck(self.pool4(e4))
            d4 = self.dec4(torch.cat([self.up4(b), e4], dim=1))
            d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
            d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
            d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        return d1

# AAM loss training

In [None]:
import os
import shutil

# === Define Paths ===
root_folder = "project_path\replay attack"

# Paths for Caffe face detector
protoPath = os.path.join(root_folder, "face_detector/deploy.prototxt.txt")
modelPath = os.path.join(
    root_folder, "face_detector/res10_300x300_ssd_iter_140000.caffemodel"
)

face_confidence = 0.75  # Confidence threshold for face detection

# Root directory for extracted Replay-Attack data
Replay_Attack_extracted = os.path.join(
    root_folder, "dataset/training/Replay-Attack extracted/"
)

# === Output Directories ===
# Output directories for detected faces (training and validation)
output_attack_training_folder = os.path.join(
    root_folder, "dataset/training/Replay_Attack/images/attack_training/"
)
output_attack_validation_folder = os.path.join(
    root_folder, "dataset/training/Replay_Attack/images/attack_validation/"
)
output_bonifade_training_folder = os.path.join(
    root_folder, "dataset/training/Replay_Attack/images/bonifade_training/"
)
output_bonifade_validation_folder = os.path.join(
    root_folder, "dataset/training/Replay_Attack/images/bonifade_validation/"
)

# Directory for augmented images (if used during training)
save_augmented_images = os.path.join(
    root_folder, "dataset/training/Replay_Attack/augmented_images/"
)

# Paths for saving model checkpoint, labels, metrics plot, and training array
save_model_pth = os.path.join(
    root_folder, "dataset/arcface/best_model_Replay_Attack_arcface.pth"
)
save_labels = os.path.join(
    root_folder, "dataset/arcface/model_Replay_Attack_labels_arcface.csv"
)
save_training_metrics_plot = os.path.join(
    root_folder, "dataset/arcface/Replay_Attack_plot_arcface.png"
)
output_np_training_array = os.path.join(
    root_folder, "dataset/arcface/Replay_Attack_training_array.csv"
)


# === Helper Functions ===
def get_folders(folder_types=["train"], image_type="attack"):
    """
    Retrieve .mov file paths for Replay-Attack dataset based on folder types and image type.

    Args:
        folder_types (list): List of folder types (e.g., ["train", "devel", "test"]).
        image_type (str): Type of images ("attack" or "real").

    Returns:
        list: List of .mov file paths.
    """
    final_files = []

    for folder_type in folder_types:
        init_folders = []
        new_root = os.path.join(Replay_Attack_extracted, folder_type, image_type)

        if not os.path.exists(new_root):
            print(f"Warning: Directory not found: {new_root}")
            continue

        if image_type == "attack":
            init_folders.append(os.path.join(new_root, "fixed"))
            init_folders.append(os.path.join(new_root, "hand"))
        else:
            init_folders.append(new_root)

        for folder in init_folders:
            if not os.path.exists(folder):
                print(f"Warning: Directory not found: {folder}")
                continue
            # Collect .mov files directly in the folder
            files = [
                f.path
                for f in os.scandir(folder)
                if f.is_file() and f.name.endswith(".mov")
            ]
            final_files.extend(files)
            if not files:
                print(f"No .mov files found in: {folder}")

    return final_files


# === Create Output Directories ===
output_directories = [
    output_attack_training_folder,
    output_attack_validation_folder,
    output_bonifade_training_folder,
    output_bonifade_validation_folder,
    save_augmented_images,
]

for output_dir in output_directories:
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"Created directory: {output_dir}")

# === Validate Caffe Model Files ===
assert os.path.exists(protoPath), f"Prototxt not found at: {protoPath}"
assert os.path.exists(modelPath), f"Caffe model not found at: {modelPath}"
print("Caffe model files validated successfully.")

In [None]:
import os
import shutil

# === Retrieve Training and Validation Files ===
# Get attack training and validation files
attack_folders_training = get_folders(["train", "devel"], "attack")
attack_folders_validation = get_folders(["test"], "attack")

# Get bonafide training and validation files
bonifade_folders_training = get_folders(["train", "devel"], "real")
bonifade_folders_validation = get_folders(["test"], "real")

# === Organize Directories ===
training_directories = [
    attack_folders_training,
    attack_folders_validation,
    bonifade_folders_training,
    bonifade_folders_validation,
]
output_directories = [
    output_attack_training_folder,
    output_attack_validation_folder,
    output_bonifade_training_folder,
    output_bonifade_validation_folder,
]

# === Create Output Directories ===
for output_dir in output_directories:
    if not os.path.exists(output_dir):
        print(f"Creating directory: {output_dir}")
        os.makedirs(output_dir)
    else:
        print(f"Directory already exists: {output_dir}")

    # Create mask subdirectory for storing mask images (e.g., face detection outputs)
    mask_dir = os.path.join(output_dir, "mask")
    if not os.path.exists(mask_dir):
        print(f"Creating mask subdirectory: {mask_dir}")
        os.makedirs(mask_dir)
    else:
        print(f"Mask subdirectory already exists: {mask_dir}")

# === Directory Structure Overview ===
# The Replay_Attack/images folder will contain:
# 1. Attack training images in output_attack_training_folder
#    - Original attack images extracted from .mov files
#    - Mask subdirectory for attack face detection outputs
# 2. Attack validation images in output_attack_validation_folder
#    - Original attack validation images extracted from .mov files
#    - Mask subdirectory for attack validation face detection outputs
# 3. Bonafide (real) training images in output_bonifade_training_folder
#    - Original bonafide images extracted from .mov files
#    - Mask subdirectory for bonafide face detection outputs
# 4. Bonafide (real) validation images in output_bonifade_validation_folder
#    - Original bonafide validation images extracted from .mov files
#    - Mask subdirectory for bonafide validation face detection outputs

# === Logging ===
# Verify that files are retrieved correctly
print(
    f"Attack files for training ({len(attack_folders_training)}):",
    attack_folders_training,
)
print(
    f"Attack files for validation ({len(attack_folders_validation)}):",
    attack_folders_validation,
)
print(
    f"Bonifide files for training ({len(bonifade_folders_training)}):",
    bonifade_folders_training,
)
print(
    f"Bonifide files for validation ({len(bonifade_folders_validation)}):",
    bonifade_folders_validation,
)

# Warn if any file list is empty
if not attack_folders_training:
    print(
        "Warning: No attack training files found. Check train/devel attack directories."
    )
if not attack_folders_validation:
    print("Warning: No attack validation files found. Check test attack directories.")
if not bonifade_folders_training:
    print(
        "Warning: No bonifide training files found. Check train/devel real directories."
    )
if not bonifade_folders_validation:
    print("Warning: No bonifide validation files found. Check test real directories.")

In [None]:
import cv2
import numpy as np

# === Hyperparameters for Training ===
final_x = 128
final_y = 128
n_channels = 3

# Training parameters
epochs = 20
batch_size = 32
init_learning_rate = 1e-4
# Note: For ArcFace, the learning rate may need tuning (e.g., 1e-4) due to sensitivity to margin and scaling.

# Target dimensions for resizing
dim = (final_x, final_y)


# === Preprocessing Function ===
def resize_normalize_image(image, dim=dim, value=255.0):
    """
    Resize and normalize an image for model input.

    Args:
        image (numpy.ndarray): Input image (H, W, C) in BGR format (from cv2).
        dim (tuple): Target dimensions (width, height) for resizing.
        value (float): Value to divide pixel values by for normalization (default: 255.0 for range [0, 1]).

    Returns:
        numpy.ndarray: Resized and normalized image in range [0, 1].

    Notes:
        - Designed for Replay-Attack .mov video frames extracted as RGB images.
        - Image size (32x32) is adopted from the original Replay-Attack pipeline but may be adjusted (e.g., to 128x128)
          for better feature extraction with UNet-based models.
        - Additional normalization (e.g., mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5) to scale to [-1, 1])
          will be applied in the dataloader using torchvision.transforms.Normalize.
    """
    # Validate input image
    if image is None or image.size == 0:
        raise ValueError("Input image is empty or None")
    if len(image.shape) != 3 or image.shape[2] != 3:
        raise ValueError(f"Expected 3-channel image, got shape {image.shape}")

    # Resize image to target dimensions
    image = cv2.resize(image, dim)

    # Normalize to [0, 1] by dividing by value
    normalized_image = image / value

    # Ensure values are in [0, 1]
    normalized_image = np.clip(normalized_image, 0.0, 1.0)

    return normalized_image

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os

# === Load Face Detection Model ===
net = cv2.dnn.readNetFromCaffe(protoPath, modelPath)


# Visualization Function ===
def im_show(image, size=(15, 15), output=None):
    """
    Display an image using Matplotlib and optionally save it.

    Args:
        image (numpy.ndarray): Image to display (H, W, C), expected in RGB format.
        size (tuple): Figure size (width, height) in inches.
        output (str, optional): Path to save the image (in BGR format for cv2).
    """
    # Validate image
    if image is None or image.size == 0:
        raise ValueError("Input image is empty or None")
    if len(image.shape) != 3 or image.shape[2] != 3:
        raise ValueError(f"Expected 3-channel RGB image, got shape:{image.shape}")

    # Display image
    fig = plt.figure(figsize=size)
    plt.imshow(image)
    plt.axis("off")
    plt.show()

    # Save image if output path is provided
    if output is not None:
        image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        cv2.imwrite(output, image_bgr)
        print(f"Saved image to: {output}")


# === Face Detection Function ===
def detect_save_face(color, output_path=None, multiple_output=False):
    """
    Detect faces in an image using the Caffe model and optionally save them.

    Args:
        color (numpy.ndarray): Input image (H, W, C) in BGR format (from cv2), typically a frame from Replay-Attack .mov videos.
        output_path (str, optional): Path to save the most confident face (in BGR format).
        multiple_output (bool): If True, detect all faces; if False, detect only the most confident face.

    Returns:
        list: List of detected faces (in BGR format).
        list: List of bounding box coordinates [startX, startY, endX, endY].
    """
    # Validate input image
    if color is None or color.size == 0:
        raise ValueError("Input image is empty or None")
    if len(color.shape) != 3 or color.shape[2] != 3:
        raise ValueError(f"Expected 3-channel BGR image, got shape {color.shape}")

    (h, w) = color.shape[:2]

    # Preprocess image for face detection
    blob = cv2.dnn.blobFromImage(
        cv2.resize(color, (300, 300)), 1.0, (300, 300), (104.0, 177.0, 123.0)
    )

    # Perform face detection
    net.setInput(blob)
    detections = net.forward()

    faces = []
    coordinates = []

    if detections.shape[2] == 0:
        print("Warning: No faces detected in the image")
        return faces, coordinates

    # Find the most confident face or all faces
    max_i = np.argmax(detections[0, 0, :, 2])
    min_range = 0 if multiple_output else max_i

    max_range = detections.shape[2] if multiple_output else max_i + 1

    for i in range(min_range, max_range):
        confidence = detections[0, 0, i, 2]
        if confidence >= face_confidence:
            box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
            (startX, startY, endX, endY) = box.astype("int")

            # Ensure bounding box is within image dimensions
            startX, startY = max(0, startX), max(0, startY)
            endX, endY = min(w, endX), min(h, endY)

            # Extract face
            face = color[startY:endY, startX:endX]
            if face.size == 0:
                print(
                    f"Warning: Empty face region at coordinates [{startX}, {startY}, {endX}, {endY}]"
                )
                continue

            faces.append(face)
            coordinates.append([startX, startY, endX, endY])

            # Save the face if output path is provided
            if output_path is not None:
                cv2.imwrite(output_path, face)
                print(f"Saved face to: {output_path} (Confidence: {confidence:.4f})")

    print(f"Detected {len(faces)} face(s) with confidence >= {face_confidence}")
    return faces, coordinates


# === File Handling Functions ===
def get_file_name(path):
    """
    Extract the file name (without extension) from a path.

    Args:
        path (str): File path.

    Returns:
        str: File name without extension.
    """
    if not path or not isinstance(path, str):
        raise ValueError("Path must be a non-empty string")
    base = os.path.basename(path)
    return os.path.splitext(base)[0]


def get_file_type(path):
    """
    Extract the file extension from a path.

    Args:
        path (str): File path.

    Returns:
        str: File extension (e.g., '.mov', '.jpg').
    """
    if not path or not isinstance(path, str):
        raise ValueError("Path must be a non-empty string")
    base = os.path.basename(path)
    return os.path.splitext(base)[1]  # e.g., .mov, .jpg")

In [None]:
import os
import shutil
import cv2
import numpy as np


def extract_frames_from_videos(file_paths):
    """
    Extract frames from .mov videos in the Replay-Attack dataset.

    Args:
        file_paths (list): List of .mov video file paths (e.g., from train/attack/fixed, train/real).

    Notes:
        - Extracts up to 18 frames per video, evenly distributed.
        - Saves RGB frames as JPG in an 'extracted_images' subdirectory of the input video folder
          (e.g., train/attack/fixed/extracted_images/).
        - Adds suffixes (_fixed, _hand, _real) to frame filenames to distinguish video types.
        - Frames will be processed for face detection and saved in output directories
          (e.g., output_attack_training_folder).
        - Frames are saved as frame_<filename>_<suffix>_<index>.jpg (e.g., frame_client001_fixed_0.jpg).
    """
    for video_path in file_paths:
        # Validate video file
        if not os.path.exists(video_path) or not video_path.endswith(".mov"):
            print(f"Error: Invalid or missing video file: {video_path}")
            continue

        # Determine video type for suffix
        parent_dir = os.path.dirname(video_path).lower()
        if "fixed" in parent_dir:
            suffix = "_fixed"
        elif "hand" in parent_dir:
            suffix = "_hand"
        else:
            suffix = "_real"

        # Output directory is 'extracted_images' subdirectory of the input video folder
        output_dir = os.path.join(os.path.dirname(video_path), "extracted_images")
        os.makedirs(
            output_dir, exist_ok=True
        )  # Create extracted_images if it doesn't exist

        # Open video
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            print(f"Error: Failed to open video: {video_path}")
            cap.release()
            continue

        # Get total frames and FPS
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        fps = int(cap.get(cv2.CAP_PROP_FPS))

        # Strategy parameters
        target_frames = 18  # Maximum number of frames to extract per video

        if total_frames > target_frames:
            # Take frames evenly distributed across the video
            frame_interval = total_frames // target_frames
        else:
            # If video is shorter, take every frame
            frame_interval = 1

        frame_count = 0
        saved_count = 0
        save_frame = False

        video_name = get_file_name(video_path)  # Extract filename without extension

        while True:
            ret, frame = cap.read()
            if not ret:
                break

            # Strategy: Evenly distributed frames
            if frame_count % frame_interval == 0:
                save_frame = True

            # Save frame if flagged
            if save_frame:
                # Validate frame
                if frame is None:
                    print(
                        f"Warning: Invalid frame at index {frame_count} in {video_path}"
                    )
                    frame_count += 1
                    continue

                # Save RGB frame
                frame_path = os.path.join(
                    output_dir, f"frame_{video_name}{suffix}_{saved_count}.jpg"
                )
                cv2.imwrite(frame_path, frame)
                print(f"Saved RGB frame: {frame_path} (Shape: {frame.shape})")

                saved_count += 1
                save_frame = False

                # Stop if we've reached the target number of frames
                if saved_count >= target_frames:
                    break

            frame_count += 1

        cap.release()
        print(
            f"Extracted {saved_count} frames from {video_path} (total frames: {total_frames}, interval: {frame_interval})"
        )


# === Extract Frames from All Videos ===
print("Extracting frames from videos...")
for file_list in training_directories:
    print(f"Processing {len(file_list)} videos...")
    extract_frames_from_videos(file_list)

Extracting frames from videos...


In [None]:
import os
import random
import cv2
from glob import glob

# Maximum number of images per folder due to hardware constraints
max_images_per_folder = 3000


def take_images():
    """
    Process RGB images from extracted frames to detect and extract faces.
    For each color (RGB) image:
    - Detects and extracts the most confident face using detect_save_face (Cell 6).
    - Saves the face image to the appropriate output directory (from Cell 3).
    - Limits the number of images per folder to max_images_per_folder.

    Notes:
        - Images are shuffled with a fixed seed for reproducibility.
        - Faces are saved in BGR format and will be converted to RGB and resized to 32x32 in the dataloader.
        - Processes frames from extracted_images/ subdirectories (e.g., train/attack/fixed/extracted_images/)
          created in Cell 7 from .mov files.
    """
    directory_counter = 0  # Track which training directory we're processing
    file_counter = 0  # Track total files processed

    for main_list in training_directories:
        folder_file_counter = 0  # Track files processed in current folder

        print(
            f"Processing directory {directory_counter + 1}/{len(training_directories)} "
            f"({output_directories[directory_counter]})"
        )

        # Get unique extracted_images directories from video paths
        frame_dirs = set()
        for video_path in main_list:
            # Derive extracted_images directory from video path
            frames_dir = os.path.join(os.path.dirname(video_path), "extracted_images")
            if os.path.exists(frames_dir):
                frame_dirs.add(frames_dir)
            else:
                print(f"Error: Frames directory not found: {frames_dir}")

        if not frame_dirs:
            print(
                f"Warning: No frame directories found for directory {directory_counter + 1}"
            )
            directory_counter += 1
            continue

        for frames_dir in frame_dirs:
            # Get all image paths in current extracted_images directory
            list_images = glob(os.path.join(frames_dir, "*.jpg"))
            if not list_images:
                print(f"Warning: No images found in {frames_dir}")
                continue

            print(f"Found {len(list_images)} images in {frames_dir}")

            # Shuffle images with fixed seed for reproducibility
            random.Random(20).shuffle(list_images)

            for file_path in list_images:
                # Stop if we've hit the per-folder limit
                if folder_file_counter >= max_images_per_folder:
                    print(
                        f"Reached max images ({max_images_per_folder}) for folder {directory_counter + 1}"
                    )
                    break

                # Load image
                color = cv2.imread(file_path)
                if color is None:
                    print(f"Could not load image: {file_path}")
                    continue

                # Define output path for the face, preserving suffix (_fixed, _hand, _real)
                frame_name = get_file_name(file_path)
                output_path = os.path.join(
                    output_directories[directory_counter],
                    f"face_{frame_name}{get_file_type(file_path)}",
                )

                # Detect and save face
                faces, _ = detect_save_face(color, output_path)
                if not faces:
                    print(f"No faces detected in {file_path}")
                    continue

                folder_file_counter += 1
                file_counter += 1
                print(
                    f"Saved face {file_counter} from {file_path} to folder {directory_counter + 1}"
                )

        print(
            f"Total faces saved in folder {directory_counter + 1}: {folder_file_counter}"
        )
        directory_counter += 1


# Run the face extraction
take_images()

In [None]:
import cv2
import os
import numpy as np
import random
from collections import Counter
import torchvision.transforms as transforms
import torch
import pickle

# Constants (from Cell 5)
dim = (128, 128)
final_x, final_y, n_channels = 128, 128, 3


# === Image Preprocessing ===
def resize_image(image, dim=dim):
    """
    Resize and normalize image to target dimensions using resize_normalize_image.

    Args:
        image (numpy.ndarray): Input image (H, W, C) in BGR format.
        dim (tuple): Target dimensions (width, height).

    Returns:
        numpy.ndarray: Resized and normalized image in [0, 1] float32, BGR format.
    """
    if image is None or image.size == 0:
        raise ValueError("Input image is empty or None")
    image = resize_normalize_image(image, dim=dim, value=255.0)
    return image.astype(np.float32)  # Ensure float32 for OpenCV compatibility


# === Dataset Preparation ===
def get_images(directories, balance_dataset=True, undersampling=False):
    """
    Load images from directories, assign labels, and optionally balance the dataset.

    Args:
        directories (list): List of directories containing face images (bonafide, attack).
        balance_dataset (bool): If True, balance classes using oversampling or undersampling.
        undersampling (bool): If True and balance_dataset=True, use undersampling instead of oversampling.

    Returns:
        numpy.ndarray: Array of images (N, H, W, C) in [0, 1] float32, RGB format.
        numpy.ndarray: Array of labels (0: bonafide, 1: attack).
    """
    image_list = []
    label_list = []
    images_per_class = []

    # Set seeds for reproducibility
    torch.manual_seed(20)
    random.seed(20)
    np.random.seed(20)

    for class_idx, directory in enumerate(directories):
        if not os.path.exists(directory):
            print(f"Error: Directory not found: {directory}")
            continue

        list_images = [
            f
            for f in os.listdir(directory)
            if f.endswith(".jpg") and os.path.isfile(os.path.join(directory, f))
        ]
        if not list_images:
            print(f"Warning: No images found in {directory}")
            continue

        images_number = len(list_images)
        if balance_dataset and undersampling and images_per_class:
            # Limit to the size of the smallest class seen so far
            images_number = min(images_number, min(images_per_class))
        random.Random(20).shuffle(list_images)
        print(f"Found {images_number} images in {directory}")

        X = np.empty((images_number, final_x, final_y, n_channels), dtype=np.float32)
        L = np.empty(images_number, dtype=np.int64)

        ipp = 0
        for im_name in list_images:
            if ipp >= images_number:
                break
            im_path = os.path.join(directory, im_name)
            image = cv2.imread(im_path)
            if image is None:
                print(f"Failed to load image: {im_path}")
                continue

            # Resize and normalize to [0, 1], convert to RGB
            image = resize_image(image, dim=dim)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            X[ipp] = image
            L[ipp] = class_idx  # 0: bonafide, 1: attack
            ipp += 1

        image_list.append(X[:ipp])
        label_list.append(L[:ipp])
        images_per_class.append(ipp)

    if not image_list:
        raise ValueError("No images loaded from any directory")

    # Balance dataset
    if balance_dataset:
        if undersampling:
            # Undersample majority class to match minority
            min_class_size = min(images_per_class)
            for j in range(len(images_per_class)):
                if images_per_class[j] > min_class_size:
                    indices = np.random.choice(
                        images_per_class[j], min_class_size, replace=False
                    )
                    image_list[j] = image_list[j][indices]
                    label_list[j] = label_list[j][indices]
                    images_per_class[j] = min_class_size
                    print(f"Undersampled class {j} to {min_class_size} images")
        else:
            # Oversample minority class to match majority
            max_class_size = max(images_per_class)
            transform = transforms.Compose(
                [
                    transforms.ToPILImage(),
                    transforms.RandomRotation(10),
                    transforms.RandomHorizontalFlip(p=0.3),
                    transforms.ColorJitter(
                        brightness=0.1, contrast=0.1, saturation=0.1
                    ),
                    transforms.ToTensor(),  # [C, H, W], [0, 1]
                ]
            )

            for j in range(len(images_per_class)):
                if images_per_class[j] < max_class_size:
                    diff = max_class_size - images_per_class[j]
                    image_array = image_list[j]
                    label_array = label_list[j]

                    new_images = np.empty(
                        (diff, final_x, final_y, n_channels), dtype=np.float32
                    )
                    new_labels = np.full(diff, j, dtype=np.int64)

                    for k in range(diff):
                        idx = random.randint(0, len(image_array) - 1)
                        img = image_array[idx]  # [H, W, C], [0, 1], RGB
                        img = (img * 255).astype(
                            np.uint8
                        )  # [0, 255] uint8 for ToPILImage
                        aug_img = (
                            transform(img).permute(1, 2, 0).numpy()
                        )  # [H, W, C], [0, 1], RGB
                        new_images[k] = aug_img

                    image_list[j] = np.concatenate([image_array, new_images], axis=0)
                    label_list[j] = np.concatenate([label_array, new_labels], axis=0)
                    images_per_class[j] = max_class_size
                    print(f"Oversampled class {j} by {diff} images")

    # Combine classes
    new_image_list = np.concatenate(image_list, axis=0)
    new_label_list = np.concatenate(label_list, axis=0)

    # Verify class balance
    label_counts = Counter(new_label_list)
    print(f"Final dataset: {label_counts}")
    if label_counts[0] != label_counts[1] and balance_dataset:
        print("Warning: Classes are not balanced despite balancing attempt!")

    return new_image_list, new_label_list


# === Load or Generate Data ===
read_from_np_array = False
print(f"Reading from numpy array: {read_from_np_array}")
if read_from_np_array:
    with open(output_np_training_array, "rb") as f:
        train_X, test_X, train_Y, test_Y = pickle.load(f)
else:
    # Load training data
    train_X, train_Y = get_images(
        [output_bonifade_training_folder, output_attack_training_folder],
        balance_dataset=True,
        undersampling=False,  # Default to oversampling
    )

    # Load validation data
    test_X, test_Y = get_images(
        [output_bonifade_validation_folder, output_attack_validation_folder],
        balance_dataset=False,
    )

    # Save data to disk
    os.makedirs(os.path.dirname(save_labels), exist_ok=True)
    os.makedirs(os.path.dirname(output_np_training_array), exist_ok=True)

    # with open(output_np_training_array, "wb") as f:
    #     pickle.dump((train_X, test_X, train_Y, test_Y), f)
    # print(f"Saved dataset to {output_np_training_array}")

print("Shape training:", train_X.shape)
print("Shape validation:", test_X.shape)

In [None]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

# Settings
augmentation_online = True  # Online augmentation for training
generator_seed = 10
batch_size = 32

# Define transformations
base_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)


# Custom Dataset class
class CustomImageDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        """
        Args:
            images (numpy.ndarray): [N, H, W, C], [0, 1], RGB format (32x32 for Replay-Attack).
            labels (numpy.ndarray): [N], int64, 0 (bonafide) or 1 (attack).
            transform: Optional transform to apply.
        """
        self.images = images  # [N, H, W, C], [0, 1], RGB
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]  # [H, W, C], [0, 1], RGB
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)  # [C, H, W], [-1, 1]
        else:
            image = (
                torch.from_numpy(image).permute(2, 0, 1).float()
            )  # [C, H, W], [0, 1]
            image = (image - 0.5) / 0.5  # Normalize to [-1, 1]

        label = torch.tensor(label, dtype=torch.long)  # For nn.CrossEntropyLoss
        return image, label


# Create generator with seed for reproducibility
g = torch.Generator()
g.manual_seed(generator_seed)

# Datasets and loaders
train_dataset = CustomImageDataset(
    train_X,
    train_Y,
    transform=base_transforms,
)
val_dataset = CustomImageDataset(test_X, test_Y, transform=base_transforms)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    generator=g,
    num_workers=0,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
)

# Check lengths
print("Length of training batches:", len(train_loader))
print("Total training samples:", len(train_X))
print("Total validation samples:", len(test_X))

In [None]:
import numpy as np
import torch
import torchvision.transforms as transforms
from tqdm import tqdm
from collections import Counter

# Settings
augmentation_online = False  # Offline augmentation enabled
generator_seed = 10
torch.manual_seed(generator_seed)
np.random.seed(generator_seed)

# Constants (from Cell 5)
final_x, final_y, n_channels = 128, 128, 3

# Define offline augmentation transform (moderate)
offline_transform = transforms.Compose(
    [
        transforms.ToPILImage(),
        transforms.RandomRotation(20),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomAffine(
            degrees=15,
            scale=(0.85, 1.15),
        ),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),  # [C, H, W], [0, 1]
    ]
)


# Function to augment dataset
def augment_dataset(X, Y, multiplier=2):
    """
    Augment the dataset by creating multiple augmented versions of each image.

    Args:
        X (numpy.ndarray): Images, [N, 128, 128, 3], [0, 1], RGB.
        Y (numpy.ndarray): Labels, [N], int64, 0 (bonafide) or 1 (attack).
        multiplier (int): Number of copies (original + augmented) per image.

    Returns:
        numpy.ndarray: Augmented images, [N*multiplier, 128, 128, 3], [0, 1], RGB, float32.
        numpy.ndarray: Augmented labels, [N*multiplier], int64.
    """
    if len(X) != len(Y):
        raise ValueError(f"Images ({len(X)}) and labels ({len(Y)}) length mismatch")

    print(f"Original dataset: {Counter(Y)}")
    to_augment = len(X) * multiplier
    new_X = np.empty((to_augment, final_x, final_y, n_channels), dtype=np.float32)
    new_Y = np.empty((to_augment), dtype=np.int64)

    idx = 0
    for i in tqdm(range(len(X)), desc="Augmenting"):
        img = X[i]  # [32, 32, 3], [0, 1], RGB
        label = Y[i]

        # Store original
        new_X[idx] = img
        new_Y[idx] = label
        idx += 1

        # Augment (multiplier-1) times
        for _ in range(multiplier - 1):
            img_uint8 = (img * 255.0).astype(np.uint8)  # [0, 255] for ToPILImage
            aug_img_tensor = offline_transform(img_uint8)  # [C, H, W], [0, 1]
            aug_img_np = aug_img_tensor.permute(1, 2, 0).numpy()  # [32, 32, 3], [0, 1]
            new_X[idx] = aug_img_np
            new_Y[idx] = label
            idx += 1

    print(f"Augmented dataset: {Counter(new_Y)}")
    if Counter(new_Y)[0] != Counter(new_Y)[1]:
        print("Warning: Augmented dataset is not balanced!")
    return new_X, new_Y


# Apply augmentation
if not augmentation_online:
    train_X, train_Y = augment_dataset(train_X, train_Y, multiplier=2)

print("Shape training:", train_X.shape)
print("Shape validation:", test_X.shape)

In [None]:
# Datasets and loaders
train_dataset = CustomImageDataset(train_X, train_Y, transform=base_transforms)
val_dataset = CustomImageDataset(test_X, test_Y, transform=base_transforms)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    generator=g,
    num_workers=0,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
)

# Check lengths
print("Length of training batches:", len(train_loader))
print("Total training samples:", len(train_X))
print("Total validation samples:", len(test_X))

# Training (Unet architecture (CELoss))

In [None]:
import torch
import torch.nn as nn
import psutil
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, precision_recall_fscore_support, confusion_matrix
from tqdm import tqdm
import os
from collections import Counter


# Define MemoryUsage class for tracking memory usage
class MemoryUsage:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.max_RAM = []
        self.max_GPU = []

    def get_size(self, byte):
        factor = 1024
        for unit in ["", "K", "M", "GB", "T", "P"]:
            if byte < factor:
                return f"{byte:.2f} {unit}B"
            byte /= factor

    def on_epoch_end(self):
        svmem = psutil.virtual_memory()
        self.max_RAM.append((svmem.used, svmem.total, svmem.percent))
        if torch.cuda.is_available():
            self.max_GPU.append(torch.cuda.memory_allocated())  # Store in bytes
            torch.cuda.reset_peak_memory_stats()

    def on_train_end(self):
        i = np.argmax([r[0] for r in self.max_RAM])
        print(
            "MAX RAM USAGE: %s / %s (%s%%)"
            % (
                self.get_size(self.max_RAM[i][0]),
                self.get_size(self.max_RAM[i][1]),
                self.max_RAM[i][2],
            )
        )
        if self.max_GPU:
            print("MAX GPU USAGE: %s" % self.get_size(max(self.max_GPU)))


# Compute class-wise accuracy for binary classification (two-class output)
def compute_class_accuracy(outputs, targets):
    """
    Compute class-wise accuracy for binary classification.

    Args:
        outputs (torch.Tensor): [batch_size, 2], raw logits.
        targets (torch.Tensor): [batch_size], long, 0 or 1.

    Returns:
        torch.Tensor: Correct predictions per class.
        torch.Tensor: Total samples per class.
    """
    predicted = torch.argmax(outputs, dim=1)  # [batch_size], predicted class indices
    correct_per_class = torch.zeros(2, device=targets.device)
    total_per_class = torch.zeros(2, device=targets.device)
    for i in range(2):
        mask = targets == i
        correct_per_class[i] = (predicted[mask] == targets[mask]).sum().float()
        total_per_class[i] = mask.sum().float()
    return correct_per_class, total_per_class


# Training function
def train_model(
    model,
    train_loader,
    val_loader,
    optimizer,
    scheduler,
    criterion,
    device,
    num_epochs=50,
    early_stop_patience=3,
    save_dir="./models",
    best_model_pth="best_model.pth",
    last_model_pth="last_model.pth",
    save_training_metrics_plot="training_plot.png",
):
    """
    Train the UNetBinaryClassifier_CE model with CrossEntropyLoss.

    Args:
        model: UNetBinaryClassifier_CE model.
        train_loader: DataLoader for training data.
        val_loader: DataLoader for validation data.
        optimizer: Optimizer (e.g., Adam).
        scheduler: Learning rate scheduler (e.g., ReduceLROnPlateau).
        criterion: Loss function (e.g., CrossEntropyLoss).
        device: Device to train on (cuda or cpu).
        num_epochs (int): Number of epochs to train.
        early_stop_patience (int): Patience for early stopping.
        save_dir (str): Directory to save models and plots.
        best_model_pth (str): Filename for best model checkpoint.
        last_model_pth (str): Filename for last model checkpoint.
        save_training_metrics_plot (str): Filename for training metrics plot.

    Returns:
        dict: Training history with metrics.
    """
    model = model.to(device)
    history = {
        "train_loss": [],
        "val_loss": [],
        "train_acc": [],
        "val_acc": [],
        "train_f1": [],
        "val_f1": [],
        "train_precision": [],
        "val_precision": [],
        "train_recall": [],
        "val_recall": [],
        "train_class_acc": [],
        "val_class_acc": [],
        "train_confusion_matrix": [],
        "val_confusion_matrix": [],
    }
    best_val_loss = float("inf")
    early_stop_counter = 0
    os.makedirs(save_dir, exist_ok=True)
    best_model_pth = os.path.join(save_dir, best_model_pth)
    last_model_pth = os.path.join(save_dir, last_model_pth)
    memory_tracker = MemoryUsage()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    for epoch in range(num_epochs):
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        y_true, y_pred = [], []
        train_correct_per_class = torch.zeros(2, device=device)
        train_total_per_class = torch.zeros(2, device=device)
        all_logits = []
        all_probs = []

        for batch_idx, (inputs, labels) in enumerate(
            tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training")
        ):
            inputs, labels = inputs.to(device), labels.to(device)
            labels = labels.view(-1).long()  # [batch_size], long, 0 or 1
            if not torch.all((labels == 0) | (labels == 1)):
                print(f"Invalid targets: {torch.unique(labels)}")
                break
            if torch.isnan(inputs).any() or torch.isinf(inputs).any():
                print(f"Invalid inputs at epoch {epoch}, batch {batch_idx}")
                break

            optimizer.zero_grad()
            outputs = model(inputs)  # [batch_size, 2], raw logits
            loss = criterion(outputs, labels)
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"Invalid loss at epoch {epoch}, batch {batch_idx}")
                break

            loss.backward()
            optimizer.step()

            probs = torch.softmax(outputs, dim=1)  # [batch_size, 2], probabilities
            predicted = torch.argmax(
                outputs, dim=1
            )  # [batch_size], predicted class indices
            running_loss += loss.item() * inputs.size(0)
            correct_per_class, total_per_class = compute_class_accuracy(outputs, labels)
            train_correct_per_class += correct_per_class
            train_total_per_class += total_per_class
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())
            all_logits.extend(outputs.cpu().detach().numpy())
            all_probs.extend(probs.cpu().detach().numpy())

        # Log distribution of logits and probabilities
        all_logits = np.array(all_logits)  # [num_samples, 2]
        all_probs = np.array(all_probs)  # [num_samples, 2]
        print(
            f"Epoch {epoch+1} - Train Logits (Class 0): Min: {all_logits[:, 0].min():.4f}, Max: {all_logits[:, 0].max():.4f}, Mean: {all_logits[:, 0].mean():.4f}"
        )
        print(
            f"Epoch {epoch+1} - Train Logits (Class 1): Min: {all_logits[:, 1].min():.4f}, Max: {all_logits[:, 1].max():.4f}, Mean: {all_logits[:, 1].mean():.4f}"
        )
        print(
            f"Epoch {epoch+1} - Train Probabilities (Class 0): Min: {all_probs[:, 0].min():.4f}, Max: {all_probs[:, 0].max():.4f}, Mean: {all_probs[:, 0].mean():.4f}"
        )
        print(
            f"Epoch {epoch+1} - Train Probabilities (Class 1): Min: {all_probs[:, 1].min():.4f}, Max: {all_probs[:, 1].max():.4f}, Mean: {all_probs[:, 1].mean():.4f}"
        )
        print(
            f"Epoch {epoch+1} - Predicted Class Distribution: Class 0: {np.sum(np.array(y_pred) == 0)}, Class 1: {np.sum(np.array(y_pred) == 1)}"
        )

        epoch_loss = running_loss / total
        epoch_acc = correct / total
        epoch_f1 = f1_score(y_true, y_pred, average="macro")
        precision, recall, _, _ = precision_recall_fscore_support(
            y_true, y_pred, average="macro", zero_division=0
        )
        cm = confusion_matrix(y_true, y_pred)
        train_class_acc = (
            100.0 * train_correct_per_class / (train_total_per_class + 1e-8)
        )

        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        val_y_true, val_y_pred = [], []
        val_correct_per_class = torch.zeros(2, device=device)
        val_total_per_class = torch.zeros(2, device=device)
        val_all_logits = []
        val_all_probs = []

        with torch.no_grad():
            for val_inputs, val_labels in tqdm(
                val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"
            ):
                val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
                val_labels = val_labels.view(-1).long()
                outputs = model(val_inputs)  # [batch_size, 2], raw logits
                v_loss = criterion(outputs, val_labels)
                v_probs = torch.softmax(outputs, dim=1)
                v_predicted = torch.argmax(outputs, dim=1)
                val_loss += v_loss.item() * val_inputs.size(0)
                correct_per_class, total_per_class = compute_class_accuracy(
                    outputs, val_labels
                )
                val_correct_per_class += correct_per_class
                val_total_per_class += total_per_class
                val_correct += (v_predicted == val_labels).sum().item()
                val_total += val_labels.size(0)
                val_y_true.extend(val_labels.cpu().numpy())
                val_y_pred.extend(v_predicted.cpu().numpy())
                val_all_logits.extend(outputs.cpu().numpy())
                val_all_probs.extend(v_probs.cpu().numpy())

        # Log validation distribution
        val_all_logits = np.array(val_all_logits)  # [num_samples, 2]
        val_all_probs = np.array(val_all_probs)  # [num_samples, 2]
        print(
            f"Epoch {epoch+1} - Val Logits (Class 0): Min: {val_all_logits[:, 0].min():.4f}, Max: {val_all_logits[:, 0].max():.4f}, Mean: {val_all_logits[:, 0].mean():.4f}"
        )
        print(
            f"Epoch {epoch+1} - Val Logits (Class 1): Min: {val_all_logits[:, 1].min():.4f}, Max: {val_all_logits[:, 1].max():.4f}, Mean: {val_all_logits[:, 1].mean():.4f}"
        )
        print(
            f"Epoch {epoch+1} - Val Probabilities (Class 0): Min: {val_all_probs[:, 0].min():.4f}, Max: {val_all_probs[:, 0].max():.4f}, Mean: {val_all_probs[:, 0].mean():.4f}"
        )
        print(
            f"Epoch {epoch+1} - Val Probabilities (Class 1): Min: {val_all_probs[:, 1].min():.4f}, Max: {val_all_probs[:, 1].max():.4f}, Mean: {val_all_probs[:, 1].mean():.4f}"
        )
        print(
            f"Epoch {epoch+1} - Val Predicted Class Distribution: Class 0: {np.sum(np.array(val_y_pred) == 0)}, Class 1: {np.sum(np.array(val_y_pred) == 1)}"
        )

        val_loss /= val_total
        val_acc = val_correct / val_total
        val_f1 = f1_score(
            val_y_true, val_y_pred, average="weighted"
        )  # Weighted due to imbalance
        val_precision, val_recall, _, _ = precision_recall_fscore_support(
            val_y_true, val_y_pred, average="weighted", zero_division=0
        )
        val_cm = confusion_matrix(val_y_true, val_y_pred)
        val_class_acc = 100.0 * val_correct_per_class / (val_total_per_class + 1e-8)

        print(
            f"Epoch {epoch+1}/{num_epochs} - Train Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.4f}, "
            f"F1: {epoch_f1:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}"
        )
        print(
            f"Train Class-wise Acc: Class 0: {train_class_acc[0]:.2f}%, Class 1: {train_class_acc[1]:.2f}%"
        )
        print(f"Train Confusion Matrix:\n{cm}")
        print(
            f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}, "
            f"Precision: {val_precision:.4f}, Recall: {val_recall:.4f}"
        )
        print(
            f"Val Class-wise Acc: Class 0: {val_class_acc[0]:.2f}%, Class 1: {val_class_acc[1]:.2f}%"
        )
        print(f"Val Confusion Matrix:\n{val_cm}")
        print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")

        history["train_loss"].append(epoch_loss)
        history["val_loss"].append(val_loss)
        history["train_acc"].append(epoch_acc)
        history["val_acc"].append(val_acc)
        history["train_f1"].append(epoch_f1)
        history["val_f1"].append(val_f1)
        history["train_precision"].append(precision)
        history["val_precision"].append(val_precision)
        history["train_recall"].append(recall)
        history["val_recall"].append(val_recall)
        history["train_class_acc"].append(train_class_acc.tolist())
        history["val_class_acc"].append(val_class_acc.tolist())
        history["train_confusion_matrix"].append(cm.tolist())
        history["val_confusion_matrix"].append(val_cm.tolist())

        if scheduler:
            scheduler.step(val_loss)

        memory_tracker.on_epoch_end()

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            early_stop_counter = 0
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": (
                        scheduler.state_dict() if scheduler else None
                    ),
                    "train_loss": epoch_loss,
                    "val_loss": val_loss,
                    "train_acc": epoch_acc,
                    "val_acc": val_acc,
                    "train_f1": epoch_f1,
                    "val_f1": val_f1,
                    "train_precision": precision,
                    "val_precision": val_precision,
                    "train_recall": recall,
                    "val_recall": val_recall,
                    "train_class_acc": train_class_acc.tolist(),
                    "val_class_acc": val_class_acc.tolist(),
                },
                best_model_pth,
            )
        else:
            early_stop_counter += 1

        if early_stop_counter >= early_stop_patience:
            print("Early stopping triggered.")
            break

    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict() if scheduler else None,
            "train_loss": epoch_loss,
            "val_loss": val_loss,
            "train_acc": epoch_acc,
            "val_acc": val_acc,
            "train_f1": epoch_f1,
            "val_f1": val_f1,
            "train_precision": precision,
            "val_precision": val_precision,
            "train_recall": recall,
            "val_recall": val_recall,
            "train_class_acc": train_class_acc.tolist(),
            "val_class_acc": val_class_acc.tolist(),
        },
        last_model_pth,
    )

    memory_tracker.on_train_end()

    def plot_metric_subplot(index, metrics, title, ylabel, is_classwise=False):
        plt.subplot(4, 1, index)

        if is_classwise:
            for cls_idx in range(2):  # Assuming binary classification: Class 0 and 1
                plt.plot(
                    [x[cls_idx] for x in history[f"train_{metrics}"]],
                    label=f"Train Class {cls_idx} Acc",
                )
                plt.plot(
                    [x[cls_idx] for x in history[f"val_{metrics}"]],
                    label=f"Val Class {cls_idx} Acc",
                )
        else:
            for phase in ["train", "val"]:
                for metric in metrics:
                    plt.plot(
                        history[f"{phase}_{metric}"],
                        label=f"{phase.capitalize()} {metric.capitalize()}",
                    )

        plt.title(title)
        plt.xlabel("Epoch")
        plt.ylabel(ylabel)
        plt.legend()
        plt.grid()

    # === Plotting ===
    plt.figure(figsize=(12, 16))

    plot_metric_subplot(1, ["loss"], "Training and Validation Loss", "Loss")
    plot_metric_subplot(2, ["acc"], "Training and Validation Accuracy", "Accuracy")
    plot_metric_subplot(
        3, ["precision", "recall"], "Training and Validation Precision/Recall", "Score"
    )
    plot_metric_subplot(
        4, "class_acc", "Class-wise Accuracy", "Accuracy (%)", is_classwise=True
    )

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, save_training_metrics_plot))
    plt.close()

    np.save(os.path.join(save_dir, "training_metrics.npy"), history)
    return history


# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Function to get learning rate for logging
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]


# Inspect data distribution
def inspect_dataloader(loader, name="DataLoader"):
    labels = []
    for _, lbls in loader:
        labels.extend(lbls.cpu().numpy())
    label_counts = Counter(labels)
    total = len(labels)
    print(f"{name} Class Distribution:")
    print(f"Class 0: {label_counts[0]} ({100 * label_counts[0] / total:.2f}%)")
    print(f"Class 1: {label_counts[1]} ({100 * label_counts[1] / total:.2f}%)")
    return label_counts

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Function to get learning rate for logging
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]


# Inspect data distribution
def inspect_dataloader(loader, name="DataLoader"):
    labels = []
    for _, lbls in loader:
        labels.extend(lbls.cpu().numpy())
    label_counts = Counter(labels)
    total = len(labels)
    print(f"{name} Class Distribution:")
    print(f"Class 0: {label_counts[0]} ({100 * label_counts[0] / total:.2f}%)")
    print(f"Class 1: {label_counts[1]} ({100 * label_counts[1] / total:.2f}%)")
    return label_counts


# Inspect train and validation loaders
print("Inspecting Data Loaders...")
train_label_counts = inspect_dataloader(train_loader, "Training")
val_label_counts = inspect_dataloader(val_loader, "Validation")


# Apply weight initialization
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)


# Initialize model and training components
model = UNetBinaryClassifier_CE(in_channels=3, base_channels=32).to(device)
model.apply(init_weights)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-3)
criterion = nn.CrossEntropyLoss()  #  # For binary classification with two-class output
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=3
)

# Verify input normalization
sample_batch, sample_labels = next(iter(train_loader))
print("Input range:", sample_batch.min().item(), "to", sample_batch.max().item())
if sample_batch.max() > 1.0 or sample_batch.min() < -1.0:
    print("Warning: Inputs are not normalized to [-1, 1].")

# Verify model output shape for a batch
sample_batch = torch.randn(32, 3, 128, 128).to(device)
with torch.no_grad():
    outputs = model(sample_batch)
    print("Model output shape for batch:", outputs.shape)  # Expected: [32, 2]

# Train the model with custom save paths
history = train_model(
    model,
    train_loader,
    val_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    criterion=criterion,
    device=device,
    num_epochs=50,
    early_stop_patience=3,
    save_dir="model_save_path",
    best_model_pth="model_best_checkpoint.pth",
    last_model_pth="model_last_checkpoint.pth",
    save_training_metrics_plot="training_metrics.png",
)

# Visualize sample images to check for data leakage
sample_batch, sample_labels = next(iter(val_loader))
plt.figure(figsize=(10, 10))
for i in range(10):
    plt.subplot(2, 5, i + 1)
    img = sample_batch[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5  # Denormalize
    plt.imshow(img)
    plt.title(f"Label: {sample_labels[i].item()}")
    plt.axis("off")
plt.savefig(os.path.join("figure_save_path", "sample_validation_images.png"))
plt.close()

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
from scipy.interpolate import interp1d
from scipy.optimize import brentq
import itertools
import os
from time import time
import logging

# Set up logging
log_dir = "log_path"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"eval_{int(time())}.log")
logging.basicConfig(
    filename=log_file, level=logging.INFO, format="%(asctime)s - %(message)s"
)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
logging.getLogger().addHandler(console)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")

# Load the saved model
checkpoint_path = "model_save_path/model_best_checkpoint.pth"
try:
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
except FileNotFoundError:
    logging.error(f"Checkpoint file not found at {checkpoint_path}")
    raise

# Initialize model
model = UNetBinaryClassifier_CE(in_channels=3, base_channels=32).to(device)

# Load checkpoint
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
    model.load_state_dict(checkpoint["model_state_dict"])
    logging.info(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
    val_acc = checkpoint.get("val_acc", "unknown")
    if isinstance(val_acc, (int, float)):
        logging.info(f"Validation accuracy: {val_acc:.3f}%")
    else:
        logging.info(f"Validation accuracy: {val_acc}")
else:
    model.load_state_dict(checkpoint)
    logging.info("Loaded model state dictionary")

model.eval()

# Collect predictions and true labels for evaluation
y_score = []
y_pred = []
y_true = []

with torch.no_grad():
    for batch_idx, (inputs, labels) in enumerate(val_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        labels = labels.view(-1).long()

        # Validate inputs and labels
        if inputs.shape[1] != 3 or inputs.shape[2:] != (128, 128):
            logging.error(f"Invalid input shape at batch {batch_idx}: {inputs.shape}")
            break
        if not torch.all((labels == 0) | (labels == 1)):
            logging.error(
                f"Invalid labels at batch {batch_idx}: {torch.unique(labels)}"
            )
            break
        if torch.isnan(inputs).any() or torch.isinf(inputs).any():
            logging.error(f"NaN/Inf in inputs at batch {batch_idx}")
            break

        # Forward pass (model returns raw logits)
        outputs = model(inputs)  # Shape: [batch_size, 2], raw logits
        probs = torch.softmax(outputs, dim=1)  # Shape: [batch_size, 2], probabilities
        preds = torch.argmax(
            outputs, dim=1
        )  # Shape: [batch_size], predicted class indices

        if torch.isnan(probs).any() or torch.isinf(probs).any():
            logging.error(f"NaN/Inf in outputs at batch {batch_idx}")
            break

        y_score.extend(probs[:, 1].cpu().numpy())  # Probability of Class 1 (Attack)
        y_pred.extend(preds.cpu().numpy())
        y_true.extend(labels.cpu().numpy())

# Convert evaluation results to numpy arrays
y_score = np.array(y_score)
y_pred = np.array(y_pred)
y_true = np.array(y_true)


# Compute EER
def compute_eer(y_true, y_score):
    fpr, tpr, _ = roc_curve(y_true, y_score)
    fnr = 1 - tpr
    eer_threshold = brentq(lambda x: 1.0 - x - interp1d(fpr, fnr)(x), 0.0, 1.0)
    eer = interp1d(fpr, fnr)(eer_threshold)
    return eer


# Compute HTER
def compute_hter(y_true, y_pred, y_score, threshold=0.5):
    far = np.sum((y_pred == 1) & (y_true == 0)) / (np.sum(y_true == 0) + 1e-8)
    frr = np.sum((y_pred == 0) & (y_true == 1)) / (np.sum(y_true == 1) + 1e-8)
    hter = (far + frr) / 2
    return hter, far, frr


# Confusion Matrix Function
def plot_confusion_matrix(cm, classes, normalize=False, title="Confusion Matrix"):
    plt.figure(figsize=(6, 6), dpi=80)
    if normalize:
        cm = cm.astype("float") / (cm.sum(axis=1)[:, np.newaxis] + 1e-8)
        logging.info("Normalized confusion matrix")
    else:
        logging.info("Confusion matrix, without normalization")

    plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = ".2f" if normalize else "d"
    thresh = cm.max() / 2.0
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(
            j,
            i,
            format(cm[i, j], fmt),
            horizontalalignment="center",
            color="white" if cm[i, j] > thresh else "black",
        )

    plt.tight_layout()
    plt.ylabel("True Label")
    plt.xlabel("Predicted Label")
    fname = os.path.join(log_dir, f"confusion_matrix_{int(time())}.png")
    plt.savefig(fname)
    logging.info(f"Saved confusion matrix to {fname}")
    plt.close()


# Classification Report
class_names = ["Bonafide", "Attack"]
logging.info("\nClassification Report:")
report = classification_report(y_true, y_pred, target_names=class_names)
logging.info("\n" + report)

# Confusion Matrix
cnf_matrix = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=False)
plot_confusion_matrix(
    cnf_matrix, classes=class_names, normalize=True, title="Normalized Confusion Matrix"
)

# ROC Curve
fpr, tpr, _ = roc_curve(y_true, y_score)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(5, 5), dpi=80)
plt.plot(fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (area = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic")
plt.legend(loc="lower right")
fname = os.path.join(log_dir, f"roc_curve_{int(time())}.png")
plt.savefig(fname)
logging.info(f"Saved ROC curve to {fname}")
plt.close()

# EER and HTER
eer = compute_eer(y_true, y_score)
hter, far, frr = compute_hter(y_true, y_pred, y_score)
logging.info(f"FAR: {far:.4f}")
logging.info(f"FRR: {frr:.4f}")
logging.info(f"Equal Error Rate (EER): {eer:.4f}")
logging.info(f"Half Total Error Rate (HTER): {hter:.4f}")

# Score Distribution
plt.figure(figsize=(8, 6), dpi=80)
plt.hist(y_score[y_true == 0], bins=50, alpha=0.5, label="Bonafide", color="blue")
plt.hist(y_score[y_true == 1], bins=50, alpha=0.5, label="Attack", color="red")
plt.xlabel("Prediction Score (Class 1 Probability)")
plt.ylabel("Frequency")
plt.title("Score Distribution")
plt.legend(loc="upper right")
fname = os.path.join(log_dir, f"score_distribution_{int(time())}.png")
plt.savefig(fname)
logging.info(f"Saved score distribution to {fname}")
plt.close()

# Training (Unet with AAM Loss architecture)

In [None]:
import torch
import torch.nn as nn
import psutil
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, precision_recall_fscore_support, confusion_matrix
from tqdm import tqdm
import os
import time


class MemoryUsage:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.max_RAM = []
        self.max_GPU = []

    def get_size(self, byte):
        factor = 1024
        for unit in ["", "K", "M", "GB", "T", "P"]:
            if byte < factor:
                return f"{byte:.2f} {unit}B"
            byte /= factor

    def on_epoch_end(self):
        svmem = psutil.virtual_memory()
        self.max_RAM.append((svmem.used, svmem.total, svmem.percent))
        if torch.cuda.is_available():
            self.max_GPU.append(torch.cuda.memory_allocated())
            torch.cuda.reset_peak_memory_stats()

    def on_train_end(self):
        i = np.argmax([r[0] for r in self.max_RAM])
        print(
            "MAX RAM USAGE: %s / %s (%s%%)"
            % (
                self.get_size(self.max_RAM[i][0]),
                self.get_size(self.max_RAM[i][1]),
                self.max_RAM[i][2],
            )
        )
        if self.max_GPU:
            print("MAX GPU USAGE: %s" % self.get_size(max(self.max_GPU)))


def compute_class_accuracy(outputs, targets):
    """
    Compute class-wise accuracy for binary classification.

    Args:
        outputs (torch.Tensor): [batch_size, 2], raw logits.
        targets (torch.Tensor): [batch_size], long, 0 (bonafide) or 1 (attack).

    Returns:
        torch.Tensor: Correct predictions per class.
        torch.Tensor: Total samples per class.
    """
    predicted = torch.argmax(outputs, dim=1)  # [batch_size]
    correct_per_class = torch.zeros(2, device=targets.device)
    total_per_class = torch.zeros(2, device=targets.device)
    for i in range(2):
        mask = targets == i
        correct_per_class[i] = (predicted[mask] == targets[mask]).sum().float()
        total_per_class[i] = mask.sum().float()
    return correct_per_class, total_per_class


def train_model(
    model,
    train_loader,
    val_loader,
    optimizer,
    scheduler,
    criterion,
    device,
    num_epochs=30,
    early_stop_patience=3,
    save_dir="model_save_path",
    best_model_pth="best_model.pth",
    last_model_pth="last_model.pth",
    save_training_metrics_plot="training_plot.png",
):
    """
    Train the ArcFace-based UNetBinaryClassifier model for Replay-Attack.

    Args:
        model: UNetBinaryClassifier model (from Cell 1).
        train_loader: DataLoader for training data (21504 samples, 32x32 images).
        val_loader: DataLoader for validation data (8598 samples, 32x32 images).
        optimizer: Optimizer (e.g., Adam).
        scheduler: Learning rate scheduler (e.g., ReduceLROnPlateau).
        criterion: Loss function (e.g., CrossEntropyLoss).
        device: Device to train on (cuda or cpu).
        num_epochs (int): Number of epochs to train.
        early_stop_patience (int): Patience for early stopping.
        save_dir (str): Directory to save models and plots.
        best_model_pth (str): Filename for best model checkpoint.
        last_model_pth (str): Filename for last model checkpoint.
        save_training_metrics_plot (str): Filename for training metrics plot.

    Returns:
        dict: Training history with metrics.
    """
    start_time = time.time()
    model = model.to(device)
    history = {
        "train_loss": [],
        "val_loss": [],
        "train_acc": [],
        "val_acc": [],
        "train_f1": [],
        "val_f1": [],
        "train_precision": [],
        "val_precision": [],
        "train_recall": [],
        "val_recall": [],
        "train_class_acc": [],
        "val_class_acc": [],
        "train_confusion_matrix": [],
        "val_confusion_matrix": [],
    }
    best_val_loss = float("inf")
    early_stop_counter = 0
    os.makedirs(save_dir, exist_ok=True)
    best_model_pth = os.path.join(save_dir, best_model_pth)
    last_model_pth = os.path.join(save_dir, last_model_pth)
    memory_tracker = MemoryUsage()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    for epoch in range(num_epochs):
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        y_true, y_pred = [], []
        train_correct_per_class = torch.zeros(2, device=device)
        train_total_per_class = torch.zeros(2, device=device)

        for batch_idx, (inputs, labels) in enumerate(
            tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training")
        ):
            inputs, labels = inputs.to(device), labels.to(
                device
            )  # inputs: [batch_size, 3, 32, 32], labels: [batch_size], long
            if not torch.all((labels == 0) | (labels == 1)):
                print(f"Invalid targets: {torch.unique(labels)}")
                break
            if torch.isnan(inputs).any() or torch.isinf(inputs).any():
                print(f"Invalid inputs at epoch {epoch}, batch {batch_idx}")
                break

            optimizer.zero_grad()
            logits, _, _ = model(inputs)  # [batch_size, 2]
            loss = criterion(logits, labels)
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"Invalid loss at epoch {epoch}, batch {batch_idx}")
                break

            loss.backward()
            optimizer.step()

            predicted = torch.argmax(logits, dim=1)  # [batch_size]
            running_loss += loss.item() * inputs.size(0)
            correct_per_class, total_per_class = compute_class_accuracy(logits, labels)
            train_correct_per_class += correct_per_class
            train_total_per_class += total_per_class
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())

        epoch_loss = running_loss / total
        epoch_acc = correct / total
        epoch_f1 = f1_score(y_true, y_pred, average="macro")
        precision, recall, _, _ = precision_recall_fscore_support(
            y_true, y_pred, average="macro"
        )
        cm = confusion_matrix(y_true, y_pred)
        train_class_acc = (
            100.0 * train_correct_per_class / (train_total_per_class + 1e-8)
        )

        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        val_y_true, val_y_pred = [], []
        val_correct_per_class = torch.zeros(2, device=device)
        val_total_per_class = torch.zeros(2, device=device)

        with torch.no_grad():
            for val_inputs, val_labels in tqdm(
                val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"
            ):
                val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
                logits, _, _ = model(val_inputs)  # [batch_size, 2]
                v_loss = criterion(logits, val_labels)
                v_predicted = torch.argmax(logits, dim=1)
                val_loss += v_loss.item() * val_inputs.size(0)
                correct_per_class, total_per_class = compute_class_accuracy(
                    logits, val_labels
                )
                val_correct_per_class += correct_per_class
                val_total_per_class += total_per_class
                val_correct += (v_predicted == val_labels).sum().item()
                val_total += val_labels.size(0)
                val_y_true.extend(val_labels.cpu().numpy())
                val_y_pred.extend(v_predicted.cpu().numpy())

        val_loss /= val_total
        val_acc = val_correct / val_total
        val_f1 = f1_score(val_y_true, val_y_pred, average="macro")
        val_precision, val_recall, _, _ = precision_recall_fscore_support(
            val_y_true, val_y_pred, average="macro"
        )
        val_cm = confusion_matrix(val_y_true, val_y_pred)
        val_class_acc = 100.0 * val_correct_per_class / (val_total_per_class + 1e-8)

        print(
            f"Epoch {epoch+1}/{num_epochs} - Train Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.4f}, "
            f"F1: {epoch_f1:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}"
        )
        print(
            f"Train Class-wise Acc: Class 0: {train_class_acc[0]:.2f}%, Class 1: {train_class_acc[1]:.2f}%"
        )
        print(f"Train Confusion Matrix:\n{cm}")
        print(
            f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}, "
            f"Precision: {val_precision:.4f}, Recall: {val_recall:.4f}"
        )
        print(
            f"Val Class-wise Acc: Class 0: {val_class_acc[0]:.2f}%, Class 1: {val_class_acc[1]:.2f}%"
        )
        print(f"Val Confusion Matrix:\n{val_cm}")
        print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")

        history["train_loss"].append(epoch_loss)
        history["val_loss"].append(val_loss)
        history["train_acc"].append(epoch_acc)
        history["val_acc"].append(val_acc)
        history["train_f1"].append(epoch_f1)
        history["val_f1"].append(val_f1)
        history["train_precision"].append(precision)
        history["val_precision"].append(val_precision)
        history["train_recall"].append(recall)
        history["val_recall"].append(val_recall)
        history["train_class_acc"].append(train_class_acc.tolist())
        history["val_class_acc"].append(val_class_acc.tolist())
        history["train_confusion_matrix"].append(cm.tolist())
        history["val_confusion_matrix"].append(val_cm.tolist())

        if scheduler:
            scheduler.step(val_loss)

        memory_tracker.on_epoch_end()

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            early_stop_counter = 0
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": (
                        scheduler.state_dict() if scheduler else None
                    ),
                    "train_loss": epoch_loss,
                    "val_loss": val_loss,
                    "train_acc": epoch_acc,
                    "val_acc": val_acc,
                    "train_f1": epoch_f1,
                    "val_f1": val_f1,
                    "train_precision": precision,
                    "val_precision": val_precision,
                    "train_recall": recall,
                    "val_recall": val_recall,
                    "train_class_acc": train_class_acc.tolist(),
                    "val_class_acc": val_class_acc.tolist(),
                },
                best_model_pth,
            )
        else:
            early_stop_counter += 1

        if early_stop_counter >= early_stop_patience:
            print("Early stopping triggered.")
            break

    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict() if scheduler else None,
            "train_loss": epoch_loss,
            "val_loss": val_loss,
            "train_acc": epoch_acc,
            "val_acc": val_acc,
            "train_f1": epoch_f1,
            "val_f1": val_f1,
            "train_precision": precision,
            "val_precision": val_precision,
            "train_recall": recall,
            "val_recall": val_recall,
            "train_class_acc": train_class_acc.tolist(),
            "val_class_acc": val_class_acc.tolist(),
        },
        last_model_pth,
    )

    memory_tracker.on_train_end()
    training_time = time.time() - start_time
    print("TRAINING TIME")
    print(f"--- {training_time:.2f} seconds ---")

    plt.figure(figsize=(12, 16))
    plt.subplot(4, 1, 1)
    plt.plot(history["train_loss"], label="Train Loss")
    plt.plot(history["val_loss"], label="Val Loss")
    plt.title("Training and Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid()
    plt.subplot(4, 1, 2)
    plt.plot(history["train_acc"], label="Train Acc")
    plt.plot(history["val_acc"], label="Val Acc")
    plt.title("Training and Validation Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.grid()
    plt.subplot(4, 1, 3)
    plt.plot(history["train_precision"], label="Train Precision")
    plt.plot(history["val_precision"], label="Val Precision")
    plt.plot(history["train_recall"], label="Train Recall")
    plt.plot(history["val_recall"], label="Val Recall")
    plt.title("Training and Validation Precision/Recall")
    plt.xlabel("Epoch")
    plt.ylabel("Score")
    plt.legend()
    plt.grid()
    plt.subplot(4, 1, 4)
    plt.plot([x[0] for x in history["train_class_acc"]], label="Train Class 0 Acc")
    plt.plot([x[1] for x in history["train_class_acc"]], label="Train Class 1 Acc")
    plt.plot([x[0] for x in history["val_class_acc"]], label="Val Class 0 Acc")
    plt.plot([x[1] for x in history["val_class_acc"]], label="Val Class 1 Acc")
    plt.title("Class-wise Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, save_training_metrics_plot))
    plt.close()

    np.save(os.path.join(save_dir, "training_metrics.npy"), history)
    return history

'\ndevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")\nmodel = UNetBinaryClassifier()  # From code cell 1\noptimizer = torch.optim.Adam(model.parameters(), lr=0.0001)\nscheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.1, patience=2)\ncriterion = nn.CrossEntropyLoss()\nhistory = train_model(\n    model,\n    train_loader,\n    val_loader,\n    optimizer,\n    scheduler,\n    criterion,\n    device,\n    num_epochs=30,\n    early_stop_patience=3,\n    save_dir="D:/.../dataset/training/Replay_Attack/images",\n    best_model_pth="best_model.pth",\n    last_model_pth="last_model.pth",\n    save_training_metrics_plot="training_plot.png",\n)\n'

In [None]:
import torch
import torch.nn as nn
import os
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Function to get learning rate for logging
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]


# Inspect data distribution
def inspect_dataloader(loader, name="DataLoader"):
    """
    Inspect class distribution in a DataLoader.

    Args:
        loader: DataLoader for training or validation.
        name (str): Name of the DataLoader (e.g., 'Training', 'Validation').

    Returns:
        Counter: Label counts.
    """
    labels = []
    for _, lbls in loader:
        labels.extend(lbls.cpu().numpy())
    label_counts = Counter(labels)
    total = len(labels)
    print(f"{name} Class Distribution:")
    print(
        f"Class 0 (Bonafide): {label_counts[0]} ({100 * label_counts[0] / total:.2f}%)"
    )
    print(f"Class 1 (Attack): {label_counts[1]} ({100 * label_counts[1] / total:.2f}%)")
    return label_counts


# Inspect train and validation loaders
print("Inspecting Data Loaders...")
train_label_counts = inspect_dataloader(train_loader, "Training")
val_label_counts = inspect_dataloader(val_loader, "Validation")

# Initialize model and training components
model = UNetBinaryClassifier(
    in_channels=3, base_channels=32, embedding_size=512, num_classes=2
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-3)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=3
)

# Verify input normalization
sample_batch, sample_labels = next(iter(train_loader))
print("Input range:", sample_batch.min().item(), "to", sample_batch.max().item())
if sample_batch.max() > 1.0 or sample_batch.min() < -1.0:
    print("Warning: Inputs are not normalized to [-1, 1].")

# Verify model output shape for a batch
sample_batch = torch.randn(32, 3, 128, 128).to(device)  # Batch size 32, 32x32 images
with torch.no_grad():
    logits, _, _ = model(sample_batch)
    print("Model output shape for batch:", logits.shape)  # Expected: [32, 2]

# Train the model with custom save paths
history = train_model(
    model,
    train_loader,
    val_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    criterion=criterion,
    device=device,
    num_epochs=50,
    early_stop_patience=3,
    save_dir="model_save_path",
    best_model_pth="model_best_checkpoint.pth",
    last_model_pth="model_last_checkpoint.pth",
    save_training_metrics_plot="training_metrics.png",
)

# Visualize sample images to check for data leakage
sample_batch, sample_labels = next(iter(val_loader))
plt.figure(figsize=(10, 10))
for i in range(10):
    plt.subplot(2, 5, i + 1)
    img = sample_batch[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5  # Denormalize
    plt.imshow(img)
    plt.title(f"Label: {sample_labels[i].item()}")
    plt.axis("off")
plt.savefig(
    os.path.join(
        "images_save_path",
        "/validation_images.png",
    )
)
plt.close()

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
from sklearn.manifold import TSNE
from scipy.interpolate import interp1d
from scipy.optimize import brentq
import itertools
import os
from time import time
import logging

# Set up logging
log_dir = "log_dir_path"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"eval_{int(time())}.log")
logging.basicConfig(
    filename=log_file, level=logging.INFO, format="%(asctime)s - %(message)s"
)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
logging.getLogger().addHandler(console)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")

# Load the saved model
checkpoint_path = "model_best_checkpoint.pth"

try:
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
except FileNotFoundError:
    logging.error(f"Checkpoint file not found at {checkpoint_path}")
    raise

# Initialize model
model = UNetBinaryClassifier(
    in_channels=3, base_channels=32, embedding_size=512, num_classes=2
).to(device)

# Load checkpoint
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
    model.load_state_dict(checkpoint["model_state_dict"])
    logging.info(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
    val_acc = checkpoint.get("val_acc", "unknown")
    if isinstance(val_acc, (int, float)):
        logging.info(f"Validation accuracy: {val_acc:.3f}%")
    else:
        logging.info(f"Validation accuracy: {val_acc}")
else:
    model.load_state_dict(checkpoint)
    logging.info("Loaded model state dictionary")

model.eval()

# Collect predictions and true labels for evaluation
y_score = []
y_pred = []
y_true = []
all_images = []
all_labels = []

with torch.no_grad():
    for batch_idx, (inputs, labels) in enumerate(val_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        labels = labels.view(-1).float()

        # Validate inputs and labels
        if inputs.shape[1] != 3 or inputs.shape[2:] != (128, 128):
            logging.error(f"Invalid input shape at batch {batch_idx}: {inputs.shape}")
            break
        if not torch.all((labels == 0) | (labels == 1)):
            logging.error(
                f"Invalid labels at batch {batch_idx}: {torch.unique(labels)}"
            )
            break
        if torch.isnan(inputs).any() or torch.isinf(inputs).any():
            logging.error(f"NaN/Inf in inputs at batch {batch_idx}")
            break

        # Forward pass (model returns logits, probs, embeddings)
        outputs, probs, _ = model(
            inputs
        )  # outputs: [batch_size, 2], probs: [batch_size, 2]
        if probs is None:
            probs = torch.softmax(outputs, dim=1)  # Compute softmax if not returned
        probs = probs[:, 1]  # Probability for attack class (class 1)
        preds = (probs >= 0.5).float()  # Threshold at 0.5

        if torch.isnan(probs).any() or torch.isinf(probs).any():
            logging.error(f"NaN/Inf in outputs at batch {batch_idx}")
            break

        y_score.extend(probs.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())
        y_true.extend(labels.cpu().numpy())

        # Collect images and labels for t-SNE
        all_images.append(inputs.cpu())
        all_labels.append(labels.cpu())

# Convert lists to tensors for t-SNE
all_images = torch.cat(all_images, dim=0)
all_labels = torch.cat(all_labels, dim=0)

# Convert evaluation results to numpy arrays
y_score = np.array(y_score)
y_pred = np.array(y_pred)
y_true = np.array(y_true)


# Compute EER
def compute_eer(y_true, y_score):
    fpr, tpr, _ = roc_curve(y_true, y_score)
    fnr = 1 - tpr
    eer_threshold = brentq(lambda x: 1.0 - x - interp1d(fpr, fnr)(x), 0.0, 1.0)
    eer = interp1d(fpr, fnr)(eer_threshold)
    return eer


# Compute HTER
def compute_hter(y_true, y_pred, y_score, threshold=0.5):
    far = np.sum((y_pred == 1) & (y_true == 0)) / (np.sum(y_true == 0) + 1e-8)
    frr = np.sum((y_pred == 0) & (y_true == 1)) / (np.sum(y_true == 1) + 1e-8)

    hter = (far + frr) / 2
    return hter, far, frr


# Confusion Matrix Function
def plot_confusion_matrix(cm, classes, normalize=False, title="Confusion Matrix"):
    plt.figure(figsize=(6, 6), dpi=80)
    if normalize:
        cm = cm.astype("float") / (cm.sum(axis=1)[:, np.newaxis] + 1e-8)
        logging.info("Normalized confusion matrix")
    else:
        logging.info("Confusion matrix, without normalization")

    plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = ".2f" if normalize else "d"
    thresh = cm.max() / 2.0
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(
            j,
            i,
            format(cm[i, j], fmt),
            horizontalalignment="center",
            color="white" if cm[i, j] > thresh else "black",
        )

    plt.tight_layout()
    plt.ylabel("True Label")
    plt.xlabel("Predicted Label")
    fname = os.path.join(log_dir, f"replayattack_confusion_matrix_{int(time())}.png")
    plt.savefig(fname)
    logging.info(f"Saved confusion matrix to {fname}")
    plt.show()


# t-SNE Visualization for all validation images at once
def plot_tsne_validation_embeddings(model, images, labels):
    model.eval()
    with torch.no_grad():
        images = images.to(device)
        # Process all images at once
        _, _, embeddings = model(images)
        embeddings = embeddings.cpu().numpy()
        labels = labels.cpu().numpy()

    # Apply t-SNE
    tsne = TSNE(n_components=2, perplexity=30, random_state=42)
    embeddings_2d = tsne.fit_transform(embeddings)  # Shape: [num_samples, 2]

    # Plot t-SNE embeddings
    plt.figure(figsize=(10, 8))
    plt.scatter(
        embeddings_2d[labels == 1, 0],
        embeddings_2d[labels == 1, 1],
        c="red",
        label="Attack",
        alpha=0.6,
    )
    plt.scatter(
        embeddings_2d[labels == 0, 0],
        embeddings_2d[labels == 0, 1],
        c="blue",
        label="Bonafide",
        alpha=0.6,
    )
    plt.title(
        "t-SNE Visualization of Validation Embeddings (Replay-Attack, AAM-loss)",
        fontsize=14,
        fontweight="bold",
    )
    plt.xlabel("t-SNE Dimension 1")
    plt.ylabel("t-SNE Dimension 2")
    plt.legend()
    plt.grid(True)

    # Save the plot
    save_path = os.path.join(log_dir, "tsne_validation_embeddings.png")
    plt.savefig(save_path, bbox_inches="tight")
    logging.info(f"Saved t-SNE plot to: {save_path}")
    plt.show()


# Classification Report
class_names = ["Bonafide", "Attack"]
logging.info("\nClassification Report:")
report = classification_report(y_true, y_pred, target_names=class_names)
logging.info("\n" + report)

# Confusion Matrix
cnf_matrix = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=False)
plot_confusion_matrix(
    cnf_matrix, classes=class_names, normalize=True, title="Normalized Confusion Matrix"
)

# ROC Curve
fpr, tpr, _ = roc_curve(y_true, y_score)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(5, 5), dpi=80)
plt.plot(fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (area = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic (Replay-Attack)")
plt.legend(loc="lower right")
fname = os.path.join(log_dir, f"replayattack_roc_curve_{int(time())}.png")
plt.savefig(fname)
logging.info(f"Saved ROC curve to {fname}")
plt.show()

# EER and HTER
eer = compute_eer(y_true, y_score)
hter, far, frr = compute_hter(y_true, y_pred, y_score)
logging.info(f"FAR : {far:.4f}")
logging.info(f"FRR: {frr:.4f}")
logging.info(f"Equal Error Rate (EER): {eer:.4f}")
logging.info(f"Half Total Error Rate (HTER): {hter:.4f}")

# Score Distribution
plt.figure(figsize=(8, 6), dpi=80)
plt.hist(y_score[y_true == 0], bins=50, alpha=0.5, label="Bonafide", color="blue")
plt.hist(y_score[y_true == 1], bins=50, alpha=0.5, label="Attack", color="red")
plt.xlabel("Prediction Score")
plt.ylabel("Frequency")
plt.title("Score Distribution (Replay-Attack)")
plt.legend(loc="upper right")
fname = os.path.join(log_dir, f"replayattack_score_distribution_{int(time())}.png")
plt.savefig(fname)
logging.info(f"Saved score distribution to {fname}")
plt.show()

# t-SNE Visualization (commented out as in 3DMAD)
# plot_tsne_validation_embeddings(model, all_images, all_labels)

# Attention Map generation


In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import os
from sklearn.manifold import TSNE

# === Path Configuration ===
BASE_PATH = r"Your_dataset_path"
IMAGES_DIR = "images"
ATTACK_DIR = os.path.join(IMAGES_DIR, "attack_training")
BONAFIDE_DIR = os.path.join(IMAGES_DIR, "bonifade_validation")
ATTENTION_MAPS_DIR = os.path.join(IMAGES_DIR, "attention_maps/experimental")


# === Heatmap Utility ===
def generate_heatmap(feature_map):
    """
    Generate a heatmap from a feature map by averaging across channels.

    Args:
        feature_map (torch.Tensor): Feature map of shape [1, C, H, W].

    Returns:
        np.ndarray: Heatmap with JET colormap, shape [H, W, 3].
    """
    heatmap = torch.mean(feature_map.squeeze(0), dim=0).cpu().numpy()
    heatmap = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX)
    heatmap = np.uint8(heatmap)
    return cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)


def overlay_heatmap(image_path, model, transform):
    """
    Overlay a heatmap on the original image.

    Args:
        image_path (str): Path to the input image.
        model: UNetBinaryClassifier model with get_activation_map method.
        transform: torchvision transforms for preprocessing.

    Returns:
        np.ndarray: Image with heatmap overlay, shape [128, 128, 3].
    """
    image = Image.open(image_path).convert("RGB")
    input_tensor = transform(image).unsqueeze(0)
    feature_map = model.get_activation_map(input_tensor)
    heatmap = generate_heatmap(feature_map)
    original = np.array(image.resize((128, 128)))
    overlayed = cv2.addWeighted(original, 0.7, heatmap, 0.3, 0)
    return overlayed


# === t-SNE Visualization ===
def plot_tsne_embeddings(model, attack_images, bonafide_images, transform):
    """
    Plot t-SNE visualization of hypersphere embeddings for attack and bonafide images.

    Args:
        model: UNetBinaryClassifier model.
        attack_images (list): List of paths to attack images.
        bonafide_images (list): List of paths to bonafide images.
        transform: torchvision transforms for preprocessing.
    """
    model.eval()
    embeddings = []
    labels = []

    # Process attack images
    for img_path in attack_images:
        image = Image.open(img_path).convert("RGB")
        input_tensor = transform(image).unsqueeze(0)
        with torch.no_grad():
            _, _, emb = model(input_tensor)
        embeddings.append(emb.squeeze().cpu().numpy())
        labels.append(1)  # 1 for attack

    # Process bonafide images
    for img_path in bonafide_images:
        image = Image.open(img_path).convert("RGB")
        input_tensor = transform(image).unsqueeze(0)
        with torch.no_grad():
            _, _, emb = model(input_tensor)
        embeddings.append(emb.squeeze().cpu().numpy())
        labels.append(0)  # 0 for bonafide

    # Convert embeddings and labels to numpy arrays
    embeddings = np.array(embeddings)  # Shape: [num_samples, 256]
    labels = np.array(labels)  # Shape: [num_samples]

    # Apply t-SNE to reduce dimensionality to 2D
    tsne = TSNE(n_components=2, perplexity=5, random_state=42)
    embeddings_2d = tsne.fit_transform(embeddings)  # Shape: [num_samples, 2]

    # Plot t-SNE embeddings
    plt.figure(figsize=(10, 8))
    plt.scatter(
        embeddings_2d[labels == 1, 0],
        embeddings_2d[labels == 1, 1],
        c="red",
        label="Attack",
        alpha=0.6,
        s=200,
    )
    plt.scatter(
        embeddings_2d[labels == 0, 0],
        embeddings_2d[labels == 0, 1],
        c="blue",
        label="Bonafide",
        alpha=0.6,
        s=200,
    )
    plt.title(
        "t-SNE Visualization of Hypersphere Embeddings (Replay-Attack, ArcFace)",
        fontsize=12,
        fontweight="bold",
    )
    plt.xlabel("t-SNE Dimension 1", fontsize=24)
    plt.ylabel("t-SNE Dimension 2", fontsize=24)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.legend(fontsize=20)
    plt.grid(True)

    # Save the plot
    os.makedirs(os.path.join(BASE_PATH, ATTENTION_MAPS_DIR), exist_ok=True)
    save_path = os.path.join(
        BASE_PATH, ATTENTION_MAPS_DIR, "replayattack_tsne_hypersphere_embeddings.png"
    )
    plt.savefig(save_path, bbox_inches="tight")
    print(f"Saved t-SNE plot to: {save_path}")
    plt.show()


# === Main Execution ===
def main():
    # Model and image paths
    model_path = os.path.join(
        BASE_PATH, IMAGES_DIR, "replayattack_model_best_checkpoint.pth"
    )

    def get_image_paths(base_path, sub_dir, filenames):
        return [os.path.join(base_path, sub_dir, fname) for fname in filenames]

    # Add the required paths below
    attack_filenames = [
        "File_name.jpg",
    ]

    bonafide_filenames = [
        "file_namejpg",
    ]

    attack_images = get_image_paths(BASE_PATH, ATTACK_DIR, attack_filenames)
    bonafide_images = get_image_paths(BASE_PATH, BONAFIDE_DIR, bonafide_filenames)

    image_paths = attack_images + bonafide_images

    # Load model
    model = UNetBinaryClassifier(
        in_channels=3, base_channels=32, embedding_size=512, num_classes=2
    )
    checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()

    # Define transform
    transform = transforms.Compose(
        [
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    # Generating and saving individual attention maps
    os.makedirs(os.path.join(BASE_PATH, ATTENTION_MAPS_DIR), exist_ok=True)
    for idx, path in enumerate(image_paths):
        result = overlay_heatmap(path, model, transform)
        save_path = os.path.join(
            BASE_PATH, ATTENTION_MAPS_DIR, f"replayattack_overlay_{idx+1}.png"
        )
        cv2.imwrite(save_path, cv2.cvtColor(result, cv2.COLOR_RGB2BGR))
        print(f"Saved: {save_path}")

    # Ploting 3x3 grids for attack and bonafide images
    def plot_grid_with_probs(model, image_paths, transform, title, start_idx):
        fig, axes = plt.subplots(3, 3, figsize=(9, 9))
        fig.suptitle(title, fontsize=16)

        for i, ax in enumerate(axes.flat):
            if i >= len(image_paths):
                ax.axis("off")
                continue

            image_path = image_paths[i]
            image = Image.open(image_path).convert("RGB")
            input_tensor = transform(image).unsqueeze(0)
            with torch.no_grad():
                logits = model(input_tensor)
                if isinstance(logits, tuple):
                    logits = logits[0]
                probs = torch.softmax(logits, dim=1).squeeze().numpy()

            feature_map = model.get_activation_map(input_tensor)
            heatmap = generate_heatmap(feature_map)
            original = np.array(image.resize((128, 128)))
            overlayed = cv2.addWeighted(original, 0.7, heatmap, 0.3, 0)

            overlayed_rgb = cv2.cvtColor(overlayed, cv2.COLOR_BGR2RGB)
            ax.imshow(overlayed_rgb)
            label = "Attack" if probs[1] > probs[0] else "Bonafide"
            prob = probs[1] if label == "Attack" else probs[0]
            ax.set_title(f"{label} ({prob:.2f})", fontsize=9)
            ax.axis("off")

        return fig

    # Plot attack and bonafide grids
    attack_fig = plot_grid_with_probs(
        model, attack_images[:9], transform, "Attack (Replay-Attack)", 0
    )
    plt.tight_layout()
    attack_save_path = os.path.join(BASE_PATH, IMAGES_DIR, "attack_grid.png")
    attack_fig.savefig(attack_save_path)
    print(f"Saved: {attack_save_path}")

    bonafide_fig = plot_grid_with_probs(
        model, bonafide_images[:9], transform, "Bonafide (Replay-Attack)", 0
    )
    plt.tight_layout()
    bonafide_save_path = os.path.join(BASE_PATH, IMAGES_DIR, "bonafide_grid.png")
    bonafide_fig.savefig(bonafide_save_path)
    print(f"Saved: {bonafide_save_path}")

    # Combine both grids into a single figure
    combined_fig, combined_axs = plt.subplots(1, 2, figsize=(18, 9))
    combined_fig.suptitle(
        "Attention Map of Replay-Attack RGB Images: Attack vs Bonafide (ArcFace)",
        fontsize=18,
        fontweight="bold",
    )
    attack_img = plt.imread(attack_save_path)
    bonafide_img = plt.imread(bonafide_save_path)

    combined_axs[0].imshow(attack_img)
    combined_axs[0].axis("off")
    combined_axs[1].imshow(bonafide_img)
    combined_axs[1].axis("off")

    plt.tight_layout()
    combined_save_path = os.path.join(
        BASE_PATH, ATTENTION_MAPS_DIR, "replayattack_final_3x3_grid.png"
    )
    combined_fig.savefig(combined_save_path)
    print(f"Saved: {combined_save_path}")

    # # Generate t-SNE plot for hypersphere embeddings
    # plot_tsne_embeddings(model, attack_images, bonafide_images, transform)


if __name__ == "__main__":
    main()