# Falsified Image Detection using PyTorch

This notebook implements a pipeline for detecting falsified images (e.g., forged signatures) using various pre-trained deep learning models with PyTorch. The process includes:

1.  **Environment Setup**: Installs necessary packages and configures the environment (Kaggle, Colab, or local) for API access and path management.
2.  **Data Preparation**:
    *   Downloads and sets up the handwritten signature verification dataset.
    *   Splits the dataset into training, validation, and test sets.
    *   Cleans the dataset by removing any corrupted or invalid images.
3.  **Data Loading**: Defines PyTorch `Dataset` and `DataLoader` instances with image augmentations using Albumentations for efficient data handling.
4.  **Model Definition**:
    *   Lists several pre-trained models (e.g., EfficientNet, ConvNeXt, ViT) to be experimented with.
    *   Implements a `CustomModel` class that allows using these pre-trained backbones with a custom classification head and supports fine-tuning by unfreezing specified layers.
5.  **Training and Evaluation**:
    *   Defines functions for model training (`train_model`) and evaluation (`evaluate_model`). The training function includes features like early stopping, checkpointing, and integration with Optuna for hyperparameter pruning.
    *   Implements an `objective` function for Optuna to perform hyperparameter optimization (HPO) for learning rate, dense layer units, dropout, optimizer type, batch size, and the number of unfrozen layers.
6.  **Experimentation Loop**:
    *   Iterates through the selected models.
    *   For each model:
        *   Performs HPO using Optuna.
        *   Trains a final model using the best hyperparameters found.
        *   Evaluates the final model on the test set.
        *   Saves training history, test results, and model checkpoints.
7.  **Results and Checkpoint Management**:
    *   Prints a summary of the test results for all models.
    *   Visualizes training history (loss/accuracy curves) and confusion matrices.
    *   Provides functionality to package model checkpoints, metrics, and plots into a zip file for download or transfer to cloud storage (e.g., Google Drive).

The goal is to identify the best performing model and hyperparameter configuration for the task of falsified image detection.

## Project body :

### 0. Code configuration

This cell installs the necessary Python packages required for the notebook. It uses `pip install` with the `--quiet` flag to suppress verbose output during installation. The packages include:
-   `torch`, `torchvision`, `torchaudio`: Core PyTorch libraries for deep learning, computer vision, and audio processing.
-   `optuna`: A hyperparameter optimization framework.
-   `kaggle`: The Kaggle API for interacting with Kaggle datasets and competitions.
-   `transformers`: Hugging Face Transformers library for pre-trained models, particularly Vision Transformers (ViT).
-   `matplotlib`, `seaborn`, `scikit-learn`: Libraries for plotting, statistical visualization, and machine learning utilities (metrics, model selection).
-   `pillow`: Python Imaging Library (PIL) fork for image manipulation.
-   `timm`: PyTorch Image Models library, providing a wide range of pre-trained computer vision models.
-   `albumentations`: A library for fast and flexible image augmentations.

Confirmation messages are printed before and after the installations.

In [1]:
print("--- Installing necessary packages ---")
!pip install torch torchvision torchaudio --quiet
!pip install optuna --quiet
!pip install kaggle --quiet
!pip install transformers --quiet
!pip install matplotlib seaborn scikit-learn --quiet
!pip install pillow --quiet
!pip install --upgrade timm --quiet
!pip install -U albumentations --quiet
print("--- Installations complete ---")

--- Installing necessary packages ---
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.0/52.0 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m28.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.1/43.1 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.0/66.0 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m332.3/332.3 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m632.7/632.7 kB[0m [31m20.7 MB/s[0m eta [36m0:00:00[0m
[?25h--- Installations complete ---


### 1. Import necessary libraries :

This cell imports all the Python libraries and modules that will be used throughout the notebook. This includes:
-   PyTorch modules (`torch`, `torch.nn`, `torchvision`, `DataLoader`, `Dataset`).
-   Hugging Face Transformers (`ViTModel`, `AutoImageProcessor`).
-   Standard libraries (`os`, `shutil`, `numpy`, `subprocess`, `time`, `datetime`, `copy`, `pathlib`, `logging`, `sqlite3`, `glob`, `sys`, `json`, `zipfile`, `random`).
-   Data science and visualization libraries (`sklearn` for metrics and model selection, `PIL` for image handling, `matplotlib` and `seaborn` for plotting).
-   `timm` for pre-trained image models.
-   `optuna` for hyperparameter optimization.
-   `albumentations` for image augmentations.
-   `IPython.display` for creating download links in Kaggle/Colab.
-   `tqdm` for progress bars.

It also sets up basic logging configuration to display informational messages with timestamps.

In [2]:
print("--- Importing libraries ---")
import torch
import torch.nn as nn
import torchvision
import os
import shutil
import numpy as np
import optuna
import subprocess
import matplotlib.pyplot as plt
import seaborn as sns
import time
import timm
import glob
import sys
import json 
import zipfile
import random
import cv2
import albumentations as A
import copy
import logging
import sqlite3 
from torchvision import transforms, datasets, models as torchvision_models
from transformers import ViTModel, AutoImageProcessor
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix, accuracy_score
from PIL import Image, UnidentifiedImageError
from torch.utils.data import DataLoader, Dataset
from datetime import timedelta
from pathlib import Path
from tqdm.auto import tqdm
from albumentations.pytorch import ToTensorV2
from IPython.display import FileLink

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(message)s', 
    datefmt='%m/%d %H:%M:%S', 
    handlers=[
        logging.StreamHandler(sys.stdout)
    ],
    force=True  
)
print("--- Library imports complete ---")

--- Importing libraries ---
--- Library imports complete ---


### Authentication :

This cell defines a function `configure_kaggle_credentials` to detect the execution environment (Kaggle, Colab, or unknown) and configure Kaggle API credentials accordingly.
-   If running in a **Kaggle environment**, it sets `ENVIRONMENT` to "kaggle".
-   If running in a **Colab environment**, it sets `ENVIRONMENT` to "colab", attempts to mount Google Drive, and creates a directory in Google Drive for storing checkpoints.
-   If the environment is **unknown**, it sets `ENVIRONMENT` to "unknown" and warns the user to configure credentials manually.
The function is then called, and the detected `ENVIRONMENT` is printed.

In [3]:
def configure_kaggle_credentials():
    global ENVIRONMENT 
    if os.path.exists("/kaggle/working/"):
        ENVIRONMENT = "kaggle"
        logger.info("Detected Kaggle environment.")
    elif os.path.exists("/content"):
        ENVIRONMENT = "colab"
        logger.info("Detected Colab environment.")

        try:
            from google.colab import files 
            from google.colab import drive
            if not os.path.exists('/content/drive'):
                logger.info("Mounting Google Drive...")
                drive.mount('/content/drive')
            
            drive_dir = os.path.join(
                '/content/drive/Shareddrives/PCD/checkpoints/SignatureVerification'
            )
            os.makedirs(drive_dir, exist_ok=True)
            
        except ImportError as e:
            logger.error(f"Colab libraries unavailable: {str(e)}")
            print("Colab environment detected but required libraries are missing.")
            return False
    else:
        ENVIRONMENT = "unknown"
        logger.warning("Unknown environment detected.")
        print("Unknown environment detected. Please configure Kaggle credentials manually.")
        return False

success = configure_kaggle_credentials()
ENVIRONMENT

05/07 22:50:47 - Detected Kaggle environment.


'kaggle'

### 2. Checkpoints and Tunner paths :

This cell sets up the paths for storing model checkpoints and the Optuna database file based on the `ENVIRONMENT` detected in the previous cell.
-   `CHECKPOINT_BASE_DIR`: The base directory where model checkpoints will be saved.
-   `OPTUNA_DB_PATH`: The SQLite database path for Optuna to store hyperparameter tuning trial results.
It ensures that the checkpoint base directory exists by creating it if necessary and logs the configured paths.

In [4]:
if ENVIRONMENT == 'kaggle':
    CHECKPOINT_BASE_DIR = "/kaggle/working/checkpoints/SignatureVerification"
    OPTUNA_DB_PATH = "sqlite:////kaggle/working/optuna_signature_verification.db"
elif ENVIRONMENT == 'colab':
    CHECKPOINT_BASE_DIR = "/content/checkpoints/SignatureVerification"
    OPTUNA_DB_PATH = "sqlite:////content/optuna_signature_verification.db"
else:
    CHECKPOINT_BASE_DIR = "./checkpoints/SignatureVerification"
    OPTUNA_DB_PATH = "sqlite:///optuna_signature_verification.db"
os.makedirs(CHECKPOINT_BASE_DIR, exist_ok=True)
logger.info(f"Checkpoint directory set to: {CHECKPOINT_BASE_DIR}")
logger.info(f"Optuna DB path set to: {OPTUNA_DB_PATH}")

05/07 22:50:48 - Checkpoint directory set to: /kaggle/working/checkpoints/SignatureVerification
05/07 22:50:48 - Optuna DB path set to: sqlite:////kaggle/working/optuna_signature_verification.db


### 3. DataSet Setup  :

This cell configures the primary dataset (`handwritten-signature-verification` by `tienen`) and handles its download and extraction based on the execution environment.
-   It defines the dataset name, owner, and zip file name.
-   **Kaggle Environment**: It sets `BASE_DIR` to the expected path of the dataset within the Kaggle input directory. It raises an error if the dataset is not found.
-   **Colab Environment**: It sets `UNZIP_TARGET_DIR` and `BASE_DIR`. If the data is not found at the expected path, it downloads the dataset using the Kaggle API (`kaggle datasets download`) and unzips it using the `unzip` command. It includes error handling for download and unzipping processes and removes the zip file after successful extraction.
-   **Unknown/Local Environment**: It assumes the dataset is already present locally at a predefined path and sets `BASE_DIR` accordingly. It raises an error if the dataset is not found.
The cell logs the actions taken and the final `BASE_DIR` used for accessing the dataset.

In [5]:
DATASET_NAME = "handwritten-signature-verification"
DATASET_OWNER = "tienen"
ZIP_FILE = f"{DATASET_NAME}.zip"

# Detect environment and set up dataset
print("\n--- Setting and Dataset ---")
if ENVIRONMENT == 'kaggle':
    BASE_DIR = "/kaggle/input/handwritten-signature-verification/data/data"
    if not os.path.exists(BASE_DIR):
        raise FileNotFoundError(
            f"Kaggle dataset not found at expected path: {BASE_DIR}. "
            "Please verify the dataset is attached to the notebook."
        )
    logger.info(f"Using Kaggle dataset path: {BASE_DIR}")
elif ENVIRONMENT == 'colab':
    UNZIP_TARGET_DIR = "/content/data"
    CORRECT_DATA_PATH = os.path.join(UNZIP_TARGET_DIR, "data/data")
    BASE_DIR = CORRECT_DATA_PATH

    if not os.path.exists(CORRECT_DATA_PATH):
        logger.info(f"Data not found at {CORRECT_DATA_PATH}. Downloading and unzipping...")
        if os.path.exists(UNZIP_TARGET_DIR):
            logger.warning(f"Removing existing directory {UNZIP_TARGET_DIR} before unzipping.")
            shutil.rmtree(UNZIP_TARGET_DIR)
        os.makedirs(UNZIP_TARGET_DIR, exist_ok=True)

        # Download dataset
        try:
            cmd = ["kaggle", "datasets", "download", "-d", f"{DATASET_OWNER}/{DATASET_NAME}", "-p", "/content"]
            subprocess.run(cmd, check=True, text=True, capture_output=True)
            logger.info("Dataset download complete.")
        except subprocess.CalledProcessError as e:
            logger.error(f"Error during dataset download: {e.stderr}")
            print(f"Download failed: {e.stderr}")
            exit(1)
        except Exception as e:
            logger.error(f"Error during dataset download: {e}")
            print(f"Download failed: {e}")
            exit(1)

        # Unzip dataset
        zip_path = os.path.join("/content", ZIP_FILE)
        logger.info(f"Unzipping {ZIP_FILE} to {UNZIP_TARGET_DIR}...")
        try:
            result = subprocess.run(
                ["unzip", "-q", zip_path, "-d", UNZIP_TARGET_DIR],
                capture_output=True,
                text=True,
                check=True
            )
            logger.info("Unzipping successful.")
            os.remove(zip_path)
            logger.info("Deleted zip file.")

            if not os.path.exists(CORRECT_DATA_PATH):
                logger.error(
                    f"Unzipped successfully, but expected data path {CORRECT_DATA_PATH} not found. "
                    "Check dataset structure."
                )
                print(f"Error: Expected path {CORRECT_DATA_PATH} not found after unzipping.")
                exit(1)
            else:
                logger.info(f"Verified data path exists: {CORRECT_DATA_PATH}")
        except subprocess.CalledProcessError as e:
            logger.error(f"Unzipping failed: {e.stderr}")
            print(f"Unzipping failed: {e.stderr}")
            exit(1)
        except Exception as e:
            logger.error(f"Error during unzipping or verification: {e}")
            print(f"Error during unzipping: {e}")
            exit(1)
    else:
        logger.info(f"Dataset already exists at {CORRECT_DATA_PATH}.")
else:
    logger.warning("Unknown environment detected. Assuming local dataset.")
    BASE_DIR = "./data/handwritten-signature-verification/data/data"
    if not os.path.exists(BASE_DIR):
        raise FileNotFoundError(
            f"Dataset not found at expected local path: {BASE_DIR}. "
            "Please place the dataset manually or adjust BASE_DIR."
        )
    logger.info(f"Using local dataset path: {BASE_DIR}")


--- Setting and Dataset ---
05/07 22:50:48 - Using Kaggle dataset path: /kaggle/input/handwritten-signature-verification/data/data


### 4. GPU check and Configuration

This cell checks for the availability of a CUDA-enabled GPU and configures PyTorch to use it.
-   It prints the current PyTorch version.
-   If a GPU is available (`torch.cuda.is_available()` is true):
    -   It sets the `device` to `torch.device("cuda")`.
    -   It logs the name of the GPU being used.
    -   It optionally clears the CUDA cache using `torch.cuda.empty_cache()`.
-   If no GPU is available, it sets the `device` to `torch.device("cpu")` and logs that the CPU will be used.

In [6]:
print("\n--- Checking GPU Configuration ---")
print("PyTorch version:", torch.__version__)
if torch.cuda.is_available():
    device = torch.device("cuda")
    logging.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
    torch.cuda.empty_cache()
else:
    device = torch.device("cpu")
    logging.info("Using CPU")


--- Checking GPU Configuration ---
PyTorch version: 2.5.1+cu121
05/07 22:50:48 - Using GPU: Tesla T4


### 5. Define constants :

This cell defines several constants and paths used throughout the notebook:
-   `model_img_sizes`: A dictionary mapping model names (e.g., "EfficientNetV2-S", "ConvNeXt_Base") to their required input image dimensions (height, width). This is crucial for preprocessing images correctly for each model.
-   `SPLIT_BASE_DIR`: The base directory where the split dataset (train, validation, test) will be stored. This path is determined based on the `ENVIRONMENT` (Colab or Kaggle/local).
-   `TRAIN_DIR`, `VAL_DIR`, `TEST_DIR`: Full paths to the training, validation, and test subdirectories within `SPLIT_BASE_DIR`.
-   `OPTUNA_DB_PATH`: Re-defines the Optuna database path if it wasn't set globally, ensuring its availability. It also ensures the directory for the Optuna database exists.
-   It creates the `TRAIN_DIR`, `VAL_DIR`, and `TEST_DIR` directories if they don't already exist.
-   Finally, it verifies that `BASE_DIR` (from Cell 3, the raw dataset path) is correctly set, raising an error if not.

In [7]:
print("\n--- Defining Constants ---")
model_img_sizes = {
    "EfficientNetV2-S": (384, 384), 
    "ConvNeXt_Base": (224, 224),    
    "DeiT_Base": (224, 224),        
    "BEiT_Large": (224, 224),       
    "EfficientNet_B7": (600, 600),  
    "ResNetRS50": (224, 224),       
    "InceptionV3": (299, 299),      
    "Xception": (299, 299),         
    "ViT_Base": (224, 224),         
    "MobileNetV3_Large": (224, 224) 
}

if 'SPLIT_BASE_DIR' not in globals(): 
    if  ENVIRONMENT == 'colab':
         SPLIT_BASE_DIR = '/content/split_dataset'
    else: 
         SPLIT_BASE_DIR = '/kaggle/working/split_dataset' 
    logging.info(f"SPLIT_BASE_DIR defined  as: {SPLIT_BASE_DIR}")


TRAIN_DIR = os.path.join(SPLIT_BASE_DIR, 'train')
VAL_DIR = os.path.join(SPLIT_BASE_DIR, 'val')
TEST_DIR = os.path.join(SPLIT_BASE_DIR, 'test')

if 'OPTUNA_DB_PATH' not in globals(): 
    if  ENVIRONMENT == 'colab':
        OPTUNA_DB_PATH = "sqlite:////content/optuna_signature_verification.db"
    else: 
        OPTUNA_DB_PATH = "sqlite:////kaggle/working/optuna_signature_verification.db"
    logging.info(f"OPTUNA_DB_PATH defined  as: {OPTUNA_DB_PATH}")

db_dir = os.path.dirname(OPTUNA_DB_PATH.replace("sqlite:///", ""))
if db_dir and not os.path.exists(db_dir):
    os.makedirs(db_dir, exist_ok=True)


# Create split directories
os.makedirs(TRAIN_DIR, exist_ok=True)
os.makedirs(VAL_DIR, exist_ok=True)
os.makedirs(TEST_DIR, exist_ok=True)

if 'BASE_DIR' not in globals():
    raise NameError("CRITICAL ERROR: BASE_DIR was not set correctly in Section 4!")
else:
    logging.info(f"Verified BASE_DIR : {BASE_DIR}")


--- Defining Constants ---
05/07 22:50:48 - SPLIT_BASE_DIR defined  as: /kaggle/working/split_dataset
05/07 22:50:48 - Verified BASE_DIR : /kaggle/input/handwritten-signature-verification/data/data


### 6. Function to split the dataset into train, validation, and test sets :

This cell defines and then calls the `split_dataset` function.
The function `split_dataset` takes the raw dataset directory (`base_dir`), target directories for train, validation, and test sets, and splitting ratios as input.
Its key functionalities are:
-   **Ratio Validation**: Ensures that `train_ratio`, `val_ratio`, and `test_ratio` are valid (between 0 and 1, and sum to 1).
-   **Skip if Already Split**: Checks if the split directories (`train_dir`, `val_dir`, `test_dir`) already exist and contain data. If so, and `force_resplit` is `False`, it skips the splitting process and prints a summary of the existing split.
-   **Force Resplit**: If `force_resplit` is `True`, it clears any existing split directories before proceeding.
-   **Directory Creation**: Creates subdirectories for each class ('real', 'forged') within the train, validation, and test directories.
-   **Data Splitting**:
    -   Iterates through each class ('real', 'forged') in the `base_dir`.
    -   Collects all image file paths for the current class.
    -   Shuffles the image list for randomness (with a fixed seed for reproducibility).
    -   Splits the shuffled list into train, validation, and test sets based on the provided ratios.
    -   Copies the image files to their respective target directories (e.g., `train_dir/real`, `val_dir/forged`). It handles potential filename collisions by appending a counter if a file with the same name already exists in the destination.
-   **Summary Print**: After splitting, it prints a summary of the number of 'real' and 'forged' images in each split (train, validation, test).

The function is then called with `BASE_DIR` (raw data), `TRAIN_DIR`, `VAL_DIR`, `TEST_DIR`, and `force_resplit=False` (meaning it will use existing splits if available, otherwise it will perform the split).

In [8]:
def split_dataset(base_dir, train_dir, val_dir, test_dir, train_ratio=0.7, val_ratio=0.15, force_resplit=False):
    logging.info("--- Starting Dataset Split ---")
    test_ratio = 1.0 - train_ratio - val_ratio
    if not (0 < train_ratio < 1 and 0 < val_ratio < 1 and 0 < test_ratio < 1 and abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6):
        raise ValueError("Ratios must be between 0 and 1, and sum to 1.")

    # Check if splitting is needed
    if not force_resplit and \
       os.path.exists(os.path.join(train_dir, 'real')) and len(os.listdir(os.path.join(train_dir, 'real'))) > 0 and \
       os.path.exists(os.path.join(val_dir, 'real')) and len(os.listdir(os.path.join(val_dir, 'real'))) > 0 and \
       os.path.exists(os.path.join(test_dir, 'real')) and len(os.listdir(os.path.join(test_dir, 'real'))) > 0:
        logging.info("Split dataset already exists. Skipping split.")
        # Print summary of existing split
        print("Existing Dataset split summary:")
        for split, dir_path in [('Train', train_dir), ('Validation', val_dir), ('Test', test_dir)]:
            try:
                real_count = len(os.listdir(os.path.join(dir_path, 'real'))) if os.path.exists(os.path.join(dir_path, 'real')) else 0
                forged_count = len(os.listdir(os.path.join(dir_path, 'forged'))) if os.path.exists(os.path.join(dir_path, 'forged')) else 0
                print(f"  {split}: {real_count} real, {forged_count} forged images")
            except FileNotFoundError:
                 print(f"  {split}: Class directory not found (real or forged).")
        return

    logging.info(f"Splitting data from {base_dir} with ratios: Train={train_ratio}, Val={val_ratio}, Test={test_ratio}")

    # Clear existing split directories if force_resplit is True
    if force_resplit:
        logging.warning("Force resplit enabled. Clearing existing split directories.")
        for split_dir in [train_dir, val_dir, test_dir]:
            if os.path.exists(split_dir):
                shutil.rmtree(split_dir)

    # Create directories for train, val, and test with class subfolders
    for split in ['train', 'val', 'test']:
        for class_name in ['real', 'forged']:
            os.makedirs(os.path.join(SPLIT_BASE_DIR, split, class_name), exist_ok=True)

    # Split data for each class
    all_copied_files = set() 
    for class_name in ['real', 'forged']:
        class_path = os.path.join(base_dir, class_name)
        if not os.path.isdir(class_path):
            logging.error(f"Class directory not found: {class_path}. Check BASE_DIR and dataset structure.")
            continue 

        images = []
        for root, _, files in os.walk(class_path):
            for file in files:
                if file.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.gif')):
                    full_path = os.path.join(root, file)
                    if full_path not in all_copied_files: 
                        images.append(full_path)


        logging.info(f"Found {len(images)} images in class '{class_name}'.")
        if not images:
            logging.warning(f"No images found for class '{class_name}'. Check directory structure and file types.")
            continue

        # Ensure reproducibility
        np.random.seed(42)
        np.random.shuffle(images)

        # Calculate split indices
        n_total = len(images)
        n_train = int(n_total * train_ratio)
        n_val = int(n_total * val_ratio)
        # n_test = n_total - n_train - n_val # Remainder goes to test

        train_images = images[:n_train]
        val_images = images[n_train : n_train + n_val]
        test_images = images[n_train + n_val :]

        # Function to copy files
        def copy_files(file_list, target_dir, class_n):
            target_class_dir = os.path.join(target_dir, class_n)
            copied_count = 0
            for img_path in file_list:
                try:
                    # Use a simpler destination name (original filename)
                    dest_name = os.path.basename(img_path)
                    dest_path = os.path.join(target_class_dir, dest_name)
                    # Handle potential filename collisions (though less likely with shuffle)
                    counter = 1
                    while os.path.exists(dest_path):
                        name, ext = os.path.splitext(dest_name)
                        dest_path = os.path.join(target_class_dir, f"{name}_{counter}{ext}")
                        counter += 1

                    shutil.copy2(img_path, dest_path) 
                    all_copied_files.add(img_path) 
                    copied_count += 1
                except Exception as e:
                    logging.error(f"Failed to copy {img_path} to {dest_path}: {e}")
            return copied_count

        # Copy images
        logging.info(f"Copying {class_name} images...")
        train_copied = copy_files(train_images, train_dir, class_name)
        val_copied = copy_files(val_images, val_dir, class_name)
        test_copied = copy_files(test_images, test_dir, class_name)
        logging.info(f"  Copied to Train: {train_copied}, Val: {val_copied}, Test: {test_copied}")


    # Print final split summary
    logging.info("--- Dataset split completed ---")
    print("Final Dataset split summary:")
    for split, dir_path in [('Train', train_dir), ('Validation', val_dir), ('Test', test_dir)]:
        try:
            real_count = len(os.listdir(os.path.join(dir_path, 'real'))) if os.path.exists(os.path.join(dir_path, 'real')) else 0
            forged_count = len(os.listdir(os.path.join(dir_path, 'forged'))) if os.path.exists(os.path.join(dir_path, 'forged')) else 0
            print(f"  {split}: {real_count} real, {forged_count} forged images (Total: {real_count + forged_count})")
        except FileNotFoundError:
             print(f"  {split}: Class directory not found (real or forged).")


# Split the dataset (force_resplit=False by default)
split_dataset(BASE_DIR, TRAIN_DIR, VAL_DIR, TEST_DIR, force_resplit=False)

05/07 22:50:48 - --- Starting Dataset Split ---
05/07 22:50:48 - Splitting data from /kaggle/input/handwritten-signature-verification/data/data with ratios: Train=0.7, Val=0.15, Test=0.15000000000000005
05/07 22:50:54 - Found 3188 images in class 'real'.
05/07 22:50:54 - Copying real images...
05/07 22:51:22 -   Copied to Train: 2231, Val: 478, Test: 479
05/07 22:51:24 - Found 2984 images in class 'forged'.
05/07 22:51:24 - Copying forged images...
05/07 22:51:46 -   Copied to Train: 2088, Val: 447, Test: 449
05/07 22:51:46 - --- Dataset split completed ---
Final Dataset split summary:
  Train: 2231 real, 2088 forged images (Total: 4319)
  Validation: 478 real, 447 forged images (Total: 925)
  Test: 479 real, 449 forged images (Total: 928)


### 7. Removing damaged images:

This cell defines and uses the `check_and_remove_images_quiet` function to ensure data integrity by removing corrupted or unreadable images from the dataset splits.
The function `check_and_remove_images_quiet`:
-   Takes a directory path (e.g., `TRAIN_DIR`) as input.
-   Iterates through 'real' and 'forged' subdirectories within the given directory.
-   Collects all image files (common extensions like .png, .jpg, etc.).
-   For each image file:
    -   Tries to open and verify it using `PIL.Image.open()` and `img.verify()`, then `img.load()` to ensure it's fully readable.
    -   If any `UnidentifiedImageError`, `IOError`, `SyntaxError`, `FileNotFoundError`, or other exception occurs during this process, the image is considered invalid.
    -   Invalid images are removed from the filesystem using `os.remove()`.
-   It displays a progress bar showing the number of files checked and removed.
-   Prints a summary of checked and removed files for the processed directory.

The cell then iterates through `TRAIN_DIR`, `VAL_DIR`, and `TEST_DIR`, calling `check_and_remove_images_quiet` on each to clean all dataset splits. A final summary of the total number of files removed across all sets is printed.

In [9]:
def check_and_remove_images_quiet(directory):
    print(f"--- Checking for invalid images in '{directory}' ---")

    removed_files = 0
    checked_files = 0
    files_to_check = []

    # 1. Collect files first for accurate progress
    for class_name in ['real', 'forged']:
        class_dir = os.path.join(directory, class_name)
        if not os.path.isdir(class_dir):
            continue
        try:
            for filename in os.listdir(class_dir):
                file_path = os.path.join(class_dir, filename)
                if os.path.isfile(file_path) and filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff')):
                     files_to_check.append(file_path)
        except OSError as e:
             print(f"\nError listing directory '{class_dir}': {e}") 

    total_files = len(files_to_check)
    if total_files == 0:
        print("No image files found to check in this directory.")
        return 0

    # 2. Process files and update progress bar
    for idx, file_path in enumerate(files_to_check):
        checked_files += 1
        is_valid = True
        try:
            with Image.open(file_path) as img:
                img.verify()
            with Image.open(file_path) as img:
                img.load()
        except (UnidentifiedImageError, IOError, SyntaxError):
            is_valid = False
        except FileNotFoundError:

            is_valid = False
            checked_files -=1
        except Exception:
            is_valid = False

        if not is_valid:
            try:
                if os.path.exists(file_path):
                    os.remove(file_path)
                    removed_files += 1
            except OSError:
                pass
            except Exception:
                pass


        progress_message = f"Checked: {idx+1}/{total_files} files | Removed: {removed_files}"
        print(f"\r{progress_message:<80}", end="") 
        sys.stdout.flush()

    # 3. Print final summary
    print() 
    print(f"Finished check for '{directory}'. Checked {checked_files} files. Removed {removed_files} invalid files.")
    return removed_files



print("\n--- Starting Image Validity Check ---")
total_removed_quiet = 0
for data_dir in [TRAIN_DIR, VAL_DIR, TEST_DIR]:
    if os.path.exists(data_dir):
        total_removed_quiet += check_and_remove_images_quiet(data_dir)
    else:
        print(f"Directory '{data_dir}' not found. Skipping validity check.")

print(f"--- Image validity check complete. Total files removed across all sets: {total_removed_quiet} ---")


--- Starting Image Validity Check ---
--- Checking for invalid images in '/kaggle/working/split_dataset/train' ---
Checked: 4319/4319 files | Removed: 1                                           
Finished check for '/kaggle/working/split_dataset/train'. Checked 4319 files. Removed 1 invalid files.
--- Checking for invalid images in '/kaggle/working/split_dataset/val' ---
Checked: 925/925 files | Removed: 0                                             
Finished check for '/kaggle/working/split_dataset/val'. Checked 925 files. Removed 0 invalid files.
--- Checking for invalid images in '/kaggle/working/split_dataset/test' ---
Checked: 928/928 files | Removed: 0                                             
Finished check for '/kaggle/working/split_dataset/test'. Checked 928 files. Removed 0 invalid files.
--- Image validity check complete. Total files removed across all sets: 1 ---


### 8. Data Loaders

This cell defines the necessary components for loading and transforming image data for PyTorch models.
It consists of two main parts:

1.  **`AlbumentationsDataset` Class**:
    *   A custom PyTorch `Dataset` class that wraps a list of image paths and their corresponding labels.
    *   It takes an Albumentations `transform` pipeline as an argument.
    *   In `__getitem__`, it loads an image from a path, converts it to RGB, converts it to a NumPy array, and then applies the specified Albumentations transformations.
    *   Includes error handling for `FileNotFoundError`, `UnidentifiedImageError`, and other exceptions during image loading, returning a placeholder tensor and a special label (-1) in case of an error.

2.  **`get_data_loaders` Function**:
    *   This function creates and returns PyTorch `DataLoader` instances for the training, validation, and test sets.
    *   **Transforms**: It defines separate Albumentations transform pipelines for training (`train_transform`) and validation/testing (`val_test_transform`).
        *   `train_transform` includes resizing, various augmentations like rotation, flips, affine transformations, perspective changes, color jitter, Gaussian blur, normalization, and conversion to a PyTorch tensor.
        *   `val_test_transform` includes only resizing, normalization, and tensor conversion.
    *   **Data Collection**: It uses a helper function `collect_paths_labels` to gather image file paths and their corresponding labels (0 for 'real', 1 for 'forged') from the `TRAIN_DIR`, `VAL_DIR`, and `TEST_DIR`.
    *   **Dataset Instantiation**: It creates instances of `AlbumentationsDataset` for train, validation, and test sets using the collected paths, labels, and the appropriate transform pipelines.
    *   **Logging**: It logs dataset sizes and class distributions. It also checks for empty datasets and raises a `ValueError` if any split is empty.
    *   **DataLoader Creation**: It creates `DataLoader` objects for each dataset, specifying batch size, shuffle behavior (True for training, False for validation/test), `num_workers` (set to 0 in this version for compatibility, but can be increased for parallel data loading), and `pin_memory`.
    *   It logs a success message upon completion.

In [10]:
class AlbumentationsDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None, img_size=(224, 224)):
        if len(image_paths) != len(labels):
            raise ValueError(f"Number of image paths ({len(image_paths)}) does not match number of labels ({len(labels)}).")

        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        self.img_size = img_size 

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]

        try:
            image = Image.open(image_path).convert('RGB')
            image_np = np.array(image)

            if self.transform:
                augmented = self.transform(image=image_np)
                image_tensor = augmented['image']
            else:
                basic_transform = A.Compose([
                    A.Resize(self.img_size[0], self.img_size[1]),
                    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                    ToTensorV2()
                ])
                image_tensor = basic_transform(image=image_np)['image']

            return image_tensor, torch.tensor(label, dtype=torch.long)

        except FileNotFoundError:
            logging.error(f"File not found in __getitem__: {image_path}. Check dataset integrity.")
            placeholder_img = torch.zeros((3, self.img_size[0], self.img_size[1]), dtype=torch.float32)
            return placeholder_img, torch.tensor(-1, dtype=torch.long)
        except UnidentifiedImageError:
             logging.error(f"Corrupted or unidentified image format in __getitem__: {image_path}.")
             placeholder_img = torch.zeros((3, self.img_size[0], self.img_size[1]), dtype=torch.float32)
             return placeholder_img, torch.tensor(-1, dtype=torch.long)
        except Exception as e:
            logging.error(f"Error processing image {image_path} in __getitem__: {e}", exc_info=True)
            placeholder_img = torch.zeros((3, self.img_size[0], self.img_size[1]), dtype=torch.float32)
            return placeholder_img, torch.tensor(-1, dtype=torch.long)

def get_data_loaders(model_name, batch_size, img_size, num_workers=2): 
    logging.info(f"Creating DataLoaders for {model_name} (Img Size: {img_size}, Batch Size: {batch_size}, Num Workers: {num_workers})...") 
    imagenet_mean = [0.485, 0.456, 0.406]
    imagenet_std = [0.229, 0.224, 0.225]
    

    rotation_p = 0.3; h_flip_p = 0.5; v_flip_p = 0.0; affine_p = 0.3
    color_jitter_p = 0.5; blur_p = 0.3; perspective_p = 0.1

    train_transform = A.Compose([
        A.Resize(img_size[0], img_size[1]),
        A.Rotate(limit=20, p=rotation_p, border_mode=cv2.BORDER_REFLECT_101),
        A.HorizontalFlip(p=h_flip_p),
        A.VerticalFlip(p=v_flip_p),
        A.Affine(scale=(0.85, 1.15), translate_percent=(-0.1, 0.1), shear=(-10, 10), p=affine_p, keep_ratio=True),
        A.Perspective(scale=(0.03, 0.08), p=perspective_p, keep_size=True),
        A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=color_jitter_p),
        A.GaussianBlur(blur_limit=(3, 7), sigma_limit=(0.1, 1.5), p=blur_p),
        A.Normalize(mean=imagenet_mean, std=imagenet_std),
        ToTensorV2(),
    ])

    val_test_transform = A.Compose([
        A.Resize(img_size[0], img_size[1]),
        A.Normalize(mean=imagenet_mean, std=imagenet_std),
        ToTensorV2(),
    ])

    # --- Collect Image Paths and Labels---
    def collect_paths_labels(split_dir):
        paths = []
        labels = []
        for class_idx, class_name in enumerate(['real', 'forged']):
            class_path = os.path.join(split_dir, class_name)
            if not os.path.isdir(class_path):
                logging.warning(f"Class directory not found: {class_path}. Skipping.")
                continue
            try:
                for filename in os.listdir(class_path):
                    if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff')):
                        image_path = os.path.join(class_path, filename)
                        if os.path.isfile(image_path):
                            paths.append(image_path)
                            labels.append(class_idx)
            except OSError as e:
                logging.error(f"Error listing files in {class_path}: {e}")
        return paths, labels

    train_image_paths, train_labels = collect_paths_labels(TRAIN_DIR)
    val_image_paths, val_labels = collect_paths_labels(VAL_DIR)
    test_image_paths, test_labels = collect_paths_labels(TEST_DIR)

    # --- Create Datasets---
    train_dataset = AlbumentationsDataset(train_image_paths, train_labels, transform=train_transform, img_size=img_size)
    val_dataset = AlbumentationsDataset(val_image_paths, val_labels, transform=val_test_transform, img_size=img_size)
    test_dataset = AlbumentationsDataset(test_image_paths, test_labels, transform=val_test_transform, img_size=img_size)

    # --- Log Dataset Sizes and Class Distribution ---
    base_train_size = len(train_dataset); base_val_size = len(val_dataset); base_test_size = len(test_dataset)
    train_real_count = sum(1 for label in train_labels if label == 0); train_forged_count = len(train_labels) - train_real_count
    val_real_count = sum(1 for label in val_labels if label == 0); val_forged_count = len(val_labels) - val_real_count
    test_real_count = sum(1 for label in test_labels if label == 0); test_forged_count = len(test_labels) - test_real_count

    logging.info(f"Dataset sizes: Train={base_train_size}, Validation={base_val_size}, Test={base_test_size}")
    logging.info(f"  Train class distribution: Real={train_real_count}, Forged={train_forged_count}")


    # --- Check for Empty Datasets---
    if base_train_size == 0 or base_val_size == 0 or base_test_size == 0:
        error_msg = "One or more datasets are empty!" 
        logging.error(error_msg)
        raise ValueError(error_msg)


    # --- Create DataLoaders---
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=0, 
        pin_memory=torch.cuda.is_available(), drop_last=False
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        num_workers=0, 
        pin_memory=torch.cuda.is_available()
    )
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False,
        num_workers=0, 
        pin_memory=torch.cuda.is_available()
    )

    logging.info(f"DataLoaders created successfully for {model_name}.")
    return train_loader, val_loader, test_loader

### 9. Define pre-trained models :

This cell defines a list named `models_to_try`. This list contains the names of various pre-trained model architectures that the notebook will experiment with for the image classification task.

After defining the list, it iterates through `models_to_try` to verify that the required input image sizes for each of these models are defined in the `model_img_sizes` dictionary (which was defined in Cell 5). If a model's image size is not found, it raises a `KeyError`, ensuring that the notebook has all necessary configuration before proceeding to model building.

In [11]:
print("\n--- Defining Base Models ---")
models_to_try = [
     # "ConvNeXt_Base",        
     # "EfficientNetV2-S",     
     # "DeiT_Base",           
     # "BEiT_Base",          
     # "ResNetRS50",          
     # "Xception",             
     # "MobileNetV3_Large",   
     "ViT_Base",             
     # "EfficientNet_B7",      
     # "InceptionV3",         
]

# Verify image sizes are defined for all selected models
for model_name in models_to_try:
    if model_name not in model_img_sizes:
        raise KeyError(f"Image size for model '{model_name}' not defined in 'model_img_sizes'.")


--- Defining Base Models ---


### 10. Function to build and train a model :

This cell defines the `CustomModel` class, a PyTorch `nn.Module`, which serves as a flexible wrapper for creating image classification models using various pre-trained backbones. The key enhancement in this version is more specific logic for unfreezing layers in different architectures.

Key features of the `CustomModel` class:
-   **Initialization (`__init__`)**:
    -   Takes `model_name`, `dense_units` (for the classifier head), `dropout` probability, `pretrained` flag, and `unfreeze_layers` (number of layer groups/blocks/stages to unfreeze in the base model) as arguments.
    -   **Base Model Loading**:
        -   If `model_name` is "ViT_Base", it loads a Vision Transformer from Hugging Face Transformers (`google/vit-base-patch16-224`). It configures it not to add a pooling layer, as the CLS token's embedding will be used directly.
        -   For other model names, it uses `timm.create_model` to load the corresponding pre-trained model. A mapping (`timm_model_name_map`) is used to get the specific `timm` model identifier. `num_classes=0` is used to get the feature extractor part of the `timm` models.
    -   **Feature Dimension Verification**: It performs a test forward pass with a dummy input tensor to determine the actual number of output features from the base model. This helps ensure the classifier head is correctly sized.
    -   **Parameter Freezing/Unfreezing**:
        -   Initially, all parameters in the loaded base model are frozen (`param.requires_grad = False`).
        -   If `unfreeze_layers` is greater than 0, it proceeds to unfreeze the specified number of layers/blocks from the end of the model. The unfreezing logic is now model-architecture-specific:
            -   **ViT_Base**: Unfreezes the last `unfreeze_layers` transformer encoder layers (`self.base_model.encoder.layer`).
            -   **EfficientNet models**: Attempts to unfreeze the last `unfreeze_layers` from `self.base_model.blocks`. If `blocks` attribute is not found, it falls back to `_generic_unfreeze`.
            -   **ConvNeXt models**: Attempts to unfreeze the last `unfreeze_layers` from `self.base_model.stages`. If `stages` attribute is not found, it falls back to `_generic_unfreeze`.
            -   **ResNet models**: Identifies main layer groups (e.g., `layer1`, `layer2`), sorts them, and unfreezes the last `unfreeze_layers` of these groups. Additionally, it unfreezes parameters in any top-level modules named 'norm', 'bn', or 'head'.
            -   **Other models**: Uses the `_generic_unfreeze` helper method.
        -   Logs the number of initially frozen parameters and the final counts of trainable and frozen parameters after the unfreezing process.
    -   **`_generic_unfreeze` Helper Method**: This private method is used as a fallback. It collects all named modules that have parameters, sorts them (attempting to approximate depth), and unfreezes the parameters of the last `unfreeze_layers` modules in this list.
    -   **Classifier Head**: Defines a sequential classifier head consisting of a `Linear` layer, `ReLU` activation, `BatchNorm1d`, `Dropout`, and a final `Linear` layer with a single output unit (for binary classification logits).
-   **Forward Pass (`forward`)**:
    -   Takes an input tensor `x`.
    -   Passes `x` through the base model to get features.
        -   For ViT, it extracts the embedding of the `[CLS]` token (`last_hidden_state[:, 0]`).
        -   For `timm` models, it uses the direct output (pooled features).
    -   Passes the extracted features through the custom classifier head.
    -   Returns the output logits.

In [12]:
class CustomModel(nn.Module):
    def __init__(self, model_name, dense_units, dropout, pretrained=True, unfreeze_layers=0):
        super(CustomModel, self).__init__()
        self.model_name = model_name
        self.base_model = None
        reported_features = 0  
        
        logging.info(f"Initializing CustomModel: Base='{model_name}', DenseUnits={dense_units}, Dropout={dropout:.2f}, Unfreeze={unfreeze_layers}")
        
        try:
            # --- Load Base Model ---
            if model_name == "ViT_Base":
                hf_model_name = 'google/vit-base-patch16-224'
                self.base_model = ViTModel.from_pretrained(
                    hf_model_name,
                    add_pooling_layer=False,  
                    ignore_mismatched_sizes=True
                )
                reported_features = self.base_model.config.hidden_size
                logging.info(f"Loaded '{hf_model_name}' from HuggingFace Transformers. Reported features: {reported_features}")
            else:
                timm_model_name_map = {
                    "ConvNeXt_Base": "convnext_base.fb_in22k_ft_in1k", 
                    "EfficientNetV2-S": "tf_efficientnetv2_s.in21k_ft_in1k",
                    "DeiT_Base": "deit_base_patch16_224.fb_in1k",
                    "BEiT_Base": "beit_base_patch16_224",
                    "EfficientNet_B7": "tf_efficientnet_b7.ns_jft_in1k", 
                    "ResNetRS50": "resnetrs50.tf_in1k",
                    "InceptionV3": "inception_v3.tf_in1k",
                    "Xception": "xception.tf_in1k",
                    "MobileNetV3_Large": "mobilenetv3_large_100.miil_in21k_ft_in1k",
                }
                if model_name not in timm_model_name_map:
                    raise ValueError(f"Model name '{model_name}' not found in timm map or not supported.")

                timm_name = timm_model_name_map[model_name]
                self.base_model = timm.create_model(timm_name, pretrained=pretrained, num_classes=0)
                reported_features = self.base_model.num_features
                logging.info(f"Loaded '{timm_name}' from timm. Reported features: {reported_features}")
            
            # --- Verify Actual Feature Dimensions with Test Forward Pass ---
            self.base_model.eval()
            with torch.no_grad():
                try:
                    dummy_input = torch.zeros(1, 3, 224, 224)
                    if model_name == "ViT_Base":
                        actual_features = self.base_model(dummy_input).last_hidden_state[:, 0]
                    else:
                        actual_features = self.base_model(dummy_input)
                    
                    actual_num_features = actual_features.shape[1]
                    
                    if actual_num_features != reported_features:
                        logging.warning(f"Model '{model_name}' reports {reported_features} features but outputs {actual_num_features} features. Using actual dimension.")
                        num_features = actual_num_features
                    else:
                        num_features = reported_features
                        logging.info(f"Verified that model '{model_name}' outputs {num_features} features as expected.")
                        
                except Exception as e:
                    logging.warning(f"Exception during feature verification: {e}. Using reported feature count {reported_features}.")
                    num_features = reported_features

            # --- Parameter Freezing/Unfreezing ---
            if unfreeze_layers < 0:
                logging.warning(f"unfreeze_layers cannot be negative ({unfreeze_layers}). Setting to 0 (fully frozen).")
                unfreeze_layers = 0

            for param in self.base_model.parameters():
                param.requires_grad = False
            frozen_params = sum(p.numel() for p in self.base_model.parameters() if not p.requires_grad)
            logging.info(f"Initially froze {frozen_params:,} parameters in the base model.")

            if unfreeze_layers > 0:
                layers_unfrozen_count = 0
                params_unfrozen_count = 0
                
                if model_name == "ViT_Base":
                    num_layers = len(self.base_model.encoder.layer)
                    layers_to_unfreeze = min(unfreeze_layers, num_layers)
                    
                    for i in range(num_layers - layers_to_unfreeze, num_layers):
                        for param in self.base_model.encoder.layer[i].parameters():
                            param.requires_grad = True
                            params_unfrozen_count += param.numel()
                        
                        layers_unfrozen_count += 1
                        logging.info(f"Unfroze ViT encoder layer {i}")
                
                elif "EfficientNet" in model_name:
                    if hasattr(self.base_model, 'blocks'):
                        blocks = self.base_model.blocks
                        num_blocks = len(blocks)
                        blocks_to_unfreeze = min(unfreeze_layers, num_blocks)
                        
                        for i in range(num_blocks - blocks_to_unfreeze, num_blocks):
                            for param in blocks[i].parameters():
                                param.requires_grad = True
                                params_unfrozen_count += param.numel()
                            
                            layers_unfrozen_count += 1
                            logging.info(f"Unfroze EfficientNet block {i}")
                    else:
                        logging.warning(f"Could not find blocks for {model_name}, falling back to generic unfreezing.")
                        self._generic_unfreeze(unfreeze_layers, layers_unfrozen_count, params_unfrozen_count)
                
                elif "ConvNeXt" in model_name:
                    if hasattr(self.base_model, 'stages'):
                        stages = self.base_model.stages
                        num_stages = len(stages)
                        stages_to_unfreeze = min(unfreeze_layers, num_stages)
                        
                        for i in range(num_stages - stages_to_unfreeze, num_stages):
                            for param in stages[i].parameters():
                                param.requires_grad = True
                                params_unfrozen_count += param.numel()
                            
                            layers_unfrozen_count += 1
                            logging.info(f"Unfroze ConvNeXt stage {i}")
                    else:
                        logging.warning(f"Could not find stages for {model_name}, falling back to generic unfreezing.")
                        self._generic_unfreeze(unfreeze_layers, layers_unfrozen_count, params_unfrozen_count)
                
                elif "ResNet" in model_name:
                    layer_groups = []
                    for name, child in self.base_model.named_children():
                        if 'layer' in name and isinstance(child, nn.Module):
                            layer_groups.append((name, child))
                    
                    layer_groups.sort(key=lambda x: x[0])  
                    groups_to_unfreeze = min(unfreeze_layers, len(layer_groups))
                    
                    for i in range(len(layer_groups) - groups_to_unfreeze, len(layer_groups)):
                        name, layer = layer_groups[i]
                        for param in layer.parameters():
                            param.requires_grad = True
                            params_unfrozen_count += param.numel()
                        
                        layers_unfrozen_count += 1
                        logging.info(f"Unfroze ResNet {name}")
                    
                    for name, module in self.base_model.named_children():
                        if any(x in name.lower() for x in ['norm', 'bn', 'head']):
                            for param in module.parameters():
                                param.requires_grad = True
                                params_unfrozen_count += param.numel()
                            logging.info(f"Unfroze {name}")
                
                else:
                    self._generic_unfreeze(unfreeze_layers, layers_unfrozen_count, params_unfrozen_count)
                
                final_frozen_params = sum(p.numel() for p in self.base_model.parameters() if not p.requires_grad)
                final_unfrozen_params = sum(p.numel() for p in self.base_model.parameters() if p.requires_grad)
                logging.info(f"Base model state after unfreeze: {final_unfrozen_params:,} trainable params, {final_frozen_params:,} frozen params.")
            else:
                logging.info(f"Base model remains fully frozen as requested (unfreeze_layers=0).")

            # --- Define Classifier Head ---
            self.classifier = nn.Sequential(
                nn.Linear(num_features, dense_units),
                nn.ReLU(),
                nn.BatchNorm1d(dense_units),
                nn.Dropout(dropout),
                nn.Linear(dense_units, 1),  
            )
            classifier_params = sum(p.numel() for p in self.classifier.parameters())
            logging.info(f"Classifier head created with {classifier_params:,} trainable parameters.")
            
        except Exception as e:
            logging.error(f"Error initializing model '{model_name}': {e}", exc_info=True)
            raise

    def _generic_unfreeze(self, unfreeze_layers, layers_unfrozen_count, params_unfrozen_count):
        module_list = []
        for name, module in self.base_model.named_modules():
            if name and any(has_params(module) for has_params in [lambda m: any(p.requires_grad is not None for p in m.parameters(recurse=False))]):
                module_list.append((name, module))
        module_list.sort(key=lambda x: len(x[0].split('.')))
        modules_to_unfreeze = module_list[-unfreeze_layers:] if unfreeze_layers < len(module_list) else module_list
        
        for name, module in modules_to_unfreeze:
            has_params = False
            for param in module.parameters(recurse=False):
                param.requires_grad = True
                params_unfrozen_count += param.numel()
                has_params = True
            
            if has_params:
                layers_unfrozen_count += 1
                logging.info(f"Unfroze layer: {name}")

    def forward(self, x):
        if self.model_name == "ViT_Base":
            features = self.base_model(x).last_hidden_state[:, 0]  
        else:
            features = self.base_model(x)
        if torch.is_grad_enabled() and random.random() < 0.01:  
            logging.debug(f"Model '{self.model_name}' feature shape: {features.shape}")
        output = self.classifier(features)
        return output

### 11. Evaluate Model

This cell defines the `evaluate_model` function, which is responsible for assessing the performance of a trained model on a given dataset (typically the validation or test set).
The function performs the following steps:
1.  **Set to Evaluation Mode**: Puts the `model` into evaluation mode (`model.eval()`) to disable layers like dropout and batch normalization updates.
2.  **No Gradient Calculation**: Uses `torch.no_grad()` to ensure that gradients are not computed during evaluation, saving memory and computation.
3.  **Iterate Through Data**: Loops through the `data_loader` (e.g., validation or test loader) batch by batch using `tqdm` for a progress bar.
    -   Moves `inputs` and `labels` to the specified `device`. Labels are converted to `float` and unsqueezed to match the shape expected by `BCEWithLogitsLoss`.
    -   Performs a forward pass: `outputs = model(inputs)`.
    -   Calculates the `loss` using the provided `criterion`.
    -   Accumulates the `total_loss`.
    -   Converts model outputs (logits) to probabilities using `torch.sigmoid()`.
    -   Determines predictions by thresholding probabilities at 0.5.
    -   Stores all true `labels`, predicted `preds`, and `probs` for later metric calculation.
4.  **Calculate Metrics**:
    -   Calculates the average loss (`avg_loss`) over the entire dataset.
    -   Uses `sklearn.metrics` to compute:
        -   `accuracy_score`
        -   `f1_score` (binary average, with `pos_label=1` and `zero_division=0` to handle cases with no positive predictions/labels gracefully)
        -   `precision_score` (binary average, `pos_label=1`, `zero_division=0`)
        -   `recall_score` (binary average, `pos_label=1`, `zero_division=0`)
        -   `confusion_matrix` (with explicit labels `[0, 1]` to ensure consistent order for 'real' and 'forged' classes).
5.  **Return Metrics**: Returns the calculated average loss, accuracy, F1-score, precision, recall, and the confusion matrix.
It includes a check for an empty dataset to prevent division by zero and returns zero metrics in such a case.


In [13]:
def evaluate_model(model, data_loader, criterion, device):

    model.eval()  
    total_loss = 0.0
    all_labels = []
    all_preds = []
    all_probs = [] 

    with torch.no_grad(): 
        
        eval_progress = tqdm(data_loader, desc="Evaluating", leave=False, unit="batch")
        for inputs, labels in eval_progress:
            inputs, labels = inputs.to(device), labels.to(device).float().unsqueeze(1) 

            outputs = model(inputs) 
            loss = criterion(outputs, labels)
            total_loss += loss.item() * inputs.size(0) 
            probs = torch.sigmoid(outputs) 
            preds = (probs > 0.5).float() 

            all_labels.extend(labels.cpu().numpy().flatten())
            all_preds.extend(preds.cpu().numpy().flatten())
            all_probs.extend(probs.cpu().numpy().flatten())

    # --- Calculate Metrics ---
    num_samples = len(data_loader.dataset) 
    if num_samples == 0:
        logging.warning("Evaluation dataset is empty. Returning zero metrics.")
        return 0.0, 0.0, 0.0, 0.0, 0.0, np.zeros((2, 2), dtype=int)
    avg_loss = total_loss / num_samples

    all_labels_np = np.array(all_labels)
    all_preds_np = np.array(all_preds)

    accuracy = accuracy_score(all_labels_np, all_preds_np)
    f1 = f1_score(all_labels_np, all_preds_np, average='binary', pos_label=1, zero_division=0) 
    precision = precision_score(all_labels_np, all_preds_np, average='binary', pos_label=1, zero_division=0)
    recall = recall_score(all_labels_np, all_preds_np, average='binary', pos_label=1, zero_division=0)
    conf_matrix = confusion_matrix(all_labels_np, all_preds_np, labels=[0, 1]) 

    return avg_loss, accuracy, f1, precision, recall, conf_matrix

### 12. Hyperparameter Tuning and Training

This cell defines the `train_model` function, which orchestrates the model training process for a specified number of epochs.
Key functionalities include:
-   **Initialization**: Sets up variables for tracking the best validation loss, epochs without improvement (for early stopping), and the best model state. Initializes a `history` dictionary to store training and validation metrics per epoch.
-   **Checkpoint Setup**: Defines the checkpoint directory and path for saving the best model based on validation loss.
-   **Epoch Loop**: Iterates for the specified number of `epochs`.
    -   **Training Phase**:
        -   Sets the model to training mode (`model.train()`).
        -   Iterates through the `train_loader` (with a `tqdm` progress bar).
        -   For each batch:
            -   Moves inputs and labels to the `device`. Handles potential placeholder labels (-1) from `AlbumentationsDataset` by skipping problematic batches or filtering them.
            -   Performs the standard training steps: zero gradients, forward pass, loss calculation, backward pass, and optimizer step.
            -   Accumulates training loss and calculates batch accuracy.
        -   Calculates average training loss and accuracy for the epoch and stores them in `history`.
    -   **Validation Phase**:
        -   Calls the `evaluate_model` function (defined in the previous cell) to get validation metrics (loss, accuracy, F1, precision, recall).
        -   Stores these validation metrics in `history`.
    -   **Logging**: Logs a summary of training and validation metrics for the current epoch, along with the epoch duration.
    -   **Early Stopping & Checkpointing**:
        -   If the current validation loss is better than `best_val_loss`, it updates `best_val_loss`, resets `epochs_no_improve`, and saves the current model's state dictionary (`best_model_state`) to the `checkpoint_path`.
        -   If validation loss does not improve, increments `epochs_no_improve`.
    -   **Optuna Pruning Integration**:
        -   If a `trial` object (from Optuna) is provided and the current epoch is beyond `hpo_warmup_steps`:
            -   It reports the validation F1 score (`val_f1`) to the Optuna `trial`.
            -   It checks if the `trial.should_prune()`. If true, it logs a pruning message, cleans up GPU memory, and raises `optuna.TrialPruned` to stop the trial.
    -   **Early Stopping Check**: If `epochs_no_improve` reaches the `patience` limit, it logs an early stopping message, prints the best validation loss, and breaks the training loop.
-   **Post-Training**:
    -   Logs the total training duration.
    -   Loads the `best_model_state` (if one was saved) back into the model. If no improvement was seen, it warns and uses the model state from the last epoch (and saves it).
-   **Plotting Training History**:
    -   Uses `matplotlib` and `seaborn` to generate plots of training/validation loss vs. epochs and training/validation accuracy & F1-score vs. epochs.
    -   Saves these plots to the model's checkpoint directory.
-   **Return**: Returns the trained `model` (with best weights loaded) and the `history` dictionary.

In [14]:
def train_model(model, model_name, train_loader, val_loader, criterion, optimizer, epochs, device, trial=None, patience=5, hpo_warmup_steps=0):
 
    best_val_loss = float('inf')
    epochs_no_improve = 0
    best_model_state = None
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], 'val_f1': [], 'val_precision': [], 'val_recall': []}

    # --- Checkpoint Setup ---
    checkpoint_dir = os.path.join(CHECKPOINT_BASE_DIR, model_name)
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_best_val_loss.pth")
    logging.info(f"Best model checkpoint (based on val loss) will be saved to: {checkpoint_path}")

    total_start_time = time.time()
    logging.info(f"--- Starting Training: {model_name} for {epochs} epochs (Patience: {patience}) ---")
    model.to(device) 

    for epoch in range(epochs):
        epoch_start_time = time.time()

        # --- Training Phase ---
        model.train()  
        running_loss = 0.0
        correct_train = 0
        total_train = 0

        train_progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} Train", leave=False, unit="batch")
        for inputs, labels in train_progress:
            inputs, labels = inputs.to(device), labels.to(device).float().unsqueeze(1) 

            if torch.any(labels == -1):
                 logging.warning(f"Skipping batch due to error loading data (label = -1 found).")
                 valid_indices = (labels != -1).squeeze()
                 if not torch.any(valid_indices): continue 
                 inputs = inputs[valid_indices]
                 labels = labels[valid_indices]
                 if inputs.nelement() == 0: continue


            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            batch_loss = loss.item()
            running_loss += batch_loss * inputs.size(0)
            with torch.no_grad():
                 preds = (torch.sigmoid(outputs) > 0.5).float()
                 total_train += labels.size(0)
                 correct_train += (preds == labels).sum().item()

            train_progress.set_postfix(batch_loss=f"{batch_loss:.4f}")

        epoch_train_loss = running_loss / len(train_loader.sampler) 
        epoch_train_acc = correct_train / total_train if total_train > 0 else 0.0
        history['train_loss'].append(epoch_train_loss)
        history['train_acc'].append(epoch_train_acc)

        # --- Validation Phase ---
        val_loss, val_accuracy, val_f1, val_precision, val_recall, _ = evaluate_model(model, val_loader, criterion, device)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_accuracy)
        history['val_f1'].append(val_f1)
        history['val_precision'].append(val_precision)
        history['val_recall'].append(val_recall)


        epoch_duration = time.time() - epoch_start_time
        logging.info(
            f"Epoch {epoch+1}/{epochs} | "
            f"Train Loss: {epoch_train_loss:.4f} | Train Acc: {epoch_train_acc:.4f} | "
            f"Val Loss: {val_loss:.4f} | Val Acc: {val_accuracy:.4f} | Val F1: {val_f1:.4f} | "
            f"Time: {timedelta(seconds=int(epoch_duration))}"
        )


        # --- Early Stopping & Checkpointing (based on Validation Loss) ---
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            best_model_state = copy.deepcopy(model.state_dict())
            try:
                torch.save(best_model_state, checkpoint_path)
                logging.info(f"  -> Val loss improved to {val_loss:.4f}. Saved best model checkpoint.")
            except Exception as e:
                 logging.error(f"Error saving checkpoint to {checkpoint_path}: {e}", exc_info=True)
        else:
            epochs_no_improve += 1
            logging.info(f"  (Val loss did not improve for {epochs_no_improve} epoch(s). Best: {best_val_loss:.4f})")


        # --- Optuna Pruning (based on Validation F1-score) ---
        if trial is not None and epoch >= hpo_warmup_steps:
            trial.report(val_f1, epoch)
            if trial.should_prune():
                logging.warning(f"Optuna Trial {trial.number} pruned at epoch {epoch+1} (Val F1: {val_f1:.4f}).")
                del inputs, labels, outputs, loss
                if torch.cuda.is_available(): torch.cuda.empty_cache()
                raise optuna.TrialPruned() 


        # --- Early Stopping Check ---
        if epochs_no_improve >= patience:
            logging.info(f"Early stopping triggered after {epoch+1} epochs (patience={patience}).")
            logging.info(f"Best validation loss achieved: {best_val_loss:.4f}")
            break 


    # --- End of Training Loop ---
    total_duration = time.time() - total_start_time
    logging.info(f"--- Training loop finished for {model_name} in {timedelta(seconds=int(total_duration))} ---")

    if best_model_state is not None:
        logging.info(f"Loading best model weights from checkpoint (Val Loss: {best_val_loss:.4f}).")
        model.load_state_dict(best_model_state)
    else:
        logging.warning("No improvement in validation loss observed during training.")
        logging.warning("Model state from the last epoch will be used.")
        last_epoch_path = os.path.join(checkpoint_dir, f"{model_name}_last_epoch.pth")
        torch.save(model.state_dict(), last_epoch_path)
        logging.info(f"Saved model state from last epoch to: {last_epoch_path}")


    # --- Plotting Training History ---
    try:
        sns.set_theme(style="darkgrid") 
        num_epochs_trained = len(history['train_loss'])
        if num_epochs_trained > 0:
            epoch_range = range(1, num_epochs_trained + 1)

            fig, axes = plt.subplots(1, 2, figsize=(16, 6)) 

            axes[0].plot(epoch_range, history['train_loss'], label='Train Loss', marker='o', linestyle='-', color='royalblue')
            axes[0].plot(epoch_range, history['val_loss'], label='Val Loss', marker='x', linestyle='--', color='darkorange')
            axes[0].set_title(f'{model_name} - Loss vs. Epochs')
            axes[0].set_xlabel('Epoch')
            axes[0].set_ylabel('Loss')
            axes[0].legend()
            axes[0].grid(True, linestyle=':')

            axes[1].plot(epoch_range, history['train_acc'], label='Train Accuracy', marker='o', linestyle='-', color='royalblue')
            axes[1].plot(epoch_range, history['val_acc'], label='Val Accuracy', marker='x', linestyle='--', color='darkorange')
            axes[1].plot(epoch_range, history['val_f1'], label='Val F1 Score', marker='s', linestyle='-.', color='forestgreen')
            axes[1].set_title(f'{model_name} - Accuracy & Val F1 vs. Epochs')
            axes[1].set_xlabel('Epoch')
            axes[1].set_ylabel('Metric Value')
            axes[1].set_ylim(0.0, 1.05) 
            axes[1].legend()
            axes[1].grid(True, linestyle=':')

            fig.suptitle(f"Training History: {model_name}", fontsize=16)
            plt.tight_layout(rect=[0, 0.03, 1, 0.95]) 

            plot_path = os.path.join(checkpoint_dir, f"{model_name}_training_metrics.png")
            plt.savefig(plot_path, dpi=150)
            plt.close(fig)
            logging.info(f"Training metrics plot saved to: {plot_path}")
        else:
             logging.warning("No training history recorded (0 epochs trained). Skipping plot generation.")

    except Exception as e:
        logging.error(f"Failed to generate or save training plot for {model_name}: {e}", exc_info=True)


    return model, history

### 13. Optuna Objective Function

This cell defines the `objective` function, which is the core component for hyperparameter optimization using the Optuna library. This function is called by Optuna for each trial.
Its purpose and steps are:
1.  **Receive Trial Object**: Takes an `optuna.Trial` object, `model_name`, `device`, and HPO-specific settings (`hpo_epochs`, `hpo_patience`, `hpo_warmup_steps`) as input.
2.  **Suggest Hyperparameters**: Uses `trial.suggest_categorical()` and `trial.suggest_float()` to sample hyperparameters for the current trial. These include:
    -   `dense_units`: Number of units in the classifier's dense layer.
    -   `dropout`: Dropout rate in the classifier.
    -   `learning_rate`: Learning rate for the optimizer.
    -   `optimizer`: Type of optimizer (AdamW or Adam).
    -   `batch_size`: Batch size for data loaders.
    -   `unfreeze_layers`: Number of layers to unfreeze in the pre-trained base model (0 means fully frozen).
3.  **Log Trial Parameters**: Prints the hyperparameters chosen for the current trial and the best F1 score found so far in the study.
4.  **Setup for Training**:
    -   Retrieves the appropriate `img_size` for the given `model_name`.
    -   Calls `get_data_loaders` to create data loaders with the suggested `batch_size` and `img_size`.
    -   Instantiates the `CustomModel` with the `model_name` and suggested `dense_units`, `dropout`, and `unfreeze_layers`.
    -   Creates the optimizer (Adam or AdamW) with the suggested `learning_rate`, ensuring it targets all trainable parameters. If no trainable parameters are found (e.g., due to an issue with unfreezing logic or model setup), it prunes the trial.
    -   Defines the loss function (`nn.BCEWithLogitsLoss`).
5.  **Train Model**: Calls the `train_model` function (defined earlier) with the configured model, data loaders, optimizer, criterion, and HPO-specific settings (`hpo_epochs`, `hpo_patience`, `hpo_warmup_steps`). The `trial` object is passed to `train_model` to enable pruning.
6.  **Evaluate Model**: After training (or if early stopping/pruning occurred), it calls `evaluate_model` to get performance metrics on the validation set using the best model state loaded by `train_model`.
7.  **Log Trial Results**: Logs the duration of the trial and the validation metrics (Loss, Accuracy, F1, Precision, Recall).
8.  **Return Metric**: Returns the validation F1 score (`val_f1`), which Optuna will attempt to maximize.
9.  **Error Handling**:
    -   Catches `optuna.TrialPruned` exceptions specifically to allow Optuna to handle pruning correctly.
    -   Catches any other exceptions, logs the error, cleans up resources (deletes model, loaders, optimizer, criterion, and clears CUDA cache), and returns a poor value (0.0 for F1 score) to indicate trial failure to Optuna.

In [15]:
def objective(trial, model_name, device, hpo_epochs, hpo_patience, hpo_warmup_steps):
    trial_start_time = time.time()
    logging.info(f"\n--- Optuna Trial {trial.number} Start ({model_name}) ---")

    # --- Suggest Hyperparameters ---
    dense_units = trial.suggest_categorical('dense_units', [128, 256, 512, 768])
    dropout = trial.suggest_float('dropout', 0.1, 0.6, step=0.05) 
    learning_rate = trial.suggest_float('learning_rate', 5e-6, 1e-3, log=True) 
    optimizer_name = trial.suggest_categorical('optimizer', ['AdamW', 'Adam']) 
    batch_size = trial.suggest_categorical('batch_size', [16, 32, 64])
    unfreeze_layers = trial.suggest_int('unfreeze_layers', 0, 3) 

    trial_params = trial.params
    print(f"  Trial {trial.number} Parameters:")
    for key, value in trial_params.items():
         print(f"    {key:<15}: {value}")
    try:
        best_value_so_far = trial.study.best_value
        print(f"  Best Value ({trial.study.direction.name} F1) So Far: {best_value_so_far:.4f}")
    except ValueError: 
        print(f"  Best Value ({trial.study.direction.name} F1) So Far: N/A (First trial)")
    img_size = model_img_sizes[model_name]
    model, train_loader, val_loader, optimizer, criterion = None, None, None, None, None

    try:
        train_loader, val_loader, _ = get_data_loaders(model_name, batch_size, img_size, num_workers=2)

        model = CustomModel(model_name=model_name,
                           dense_units=dense_units,
                           dropout=dropout,
                           pretrained=True,
                           unfreeze_layers=unfreeze_layers).to(device)

        trainable_params = [p for p in model.parameters() if p.requires_grad]
        if not trainable_params:
            logging.error(f"FATAL ERROR in Trial {trial.number}: No trainable parameters found for model {model_name} with unfreeze={unfreeze_layers}. Pruning.")

            raise optuna.TrialPruned("No trainable parameters found.")


        logging.info(f"  Trial {trial.number}: Optimizing {len(trainable_params)} parameter groups ({sum(p.numel() for p in trainable_params):,} total parameters).")

        if optimizer_name == 'Adam':
            optimizer = torch.optim.Adam(trainable_params, lr=learning_rate)
        elif optimizer_name == 'AdamW':
            optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate)
        else:
            raise ValueError(f"Unsupported optimizer suggested: {optimizer_name}")

        criterion = nn.BCEWithLogitsLoss()

        model, history = train_model(model=model,
                                      model_name=f"{model_name}_trial_{trial.number}", 
                                      train_loader=train_loader,
                                      val_loader=val_loader,
                                      criterion=criterion,
                                      optimizer=optimizer,
                                      epochs=hpo_epochs, 
                                      device=device,
                                      trial=trial, 
                                      patience=hpo_patience,
                                      hpo_warmup_steps=hpo_warmup_steps)

        val_loss, val_accuracy, val_f1, val_precision, val_recall, _ = evaluate_model(model, val_loader, criterion, device)

        trial_duration = time.time() - trial_start_time
        logging.info(f"--- Optuna Trial {trial.number} Finished [{timedelta(seconds=int(trial_duration))}] ---")
        logging.info(f"  Validation Metrics: Loss={val_loss:.4f}, Acc={val_accuracy:.4f}, F1={val_f1:.4f}, Precision={val_precision:.4f}, Recall={val_recall:.4f}")

        return val_f1

    except optuna.TrialPruned as e:
        trial_duration = time.time() - trial_start_time
        logging.info(f"--- Optuna Trial {trial.number} Pruned [{timedelta(seconds=int(trial_duration))}] ---")
        del model, train_loader, val_loader, optimizer, criterion
        if torch.cuda.is_available(): torch.cuda.empty_cache()
        raise e 

    except Exception as e:
        trial_duration = time.time() - trial_start_time
        logging.error(f"--- Optuna Trial {trial.number} Failed [{timedelta(seconds=int(trial_duration))}] ---")
        logging.error(f"Error during Optuna trial {trial.number} for {model_name}: {e}", exc_info=True)

        if 'model' in locals() and model is not None: del model
        if 'train_loader' in locals() and train_loader is not None: del train_loader
        if 'val_loader' in locals() and val_loader is not None: del val_loader
        if 'optimizer' in locals() and optimizer is not None: del optimizer
        if 'criterion' in locals() and criterion is not None: del criterion
        if torch.cuda.is_available(): torch.cuda.empty_cache()

        return 0.0

### 14. Upload to Dataset

This cell defines three helper functions related to saving and packaging results: `plot_training_history`, `plot_confusion_matrix`, and `upload_checkpoint_to_dataset`.

1.  **`plot_training_history(history_dict, save_path, model_name)`**:
    *   Takes a history dictionary (containing lists of 'train_loss', 'val_loss', 'train_acc', 'val_acc', 'val_f1', etc.), a save path, and the model name.
    *   Generates and saves plots for:
        *   Training vs. Validation Loss.
        *   Training vs. Validation Accuracy.
        *   Validation Metrics (F1, Precision, Recall) if available.
    *   Plots are saved as PNG files in the `save_path` (typically the model's checkpoint directory).
    *   Includes error handling for missing keys or other plotting issues.

2.  **`plot_confusion_matrix(cm, save_path, model_name, class_names=None)`**:
    *   Takes a confusion matrix (`cm`), a save path, model name, and optional class names.
    *   Generates a heatmap visualization of the confusion matrix using `seaborn.heatmap`.
    *   Saves the plot as a PNG file in the `save_path`.

3.  **`upload_checkpoint_to_dataset(model_name)`**:
    *   This is the main function for packaging and making results accessible.
    *   **Load Metrics and Generate Plots**:
        *   Constructs the path to the model's checkpoint directory and the results JSON file (e.g., `_test_results.json`).
        *   If the metrics file exists, it loads the JSON data.
        *   Extracts `training_history` and `confusion_matrix` from the loaded metrics.
        *   Calls `plot_training_history` and `plot_confusion_matrix` to generate plots and save them directly into the model's checkpoint directory.
    *   **File Collection**: Lists all relevant files in the checkpoint directory (`.pth`, `.pt`, `.onnx` for models; `.json` for metrics; `.png` for plots).
    *   **Environment-Specific Handling**:
        *   **Kaggle Environment**:
            *   Creates a zip file (e.g., `{model_name}_checkpoints_and_plots.zip`) in `/kaggle/working/` containing all collected files from the checkpoint directory.
            *   Verifies the zip file creation and logs its size and contents.
            *   Prints instructions for downloading the zip file via the Kaggle UI ('Data' tab -> 'Output').
            *   Attempts to provide a direct download link using `IPython.display.FileLink` as a secondary method.
        *   **Colab Environment**:
            *   Copies all collected files from the local checkpoint directory to a corresponding directory in Google Drive (`/content/drive/MyDrive/checkpoints/SignatureVerification/{model_name}`).
            *   Logs the number of files copied and the destination path.
        *   **Local Environment (Unknown)**:
            *   Copies all collected files to a local directory (e.g., `~/checkpoints/SignatureVerification/{model_name}`).
            *   Logs the number of files copied and the destination path.
    *   Includes error handling for various stages like JSON decoding, file operations, and zip creation.

In [16]:
def plot_training_history(history_dict, save_path, model_name):
    if not history_dict:
        logger.warning(f"No training history data provided for {model_name}. Skipping history plots.")
        return

    try:
        epochs = range(1, len(history_dict['train_loss']) + 1)

        # --- Loss Plot ---
        plt.figure(figsize=(10, 5))
        plt.plot(epochs, history_dict['train_loss'], 'bo-', label='Training Loss')
        plt.plot(epochs, history_dict['val_loss'], 'ro-', label='Validation Loss')
        plt.title(f'{model_name} - Training & Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(os.path.join(save_path, f'{model_name}_loss_plot.png'))
        plt.close() 
        logger.info(f"Saved loss plot for {model_name}")

        # --- Accuracy Plot ---
        plt.figure(figsize=(10, 5))
        plt.plot(epochs, history_dict['train_acc'], 'bo-', label='Training Accuracy')
        plt.plot(epochs, history_dict['val_acc'], 'ro-', label='Validation Accuracy')
        plt.title(f'{model_name} - Training & Validation Accuracy')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(os.path.join(save_path, f'{model_name}_accuracy_plot.png'))
        plt.close()
        logger.info(f"Saved accuracy plot for {model_name}")

        # --- Validation Metrics Plot (F1, Precision, Recall) ---
        if all(k in history_dict for k in ['val_f1', 'val_precision', 'val_recall']):
             plt.figure(figsize=(10, 5))
             plt.plot(epochs, history_dict['val_f1'], 'go-', label='Validation F1-Score')
             plt.plot(epochs, history_dict['val_precision'], 'yo-', label='Validation Precision')
             plt.plot(epochs, history_dict['val_recall'], 'mo-', label='Validation Recall')
             plt.title(f'{model_name} - Validation Metrics')
             plt.xlabel('Epochs')
             plt.ylabel('Score')
             plt.legend()
             plt.grid(True)
             plt.tight_layout()
             plt.savefig(os.path.join(save_path, f'{model_name}_val_metrics_plot.png'))
             plt.close()
             logger.info(f"Saved validation metrics plot for {model_name}")
        else:
             logger.warning(f"Missing some validation metrics (F1, Precision, Recall) for {model_name}. Skipping val metrics plot.")

    except KeyError as e:
        logger.error(f"Missing key in history_dict for {model_name}: {e}. Cannot generate plots.")
    except Exception as e:
        logger.error(f"Error plotting training history for {model_name}: {e}", exc_info=True)


def plot_confusion_matrix(cm, save_path, model_name, class_names=None):
    if not cm:
        logger.warning(f"No confusion matrix data provided for {model_name}. Skipping CM plot.")
        return

    try:
        cm_array = np.array(cm)
        if not class_names:
             if cm_array.shape == (2,2):
                  class_names = ['Class 0', 'Class 1']
             else: 
                  class_names = [f'Class {i}' for i in range(cm_array.shape[0])]

        plt.figure(figsize=(8, 6))
        sns.heatmap(cm_array, annot=True, fmt='d', cmap='Blues',
                    xticklabels=class_names, yticklabels=class_names, cbar=False)
        plt.title(f'{model_name} - Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.tight_layout()
        plt.savefig(os.path.join(save_path, f'{model_name}_confusion_matrix.png'))
        plt.close()
        logger.info(f"Saved confusion matrix plot for {model_name}")

    except Exception as e:
        logger.error(f"Error plotting confusion matrix for {model_name}: {e}", exc_info=True)



def upload_checkpoint_to_dataset(model_name):
    logger.info(f"Starting upload/packaging process for model: {model_name}")
    checkpoint_dir = os.path.join(CHECKPOINT_BASE_DIR, model_name)
    if not os.path.exists(checkpoint_dir):
        logger.warning(f"Checkpoint directory {checkpoint_dir} does not exist.")
        return

    # --- Load Metrics and Generate Plots ---
    metrics_file_path = results_path 
    metrics_data = None
    if os.path.exists(metrics_file_path):
        try:
            with open(metrics_file_path, 'r') as f:
                metrics_data = json.load(f)
            logger.info(f"Loaded metrics data from {metrics_file_path}")

            history = metrics_data.get('training_history')
            cm = metrics_data.get('confusion_matrix')
            plot_training_history(history, checkpoint_dir, model_name)
            plot_confusion_matrix(cm, checkpoint_dir, model_name) 

        except json.JSONDecodeError:
            logger.error(f"Error decoding JSON from {metrics_file_path}. Cannot generate plots.")
        except Exception as e:
            logger.error(f"Error processing metrics or plotting for {model_name}: {e}", exc_info=True)
    else:
        logger.warning(f"Metrics file {metrics_file_path} not found. Skipping plot generation.")
    # --- End Plot Generation ---


    try:
        checkpoint_files = [
            f for f in os.listdir(checkpoint_dir)
            if f.endswith(('.pth', '.pt', '.onnx', '.json', '.png'))
        ]
        if not checkpoint_files:
            logger.warning(f"No checkpoint-related files (.pth, .pt, .onnx, .json, .png) found in {checkpoint_dir}")
        files_to_zip = []
        for file in checkpoint_files:
             src = os.path.join(checkpoint_dir, file)
             if os.path.isfile(src):
                 files_to_zip.append(file)
             else:
                 logger.warning(f"File '{file}' listed but not found at '{src}'. Skipping.")

        if not files_to_zip:
             logger.error(f"No valid files found to zip in {checkpoint_dir}")
             return

        # --- Environment Specific Handling (Kaggle, Colab, Local) ---
        if ENVIRONMENT == 'kaggle':
            zip_filename = f'{model_name}_checkpoints_and_plots.zip' 
            zip_path = os.path.join('/kaggle/working', zip_filename)
            logger.info(f"Creating zip file at: {zip_path}")

            with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
                for file in files_to_zip:
                    src = os.path.join(checkpoint_dir, file)
                    zipf.write(src, arcname=file) 
            logger.info(f"Zip file creation process finished for {zip_path}")

            time.sleep(1)
            if os.path.exists(zip_path):
                file_size = os.path.getsize(zip_path)
                logger.info(f"Verified zip file exists: {zip_path} (Size: {file_size} bytes)")
                print(f"\nSUCCESS: Zip file created at: {zip_path} (Size: {file_size} bytes)")
                print(f"   Contains: {files_to_zip}") 

                # --- Primary Download Method: Kaggle UI ---
                print("\n--- DOWNLOAD INSTRUCTIONS (Primary Method) ---")
                print(f"1. Go to the 'Data' tab -> 'Output' section in the right-hand panel.")
                print(f"2. Find the file named '{zip_filename}'.")
                print(f"3. Click the download icon.")
                print(f"   (Note: May take a minute to appear after cell finishes.)")

                # --- Secondary Method: FileLink (Attempt) ---
                print("\n--- Direct Link (Secondary Method - May Not Work) ---")
                try:
                    display(FileLink(zip_path))
                    print(f"-> If a link for '{zip_filename}' appeared above, try clicking it.")
                    print(f"-> If it fails, use the Primary Method (Kaggle UI Output tab).")
                except Exception as e:
                    logger.error(f"Error generating FileLink: {e}", exc_info=True)
                    print(f"   - Could not generate direct link. Use the Primary Method.")

            else:
                logger.error(f"CRITICAL ERROR: Zip file NOT found at '{zip_path}' after creation attempt!")

        elif ENVIRONMENT == 'colab':
            drive_dir = f"/content/drive/MyDrive/checkpoints/SignatureVerification/{model_name}"
            os.makedirs(drive_dir, exist_ok=True)
            copied_files_count = 0
            logger.info(f"Copying {len(files_to_zip)} files to Google Drive...")
            for file in files_to_zip:
                src = os.path.join(checkpoint_dir, file)
                dst = os.path.join(drive_dir, file)
                try:
                    shutil.copy(src, dst)
                    copied_files_count += 1
                except Exception as e:
                     logger.error(f"Failed to copy {src} to {dst}: {e}", exc_info=True)
            logger.info(f"{copied_files_count} file(s) copied to Google Drive: {drive_dir}")
            print(f"Checkpoint files & plots saved to Google Drive at: {drive_dir}")


        else: 
            local_dir = os.path.expanduser(f"~/checkpoints/SignatureVerification/{model_name}")
            os.makedirs(local_dir, exist_ok=True)
            copied_files_count = 0
            logger.info(f"Copying {len(files_to_zip)} files to local storage...")
            for file in files_to_zip:
                src = os.path.join(checkpoint_dir, file)
                dst = os.path.join(local_dir, file)
                try:
                    shutil.copy(src, dst)
                    copied_files_count += 1
                except Exception as e:
                     logger.error(f"Failed to copy {src} to {dst}: {e}", exc_info=True)
            logger.info(f"{copied_files_count} file(s) copied locally to: {local_dir}")
            print(f"Checkpoints & plots saved locally to: {local_dir}")


    except Exception as e:
        logger.error(f"Failed to process checkpoints/plots for {model_name}: {str(e)}", exc_info=True)
        print(f"\nERROR: An unexpected error occurred during checkpoint/plot processing: {e}")



# *  Run the Models and generate checkpoints :

This cell is the main execution block of the notebook. It orchestrates the entire process of hyperparameter optimization (HPO), final model training, evaluation, and results storage for each model specified in `models_to_try`.

**Overall Workflow**:
1.  **Initialization**: Records the overall start time and initializes a `results` dictionary to store outcomes for each model.
2.  **Configuration**: Sets parameters for:
    -   **Optuna HPO**: `n_trials_optuna` (number of trials per model), `hpo_timeout_seconds` (timeout for HPO study), `hpo_epochs` (max epochs per HPO trial), `hpo_patience` (early stopping patience within HPO), `hpo_warmup_steps` (epochs before Optuna pruning starts).
    -   **Final Training**: `final_train_epochs` (max epochs for the final run), `final_train_patience` (early stopping patience for the final run).
    -   **Loss Function**: Defines `criterion` as `nn.BCEWithLogitsLoss()` for binary classification.
3.  **Model Loop**: Iterates through each `model_name` in the `models_to_try` list.
    -   **Setup**: Records model start time, logs the current model being processed, and creates a specific checkpoint directory for it.
    -   **Hyperparameter Tuning (Optuna)**:
        -   Logs the start of HPO and its settings.
        -   Creates an Optuna `study` (or loads an existing one from the `OPTUNA_DB_PATH`) with a `MedianPruner`. The study aims to `maximize` the validation F1 score.
        -   Calls `study.optimize()`, passing the `objective` function (defined in Cell 13) along with the model name, device, and HPO settings.
        -   Handles potential errors during study creation or optimization.
        -   After HPO, retrieves the `best_trial`, `best_params`, and `best_value` (best validation F1). Logs these best HPO results.
        -   Handles cases where no successful trials or no best trial is found.
    -   **Final Training with Best Parameters**:
        -   Logs the start of final training and its settings.
        -   Extracts the best hyperparameters (batch size, learning rate, dense units, dropout, optimizer name, unfreeze layers) from `best_params`.
        -   Gets data loaders using the best batch size.
        -   Builds the `CustomModel` using the best hyperparameters (including `unfreeze_layers`).
        -   Creates the final optimizer based on the best optimizer name and learning rate, ensuring it targets all trainable parameters of the final model.
        -   Calls `train_model` to train this final model (without Optuna trial/pruning).
    -   **Final Evaluation on Test Set**:
        -   Evaluates the trained final model on the `test_loader` using `evaluate_model`.
        -   Stores all results (HPO params, HPO F1 score, final training epochs, test metrics like loss, accuracy, F1, precision, recall, confusion matrix, and the full training history) in the `results` dictionary for the current model.
        -   Logs the test set metrics.
        -   Saves these detailed results to a JSON file (e.g., `{model_name}_test_results.json`) in the model's checkpoint directory.
    -   **Resource Cleanup**: Deletes model objects, data loaders, optimizers, and clears CUDA cache to free up memory before processing the next model.
    -   **Error Handling**: Includes a broad `try-except` block for each model to catch any unexpected errors, log them, store an error status in `results`, and clean up resources.
    -   Logs the total processing time for the current model.

In [None]:
overall_start_time = time.time()
results = {} 

# --- Configuration ---
n_trials_optuna = 30  
hpo_timeout_seconds = 36000 
hpo_epochs = 8       
hpo_patience = 3     
hpo_warmup_steps = 2 

final_train_epochs = 25 
final_train_patience = 5  

criterion = nn.BCEWithLogitsLoss()

# --- Model Loop ---
for model_name in models_to_try:
    model_start_time = time.time()
    logging.info(f"\n{'='*70}\nProcessing Model: {model_name}\n{'='*70}")

    model_checkpoint_dir = os.path.join(CHECKPOINT_BASE_DIR, model_name)
    os.makedirs(model_checkpoint_dir, exist_ok=True) 

    try:
        # ===================================
        # 1. Hyperparameter Tuning (Optuna)
        # ===================================
        hpo_start_time = time.time()
        logging.info(f"--- Starting Hyperparameter Tuning for {model_name} ---")
        logging.info(f"  Optuna Settings: Max Trials={n_trials_optuna}, Timeout={timedelta(seconds=hpo_timeout_seconds)}, "
                     f"Epochs/Trial={hpo_epochs}, Patience/Trial={hpo_patience}, Pruning Warmup={hpo_warmup_steps}")

        study_name = f"sig_verify_{model_name}_hpo_v3" 
        study = None
        best_params = None
        best_value = None

        try:
            pruner = optuna.pruners.MedianPruner(n_warmup_steps=hpo_warmup_steps, n_min_trials=5)

            study = optuna.create_study(
                direction='maximize',
                study_name=study_name,
                storage=OPTUNA_DB_PATH,
                pruner=pruner,
                load_if_exists=True
            )

            n_existing_trials = len([t for t in study.trials if t.state != optuna.trial.TrialState.WAITING])
            logging.info(f"Optuna study '{study_name}': Loaded with {n_existing_trials} existing trials.")

            trials_to_run = n_trials_optuna - n_existing_trials
            if trials_to_run <= 0:
                logging.info(f"Study already has {n_existing_trials} >= {n_trials_optuna} trials. Using existing best params.")
            else:
                logging.info(f"Running {trials_to_run} new Optuna trials...")
                study.optimize(lambda trial: objective(trial, model_name, device,
                                                       hpo_epochs, hpo_patience, hpo_warmup_steps),
                               n_trials=trials_to_run,
                               timeout=hpo_timeout_seconds,
                               gc_after_trial=True)

        except Exception as hpo_e:
             logging.error(f"Error during Optuna study creation or optimization for {model_name}: {hpo_e}", exc_info=True)
             results[model_name] = {'error': f'Optuna HPO failed: {hpo_e}'}
             del study
             if torch.cuda.is_available(): torch.cuda.empty_cache()
             continue

        hpo_duration = time.time() - hpo_start_time
        logging.info(f"--- Optuna Tuning Complete for {model_name} in {timedelta(seconds=int(hpo_duration))} ---")
        if not study or not hasattr(study, 'best_trial') or study.best_trial is None:
            completed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
            if not completed_trials:
                 logging.error(f"Optuna study for {model_name} completed but had NO successful trials. Cannot proceed.")
                 results[model_name] = {'error': 'Optuna study finished with no successful trials.'}
            else:
                 logging.warning(f"Optuna study for {model_name} finished, but 'study.best_trial' is None. Trying to find best completed trial.")
                 best_completed_trial = max(completed_trials, key=lambda t: t.value if t.value is not None else -float('inf'))
                 if best_completed_trial.value is not None:
                      best_trial = best_completed_trial
                      logging.info(f"Manually selected best completed trial #{best_trial.number} with value {best_trial.value:.4f}")
                 else:
                      logging.error("Could not find a best trial among completed trials. Cannot proceed.")
                      results[model_name] = {'error': 'Optuna study failed to determine a best trial.'}
                      del study
                      if torch.cuda.is_available(): torch.cuda.empty_cache()
                      continue

            del study
            if torch.cuda.is_available(): torch.cuda.empty_cache()
            continue 

        best_trial = study.best_trial
        best_params = best_trial.params
        best_value = best_trial.value

        logging.info(f"  Best Trial Found: #{best_trial.number} (Value: {best_value:.4f})")
        logging.info(f"  Best Hyperparameters:")
        for key, value in best_params.items():
            logging.info(f"    {key:<15}: {value}")

        # ========================================
        # 2. Final Training with Best Parameters
        # ========================================
        final_training_start_time = time.time()
        logging.info(f"\n--- Starting Final Training: {model_name} with Best Params ---")
        logging.info(f"  Settings: Max Epochs={final_train_epochs}, Patience={final_train_patience}")

        img_size = model_img_sizes[model_name]
        final_batch_size = best_params['batch_size']
        final_lr = best_params['learning_rate']
        final_dense_units = best_params['dense_units']
        final_dropout = best_params['dropout']
        final_optimizer_name = best_params['optimizer']
        final_unfreeze_layers = best_params['unfreeze_layers']

        train_loader, val_loader, test_loader = get_data_loaders(model_name, final_batch_size, img_size, num_workers=2) 

        final_model = CustomModel(model_name=model_name,
                                 dense_units=final_dense_units,
                                 dropout=final_dropout,
                                 pretrained=True,
                                 unfreeze_layers=final_unfreeze_layers).to(device)

        trainable_params_final = [p for p in final_model.parameters() if p.requires_grad]
        if not trainable_params_final:
            logging.error(f"FATAL: No trainable parameters found for the FINAL model {model_name}! Skipping training.")
            results[model_name] = {'error': 'Final model build resulted in no trainable parameters.'}
            del final_model, train_loader, val_loader, test_loader, study
            if torch.cuda.is_available(): torch.cuda.empty_cache()
            continue

        logging.info(f"Final model: Optimizing {len(trainable_params_final)} parameter groups ({sum(p.numel() for p in trainable_params_final):,} total parameters).")

        if final_optimizer_name == 'Adam':
            final_optimizer = torch.optim.Adam(trainable_params_final, lr=final_lr)
        elif final_optimizer_name == 'AdamW':
            final_optimizer = torch.optim.AdamW(trainable_params_final, lr=final_lr) 
        else:
            logging.warning(f"Unsupported optimizer '{final_optimizer_name}' found in best params. Defaulting to AdamW.")
            final_optimizer = torch.optim.AdamW(trainable_params_final, lr=final_lr)
        final_model, history = train_model(model=final_model,
                                           model_name=model_name,
                                           train_loader=train_loader,
                                           val_loader=val_loader,
                                           criterion=criterion,
                                           optimizer=final_optimizer,
                                           epochs=final_train_epochs,
                                           device=device,
                                           patience=final_train_patience,
                                           trial=None,
                                           hpo_warmup_steps=0) 

        final_training_duration = time.time() - final_training_start_time
        logging.info(f"--- Final Training Complete for {model_name} in {timedelta(seconds=int(final_training_duration))} ---")

        # ===================================
        # 3. Final Evaluation on Test Set
        # ===================================
        eval_start_time = time.time()
        logging.info(f"\n--- Evaluating Final Model ({model_name}) on Test Set ---")
        test_loss, test_accuracy, test_f1, test_precision, test_recall, conf_matrix = evaluate_model(
            final_model, test_loader, criterion, device
        )
        eval_duration = time.time() - eval_start_time
        logging.info(f"--- Test Set Evaluation Complete in {eval_duration:.2f} seconds ---")


        results[model_name] = {
            'status': 'Success',
            'best_params_from_hpo': best_params,
            'best_val_f1_hpo': best_value,
            'final_train_epochs_run': len(history.get('train_loss', [])),
            'test_loss': test_loss,
            'test_accuracy': test_accuracy,
            'test_f1': test_f1,
            'test_precision': test_precision,
            'test_recall': test_recall,
            'confusion_matrix': conf_matrix.tolist(), 
            'training_history': history 
        }

        logging.info(f"--- Test Set Metrics for Final Trained {model_name} ---")
        logging.info(f"  Accuracy:  {test_accuracy:.4f}")
        logging.info(f"  F1 Score:  {test_f1:.4f}")
        logging.info(f"  Precision: {test_precision:.4f}")
        logging.info(f"  Recall:    {test_recall:.4f}")
        logging.info(f"  Loss:      {test_loss:.4f}")

        results_path = os.path.join(model_checkpoint_dir, f"{model_name}_test_results.json")
        try:
             serializable_results = copy.deepcopy(results[model_name])
             if 'confusion_matrix' in serializable_results and isinstance(serializable_results['confusion_matrix'], np.ndarray):
                  serializable_results['confusion_matrix'] = serializable_results['confusion_matrix'].tolist()
             serializable_results['model_name'] = model_name

             with open(results_path, 'w') as f:
                  json.dump(serializable_results, f, indent=4)
             logging.info(f"Test results saved to: {results_path}")
        except Exception as json_e:
             logging.error(f"Error saving test results to JSON {results_path}: {json_e}", exc_info=True)

        # ===================================
        # 5. Clean up Resources
        # ===================================
        logging.info(f"Cleaning up resources for {model_name}...")
        del final_model, train_loader, val_loader, test_loader, final_optimizer, history, study, trainable_params_final
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        logging.info(f"Resources cleaned up for {model_name}.")


    except Exception as model_e:
        logging.error(f"!!!!!!!! UNEXPECTED ERROR processing model {model_name}: {model_e} !!!!!!!!", exc_info=True)
        results[model_name] = {'status': 'Failed', 'error': str(model_e)}

        if 'final_model' in locals(): del final_model
        if 'train_loader' in locals(): del train_loader
        if 'val_loader' in locals(): del val_loader
        if 'test_loader' in locals(): del test_loader
        if 'final_optimizer' in locals(): del final_optimizer
        if 'study' in locals(): del study
        if 'history' in locals(): del history
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        logging.info(f"Resources cleaned up for {model_name} after error.")


    model_duration = time.time() - model_start_time
    logging.info(f"--- Completed ALL processing for {model_name} in {timedelta(seconds=int(model_duration))} ---")

05/07 22:53:07 - 
Processing Model: ViT_Base
05/07 22:53:07 - --- Starting Hyperparameter Tuning for ViT_Base ---
05/07 22:53:07 -   Optuna Settings: Max Trials=30, Timeout=10:00:00, Epochs/Trial=8, Patience/Trial=3, Pruning Warmup=2


[I 2025-05-07 22:53:09,587] A new study created in RDB with name: sig_verify_ViT_Base_hpo_v3


05/07 22:53:09 - Optuna study 'sig_verify_ViT_Base_hpo_v3': Loaded with 0 existing trials.
05/07 22:53:09 - Running 30 new Optuna trials...
05/07 22:53:09 - 
--- Optuna Trial 0 Start (ViT_Base) ---
  Trial 0 Parameters:
    dense_units    : 512
    dropout        : 0.1
    learning_rate  : 4.040740331614598e-05
    optimizer      : Adam
    batch_size     : 64
    unfreeze_layers: 1
  Best Value (MAXIMIZE F1) So Far: N/A (First trial)
05/07 22:53:09 - Creating DataLoaders for ViT_Base (Img Size: (224, 224), Batch Size: 64, Num Workers: 2)...
05/07 22:53:09 - Dataset sizes: Train=4318, Validation=925, Test=928
05/07 22:53:09 -   Train class distribution: Real=2231, Forged=2087
05/07 22:53:09 - DataLoaders created successfully for ViT_Base.
05/07 22:53:09 - Initializing CustomModel: Base='ViT_Base', DenseUnits=512, Dropout=0.10, Unfreeze=1


config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

05/07 22:53:12 - Loaded 'google/vit-base-patch16-224' from HuggingFace Transformers. Reported features: 768
05/07 22:53:13 - Verified that model 'ViT_Base' outputs 768 features as expected.
05/07 22:53:13 - Initially froze 85,798,656 parameters in the base model.
05/07 22:53:13 - Unfroze ViT encoder layer 11
05/07 22:53:13 - Base model state after unfreeze: 7,087,872 trainable params, 78,710,784 frozen params.
05/07 22:53:13 - Classifier head created with 395,265 trainable parameters.
05/07 22:53:13 -   Trial 0: Optimizing 22 parameter groups (7,483,137 total parameters).
05/07 22:53:13 - Best model checkpoint (based on val loss) will be saved to: /kaggle/working/checkpoints/SignatureVerification/ViT_Base_trial_0/ViT_Base_trial_0_best_val_loss.pth
05/07 22:53:13 - --- Starting Training: ViT_Base_trial_0 for 8 epochs (Patience: 3) ---


Epoch 1/8 Train:   0%|          | 0/68 [00:00<?, ?batch/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?batch/s]

05/07 22:56:20 - Epoch 1/8 | Train Loss: 0.6611 | Train Acc: 0.6070 | Val Loss: 0.6076 | Val Acc: 0.6649 | Val F1: 0.6586 | Time: 0:03:07
05/07 22:56:21 -   -> Val loss improved to 0.6076. Saved best model checkpoint.


Epoch 2/8 Train:   0%|          | 0/68 [00:00<?, ?batch/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?batch/s]

05/07 22:59:39 - Epoch 2/8 | Train Loss: 0.5823 | Train Acc: 0.6964 | Val Loss: 0.5685 | Val Acc: 0.7049 | Val F1: 0.6844 | Time: 0:03:18
05/07 22:59:40 -   -> Val loss improved to 0.5685. Saved best model checkpoint.


Epoch 3/8 Train:   0%|          | 0/68 [00:00<?, ?batch/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?batch/s]

05/07 23:03:01 - Epoch 3/8 | Train Loss: 0.5462 | Train Acc: 0.7228 | Val Loss: 0.5369 | Val Acc: 0.7373 | Val F1: 0.7315 | Time: 0:03:20
05/07 23:03:02 -   -> Val loss improved to 0.5369. Saved best model checkpoint.


Epoch 4/8 Train:   0%|          | 0/68 [00:00<?, ?batch/s]

# Print final results :

This cell is responsible for summarizing and displaying the results obtained from the main experimentation loop (Cell 20).
It performs the following actions:
1.  **Log Total Execution Time**: Calculates and logs the total time taken for the entire script to run.
2.  **Iterate Through Results**:
    -   Sorts the model names alphabetically from the `results` dictionary for consistent output.
    -   For each `model_name`:
        -   Prints a header for the model.
        -   **Status Check**: Checks if an error occurred during the processing of this model. If so, prints "FAILED" and the error message.
        -   **Success Case**: If processing was successful:
            -   Prints "Success".
            -   Displays the best validation F1 score achieved during Hyperparameter Optimization (HPO) and the corresponding best hyperparameters.
            -   Prints the key test set metrics (Accuracy, F1 Score, Precision, Recall, Loss) achieved by the final trained model.
            -   **Plot Confusion Matrix**: If a confusion matrix is available in the results, it's plotted using `seaborn.heatmap` (with 'Forged' and 'Real' labels) and displayed.
            -   **Plot Training History**: If training history (loss and accuracy curves) is available, it's plotted using `matplotlib`. Two subplots are created: one for loss (training vs. validation) and one for accuracy (training vs. validation) over epochs. These plots are then displayed.
3.  **Final Log Message**: Logs "--- Script Finished ---".

This cell provides a comprehensive overview of each model's performance, including both HPO insights and final test evaluation, along with visualizations to aid in analysis.

In [None]:
overall_duration = time.time() - overall_start_time
logging.info(f"\n{'='*60}\nTotal Script Execution Time: {timedelta(seconds=int(overall_duration))}\n{'='*60}")

print("\n" + "="*60)
print("      Summary of Test Results (from Initial Training Run)")
print("="*60)
successful_models_initial = 0
sorted_model_names = sorted(results.keys())

for model_name in sorted_model_names:
    metrics = results[model_name]
    print(f"\n--- Results for: {model_name} (Initial Run) ---")
    if 'error' in metrics:
        print(f"  Status: FAILED")
        print(f"  Error: {metrics['error']}")
        continue

    successful_models_initial += 1
    print(f"  Status: Success")
    print(f"  Best Validation F1 (HPO): {metrics.get('best_val_f1_hpo', 'N/A'):.4f}")
    print(f"  Best Hyperparameters (from HPO): {metrics.get('best_params_from_hpo', 'N/A')}")
    print("-" * 30)
    print(f"  Test Accuracy:       | {metrics.get('test_accuracy', -1):.4f} |")
    print(f"  Test F1 Score:       | {metrics.get('test_f1', -1):.4f} |")
    print(f"  Test Precision:      | {metrics.get('test_precision', -1):.4f} |")
    print(f"  Test Recall:         | {metrics.get('test_recall', -1):.4f} |")
    print(f"  Test Loss:           | {metrics.get('test_loss', -1):.4f} |")
    print("-" * 30)

    # --- Plot Confusion Matrix (from initial run's test eval) ---
    if 'confusion_matrix' in metrics:
        conf_matrix = np.array(metrics['confusion_matrix'])
        plt.figure(figsize=(6, 5))
        sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Greens', 
                    xticklabels=['Forged', 'Real'], yticklabels=['Forged', 'Real'], cbar=False)
        plt.title(f'Confusion Matrix (Initial Test Eval) - {model_name}')
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        plt.tight_layout()
        plt.show()


    # --- Plot Training History (Loss and Accuracy Curves) ---
    if 'training_history' in metrics:
        history = metrics['training_history']
        epochs_ran = len(history.get('train_loss', []))
        if epochs_ran > 0:
            epoch_list = range(1, epochs_ran + 1)
            plt.figure(figsize=(12, 5))

            plt.subplot(1, 2, 1)
            plt.plot(epoch_list, history['train_loss'], 'bo-', label='Training Loss')
            plt.plot(epoch_list, history['val_loss'], 'ro-', label='Validation Loss')
            plt.title(f'Loss vs. Epochs - {model_name}')
            plt.xlabel('Epochs')
            plt.ylabel('Loss')
            plt.legend()
            plt.grid(True)


            plt.subplot(1, 2, 2)
            plt.plot(epoch_list, history['train_acc'], 'bo-', label='Training Accuracy')
            plt.plot(epoch_list, history['val_acc'], 'ro-', label='Validation Accuracy')
            plt.title(f'Accuracy vs. Epochs - {model_name}')
            plt.xlabel('Epochs')
            plt.ylabel('Accuracy')
            plt.ylim(0.0, 1.05)
            plt.legend()
            plt.grid(True)

            plt.tight_layout()
            plt.show()
        else:
            print("  Training history data is empty (training might have failed early).")
    else:
         print("  No training history available for initial run.")
print("="*60)

logging.info("--- Script Finished ---")

# Download result zip

This cell iterates through the `sorted_model_names` (which are the names of the models processed in the main loop). For each `model_name`:
1.  It logs a message indicating that it's starting the checkpoint saving/uploading process for that model.
2.  It calls the `upload_checkpoint_to_dataset(model_name)` function (defined in Cell 14). This function is responsible for:
    *   Loading the saved metrics for the model.
    *   Generating training history and confusion matrix plots and saving them in the model's checkpoint directory.
    *   Packaging all relevant files (model weights, metrics JSON, plots) from the checkpoint directory into a zip file.
    *   Handling the "upload" or saving of this zip file based on the environment:
        *   **Kaggle**: Creates the zip in `/kaggle/working/` and provides download instructions/links.
        *   **Colab**: Copies files to Google Drive.
        *   **Local**: Copies files to a local directory.

The final comment suggests that after downloading the zip file(s), the user should extract them and then upload the results to a shared drive, implying a manual step for long-term storage or sharing.

In [None]:
for model_name in sorted_model_names:
        logging.info(f"\n--- Saving/Uploading Checkpoints for {model_name} ---")
        upload_checkpoint_to_dataset( model_name)
