# Data loading

In [None]:
# Import necessary libraries
import pandas as pd
import numpy as np, random
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import os
import kagglehub
import cv2
import pywt
from tqdm import tqdm
from PIL import Image

from sklearn.model_selection import train_test_split
import tensorflow as tf
import keras
from keras.models import load_model
from keras.preprocessing import image
import seaborn as sns
from sklearn import metrics

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, regularizers, callbacks, optimizers
from tensorflow.keras.models import Model

from sklearn.model_selection import StratifiedKFold
from statsmodels.stats.contingency_tables import mcnemar
import gc

from typing import Dict, Any

import warnings
warnings.filterwarnings('ignore')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Set path
path = kagglehub.dataset_download("xhlulu/140k-real-and-fake-faces")
dataset_dir = path + "/real_vs_fake/real-vs-fake"

print("Data directory:", dataset_dir)

In [None]:
IMG_SIZE = (256, 256)
BATCH_SIZE = 128
SEED = 42
CHANNELS = 3
INITIAL_LR = 1e-4
TOTAL_IMAGE_COUNT = 140000

tf.random.set_seed(SEED)
np.random.seed(SEED); random.seed(SEED)


# EDA

In [None]:
# Build dataframe with file paths and labels
all_files = []
all_labels = []

for split in ["train", "valid", "test"]:
    split_dir = os.path.join(dataset_dir, split)

    for label in ["real", "fake"]:
        class_dir = os.path.join(split_dir, label)

        for fname in os.listdir(class_dir):
            all_files.append(os.path.join(class_dir, fname))
            all_labels.append(label)

df = pd.DataFrame({"filename": all_files, "class": all_labels})

print(df)

In [None]:
# Randomly select 10 real and 10 fake images
real_samples = df[df['class'] == 'real'].sample(10)
fake_samples = df[df['class'] == 'fake'].sample(10)
samples_df = pd.concat([real_samples, fake_samples])

# Display the sample images
plt.figure(figsize=(12,10))
for i, row in enumerate(samples_df.itertuples()):
    img = plt.imread(row.filename) / 255.0  # normalize
    plt.subplot(4,5,i+1)
    plt.imshow(img)
    plt.title(row._2)
    plt.axis("off")
plt.show()

## Fourier transform

In [None]:
# Perform fourier transform and display
plt.figure(figsize=(24, 10))
for i, row in enumerate(samples_df.itertuples()):
    img = plt.imread(row.filename) / 255.0  # normalize
    gray = np.mean(img, axis=2) # convert to grayscale
    f = np.fft.fft2(gray)
    fshift = np.fft.fftshift(f)
    magnitude_spectrum = np.log(1 + np.abs(fshift)) # perform fourier transform

    # Left: original image; Right: Fourier magnitude spectrum
    plt.subplot(4, 10, 2*i+1)
    plt.imshow(img)
    plt.title(f"{row._2}")
    plt.axis("off")

    plt.subplot(4, 10, 2*i+2)
    plt.imshow(magnitude_spectrum, cmap='gray')
    plt.axis("off")

plt.tight_layout()
plt.show()

## Wavelet transform

In [None]:
# Perform wavelet transform and display
plt.figure(figsize=(12,10))
for i, row in enumerate(samples_df.itertuples()):
    # read and normalize
    img = plt.imread(row.filename) / 255.0
    gray = np.mean(img, axis=2)

    # 2D Discrete Wavelet Transform (Haar)
    coeffs2 = pywt.dwt2(gray, 'haar')
    LL, (LH, HL, HH) = coeffs2

    # Show original + 4 sub-bands
    plt.subplot(5, 5, i*5 % 25 + 1); plt.imshow(img); plt.title(f"{row._2}"); plt.axis("off")
    plt.subplot(5, 5, i*5 % 25 + 2); plt.imshow(LL, cmap='gray'); plt.title("LL"); plt.axis("off")
    plt.subplot(5, 5, i*5 % 25 + 3); plt.imshow(LH, cmap='gray'); plt.title("LH"); plt.axis("off")
    plt.subplot(5, 5, i*5 % 25 + 4); plt.imshow(HL, cmap='gray'); plt.title("HL"); plt.axis("off")
    plt.subplot(5, 5, i*5 % 25 + 5); plt.imshow(HH, cmap='gray'); plt.title("HH"); plt.axis("off")

    if (i+1) % 5 == 0:
        plt.tight_layout()
        plt.show()
        plt.figure(figsize=(12,10))

plt.tight_layout()
plt.show()


## Real image example

In [None]:
img_path = "/content/drive/MyDrive/111" # a 1024*1024 real image from original FFHQ dataset
img = plt.imread(img_path)
gray = np.mean(img, axis=2)

# fourier transform
f = np.fft.fft2(gray)
fshift = np.fft.fftshift(f)
magnitude = np.log(1 + np.abs(fshift))

# wavelet transfrom
coeffs2 = pywt.dwt2(gray, 'haar')
LL, (LH, HL, HH) = coeffs2

plt.figure(figsize=(8,12))
plt.subplot(3,2,1); plt.imshow(img); plt.title('Original Image'); plt.axis("off")
plt.subplot(3,2,2); plt.imshow(magnitude, cmap='gray'); plt.title('Fourier transform'); plt.axis("off")
plt.subplot(3,2,3); plt.imshow(LL, cmap='gray'); plt.title('LL'); plt.axis("off")
plt.subplot(3,2,4); plt.imshow(LH, cmap='gray'); plt.title('LH'); plt.axis("off")
plt.subplot(3,2,5); plt.imshow(HL, cmap='gray'); plt.title('HL'); plt.axis("off")
plt.subplot(3,2,6); plt.imshow(HH, cmap='gray'); plt.title('HH'); plt.axis("off")

## Mean and STD of Wavelets

In [None]:
CHANNELS = 9
SAMPLING_FRACTION = 0.1

def compute_wavelet_features_np(np_image):
    """Computes and stacks 9 high-frequency DWT subbands for a SINGLE image (uint8 NumPy array)."""

    wavelet_features = []

    for i in range(3):
        channel = np_image[:, :, i]
        coeffs = pywt.dwt2(channel, 'haar')
        LH, HL, HH = coeffs[1]
        wavelet_features.extend([LH, HL, HH])

    wavelet_stack = np.stack(wavelet_features, axis=-1)
    return wavelet_stack

df_sample = df.sample(frac=SAMPLING_FRACTION, random_state=42)
total_sample_size = len(df_sample)
print(f"Sampling {total_sample_size} images ({SAMPLING_FRACTION*100:.0f}% of total) for Wavelet stats...")

all_wavelet_pixels = []
total_pixels_processed = 0

# Access the image path directly from the row value, as the index is the row identifier.
for index, row in tqdm(df_sample.iterrows(), total=total_sample_size):
    img_path = row.iloc[0]
    image = cv2.imread(img_path)

    wavelet_features = compute_wavelet_features_np(image)

    reshaped_data = wavelet_features.reshape(-1, CHANNELS)
    all_wavelet_pixels.append(reshaped_data)
    total_pixels_processed += reshaped_data.shape[0]

combined_data = np.concatenate(all_wavelet_pixels, axis=0)
print(f"Total {combined_data.shape[0]} pixel samples aggregated.")

Wavelet_MEAN_NP = np.mean(combined_data, axis=0).astype(np.float32)
Wavelet_STD_NP = np.std(combined_data, axis=0).astype(np.float32)

print(f"Wavelet_MEAN = tf.constant({Wavelet_MEAN_NP.tolist()}, dtype=tf.float32)")
print(f"Wavelet_STD = tf.constant({Wavelet_STD_NP.tolist()}, dtype=tf.float32)")

# Model training

## Dataset Split

In [None]:
train_ds = tf.keras.utils.image_dataset_from_directory(
    os.path.join(dataset_dir, "train"),
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    label_mode="binary",
    seed=SEED
)

valid_ds = tf.keras.utils.image_dataset_from_directory(
    os.path.join(dataset_dir, "valid"),
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    label_mode="binary",
    seed=SEED
)

test_ds = tf.keras.utils.image_dataset_from_directory(
    os.path.join(dataset_dir, "test"),
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    label_mode="binary",
    shuffle=False
)

In [None]:
# Inspect one sample (image values + label)
for img, label in train_ds:
    print("Image values:", img[0])
    print("Label:", label[0])
    break

# Check batch shapes
for img, label in train_ds:
    print("Image batch shape:", img.shape)
    print("Label batch shape:", label.shape)
    break

# Show class-to-index mapping
print("Class names:", train_ds.class_names)

## Dataset Preprocessing

In [None]:
Wavelet_MEAN = tf.constant([0.06322062760591507, 0.14566604793071747, 0.004705097526311874, 0.07626466453075409, -0.05939734727144241, 0.005020765122026205, 0.02491244673728943, 0.11974337697029114, 0.004669941496104002], dtype=tf.float32)
Wavelet_STD = tf.constant([11.035733222961426, 11.721734046936035, 7.004025459289551, 11.152276992797852, 11.841991424560547, 7.022481918334961, 11.28213119506836, 12.013513565063477, 7.008289337158203], dtype=tf.float32)

# DWT Feature Extraction (pywt + tf.py_function)
def compute_wavelet_features_py(image_tensor):
    """Computes and stacks 9 high-frequency DWT subbands for a SINGLE image (uint8)."""
    # NOTE: This part is the performance bottleneck due to tf.py_function and .numpy()
    np_image = image_tensor.numpy()
    wavelet_features = []

    for i in range(3): # Iterate through R, G, B channels
        channel = np_image[:, :, i]
        coeffs = pywt.dwt2(channel, 'haar')
        # coeffs[1] contains (LH, HL, HH) high-frequency subbands
        LH, HL, HH = coeffs[1]
        wavelet_features.extend([LH, HL, HH])

    wavelet_stack = np.stack(wavelet_features, axis=-1)
    return tf.convert_to_tensor(wavelet_stack, dtype=tf.float32)

def process_wavelet_features(images_batch):
    """Computes, resizes, and returns UN-NORMALIZED Wavelet features for a batch (SLOW)."""

    def process_single_image_wavelet(image):
        # image must be raw uint8 here (0-255)
        wavelet_features_downscaled = tf.py_function(
            compute_wavelet_features_py,
            [image],
            tf.float32
        )
        new_h, new_w = IMG_SIZE[0] // 2, IMG_SIZE[1] // 2
        # Must set the shape for tf.data to work
        wavelet_features_downscaled.set_shape([new_h, new_w, 9])

        wavelet_features_resized = tf.image.resize(
            wavelet_features_downscaled,
            IMG_SIZE,
            method='bilinear'
        )
        return wavelet_features_resized

    # tf.map_fn applies the py_function-based logic to every image in the batch
    wavelet_batch = tf.map_fn(
        process_single_image_wavelet,
        images_batch,
        fn_output_signature=tf.TensorSpec(shape=(IMG_SIZE[0], IMG_SIZE[1], 9), dtype=tf.float32),
    )

    wavelet_batch_standardized = (wavelet_batch - Wavelet_MEAN) / (Wavelet_STD + 1e-7)

    return wavelet_batch_standardized

# Data Augmentation and Preprocessing Functions
def apply_augmentation(images_batch):
    """Apply Random Horizontal Flip and Additive Gaussian Noise to raw RGB batch (0-255)."""
    augmented_images = tf.image.random_flip_left_right(images_batch)
    float_images = tf.cast(augmented_images, tf.float32)
    NOISE_STDDEV = 5.0

    noise = tf.random.normal(shape=tf.shape(float_images), mean=0.0, stddev=NOISE_STDDEV, dtype=tf.float32)
    noisy_images = tf.clip_by_value(float_images + noise, 0.0, 255.0)

    return tf.cast(noisy_images, images_batch.dtype)


# RGB ONLY Preprocessing (3 Channels)
def preprocess_rgb_only_train(images_batch, labels_batch):
    augmented_images_batch = apply_augmentation(images_batch)
    return tf.cast(augmented_images_batch, tf.float32) / 255.0, labels_batch

def preprocess_rgb_only_valid(images_batch, labels_batch):
    return tf.cast(images_batch, tf.float32) / 255.0, labels_batch


# WAVELETS ONLY Preprocessing (9 Channels)
def preprocess_wavelet_only_train(images_batch, labels_batch):
    augmented_images_batch = apply_augmentation(images_batch)
    wavelet_batch = process_wavelet_features(augmented_images_batch)
    return wavelet_batch, labels_batch

def preprocess_wavelet_only_valid(images_batch, labels_batch):
    wavelet_batch = process_wavelet_features(images_batch)
    return wavelet_batch, labels_batch


# RGB + WAVELETS Preprocessing (12 Channels)
def preprocess_rgb_wavelet_train(images_batch, labels_batch):
    augmented_images_batch = apply_augmentation(images_batch)
    rgb_normalized_batch = tf.cast(augmented_images_batch, tf.float32) / 255.0
    wavelet_batch = process_wavelet_features(augmented_images_batch)

    combined_features_batch = tf.concat([rgb_normalized_batch, wavelet_batch], axis=-1)
    return combined_features_batch, labels_batch

def preprocess_rgb_wavelet_valid(images_batch, labels_batch):
    rgb_normalized_batch = tf.cast(images_batch, tf.float32) / 255.0
    wavelet_batch = process_wavelet_features(images_batch)
    combined_features_batch = tf.concat([rgb_normalized_batch, wavelet_batch], axis=-1)
    return combined_features_batch, labels_batch

# Dataset Mapping
train_ds_rgb = train_ds.map(preprocess_rgb_only_train, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)
valid_ds_rgb = valid_ds.map(preprocess_rgb_only_valid, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)

train_ds_wavelet = train_ds.map(preprocess_wavelet_only_train, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)
valid_ds_wavelet = valid_ds.map(preprocess_wavelet_only_valid, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)

train_ds_rgb_wavelet = train_ds.map(preprocess_rgb_wavelet_train, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)
valid_ds_rgb_wavelet = valid_ds.map(preprocess_rgb_wavelet_valid, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)

print("All datasets are prepared (using pywt + tf.py_function) and model uses Batch Normalization.")

In [None]:
# Test Dataset Mapping
test_ds_rgb = test_ds.map(preprocess_rgb_only_valid, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)
test_ds_wavelet = test_ds.map(preprocess_wavelet_only_valid, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)
test_ds_rgb_wavelet = test_ds.map(preprocess_rgb_wavelet_valid, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)

## Callbacks Setup

In [None]:
early_stopping_cb = callbacks.EarlyStopping(monitor='val_loss', patience=6, restore_best_weights=True)

reduce_lr = callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-7)

def get_checkpoint_cb(name):

    filepath = f"/content/drive/MyDrive/Thesis/{name}_optimized_best.weights.h5"
    return tf.keras.callbacks.ModelCheckpoint(
        filepath=filepath,
        save_weights_only=True,
        monitor='val_loss',
        mode='min',
        save_best_only=True)

## Model Setup

In [None]:
def densenet_scratch(input_channels, model_name, initial_lr):
    """Builds DenseNet121 from scratch, including input BN for feature normalization."""

    optimizer = tf.keras.optimizers.Adam(
        learning_rate=initial_lr,
        weight_decay=1e-4
    )

    base = tf.keras.applications.DenseNet121(
        weights=None,
        include_top=False,
        input_shape=(IMG_SIZE[0], IMG_SIZE[1], input_channels),
        name=f"densenet121_{input_channels}ch"
    )

    # Building the model flow
    x = base.output

    # Classification Head
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dense(512, activation='relu')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Dropout(0.5)(x)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(x)

    model = tf.keras.Model(inputs=base.input, outputs=output, name=model_name)
    model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])
    return model

def train_single_phase_optimized(model, train_ds, valid_ds, model_tag, initial_lr):
    """
    Trains the model using a single, long phase with automatic LR reduction and Early Stopping.
    """

    print(f"\n--- Starting {model_tag} Single Phase Optimization (LR={initial_lr}) ---")

    model.optimizer.learning_rate.assign(initial_lr)

    checkpoint_cb = get_checkpoint_cb(model_tag)

    callbacks_list = [checkpoint_cb, early_stopping_cb, reduce_lr]

    TOTAL_EPOCHS = 60

    history = model.fit(
        train_ds,
        validation_data=valid_ds,
        epochs=TOTAL_EPOCHS,
        callbacks=callbacks_list,
        verbose=1
    )

    try:
        best_weights_path = f"/content/drive/MyDrive/Thesis/{model_tag}_optimized_best.weights.h5"
        model.load_weights(best_weights_path)
        print(f"Loaded BEST weights from checkpoint for final model object.")
    except Exception as e:
        print(f"Could not load best weights from disk. Using weights restored by EarlyStopping.")

    return history

In [None]:
model_rgb = densenet_scratch(3, "rgb_only", INITIAL_LR)
model_rgb.load_weights("/content/drive/MyDrive/Thesis/rgb_only_optimized_best.weights.h5")
model_wavelets = densenet_scratch(9, "wavelets_only", INITIAL_LR)
model_wavelets.load_weights("/content/drive/MyDrive/Thesis/wavelets_only_optimized_best.weights.h5")
model_rgb_wavelets = densenet_scratch(12, "rgb_wavelets", INITIAL_LR)
model_rgb_wavelets.load_weights("/content/drive/MyDrive/Thesis/rgb_wavelets_optimized_best.weights.h5")

## 1. RGB ONLY Model (3 Channels)

In [None]:
model_rgb = densenet_scratch(3, "rgb_only", INITIAL_LR)
history_rgb = train_single_phase_optimized(model_rgb, train_ds_rgb, valid_ds_rgb, "rgb_only", INITIAL_LR)

## 2. Wavelets ONLY Model (9 Channels)

In [None]:
model_wavelets = densenet_scratch(9, "wavelets_only", INITIAL_LR)
history_wavelets = train_single_phase_optimized(model_wavelets, train_ds_wavelet, valid_ds_wavelet, "wavelets_only", INITIAL_LR)

## 3. RGB + Wavelets Model (12 Channels)

In [None]:
model_rgb_wavelets = densenet_scratch(12, "rgb_wavelets", INITIAL_LR)
history_rgb_wavelets = train_single_phase_optimized(model_rgb_wavelets, train_ds_rgb_wavelet, valid_ds_rgb_wavelet, "rgb_wavelets", INITIAL_LR)

## Evaluation

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# RGB Only
epochs_rgb = np.arange(1, 22)
val_acc_rgb = [0.5698,0.8039,0.6641,0.8837,0.5240,0.7947,0.9186,0.9050,0.6664,0.9146,
               0.9646,0.9427,0.9625,0.9432,0.9753,0.9754,0.9765,0.9640,0.9790,0.9774,0.9567]
val_loss_rgb = [1.1607,0.4665,0.9485,0.2816,2.4571,0.6526,0.2125,0.2639,1.5332,0.2655,
                0.1048,0.1786,0.1306,0.2229,0.0772,0.0837,0.0870,0.1292,0.0780,0.0847,0.1740]

# Wavelet Only
epochs_wav = np.arange(1, 16)
val_acc_wav = [0.5433,0.8595,0.6839,0.8429,0.7840,0.9059,0.8766,0.9564,0.9697,0.9624,
               0.9645,0.8691,0.9387,0.9717,0.9588]
val_loss_wav = [2.1196,0.3544,1.4302,0.4641,0.8053,0.3061,0.5417,0.1168,0.0809,0.1084,
                0.1058,0.5942,0.2248,0.0847,0.1456]

# RGB + Wavelet
epochs_fuse = np.arange(1, 21)
val_acc_fuse = [0.6327,0.7728,0.9048,0.7564,0.8747,0.7071,0.8768,0.9346,0.9551,0.8801,
                0.8569,0.9366,0.9635,0.9820,0.9639,0.9143,0.9539,0.9743,0.9416,0.9767]
val_loss_fuse = [1.2879,0.6531,0.2352,1.1157,0.3708,1.2748,0.4079,0.1918,0.1312,0.4632,
                 0.6902,0.2180,0.1148,0.0525,0.1183,0.3935,0.1743,0.0861,0.2276,0.0840]

# Stats for table
def get_summary(name, epochs, acc, loss):
    best_acc_idx = np.argmax(acc)
    best_loss_idx = np.argmin(loss)
    return [
        name,
        f"{acc[best_acc_idx]:.4f} (Ep {epochs[best_acc_idx]})",
        f"{loss[best_loss_idx]:.4f} (Ep {epochs[best_loss_idx]})"
    ]

summary_data = [
    get_summary("RGB Only", epochs_rgb, val_acc_rgb, val_loss_rgb),
    get_summary("Wavelets Only", epochs_wav, val_acc_wav, val_loss_wav),
    get_summary("RGB + Wavelets", epochs_fuse, val_acc_fuse, val_loss_fuse)
]

colors = {"RGB Only":"#1f77b4", "Wavelets Only":"#ff7f0e", "RGB + Wavelets":"#2ca02c"}

fig, axes = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
plt.subplots_adjust(hspace=0.35, bottom=0.15, top=0.9)

# Validation Accuracy
axes[0].plot(epochs_rgb, val_acc_rgb, label="RGB Only", color=colors["RGB Only"], linewidth=2)
axes[0].plot(epochs_wav, val_acc_wav, label="Wavelets Only", color=colors["Wavelets Only"], linewidth=2)
axes[0].plot(epochs_fuse, val_acc_fuse, label="RGB + Wavelets", color=colors["RGB + Wavelets"], linewidth=2)
axes[0].set_ylabel("Validation Accuracy")
axes[0].set_ylim(0.4, 1.05)
axes[0].grid(True, linestyle="--", alpha=0.3)
axes[0].legend(loc="lower right", fontsize=9)
axes[0].set_title("Validation Accuracy Comparison")

# Validation Loss
axes[1].plot(epochs_rgb, val_loss_rgb, label="RGB Only", color=colors["RGB Only"], linewidth=2)
axes[1].plot(epochs_wav, val_loss_wav, label="Wavelets Only", color=colors["Wavelets Only"], linewidth=2)
axes[1].plot(epochs_fuse, val_loss_fuse, label="RGB + Wavelets", color=colors["RGB + Wavelets"], linewidth=2)
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Validation Loss")
axes[1].set_ylim(0, 2.6)
axes[1].grid(True, linestyle="--", alpha=0.3)
axes[1].legend(loc="upper right", fontsize=9)
axes[1].set_title("Validation Loss Comparison")

fig_table, ax_table = plt.subplots(figsize=(6, 1.3))
ax_table.axis("off")

col_labels = ["Model", "Best Val Acc", "Min Val Loss"]
the_table = ax_table.table(
    cellText=summary_data,
    colLabels=col_labels,
    loc='center',
    cellLoc='center'
)
the_table.auto_set_font_size(False)
the_table.set_fontsize(9)
the_table.scale(1.2, 1.3)

plt.show()


In [None]:
# Define the evaluation configuration: Model, Name, and its corresponding Dataset
eval_configurations = {
    "RGB Only": (model_rgb, test_ds_rgb),
    "Wavelets Only": (model_wavelets, test_ds_wavelet),
    "RGB + Wavelets": (model_rgb_wavelets, test_ds_rgb_wavelet)
}

# Extract true labels (Y_true) once. The labels are identical across all mapped test datasets.
Y_true = np.concatenate([y for x, y in test_ds_rgb], axis=0).astype(int).flatten()

# Evaluation and Comparison
results_df = pd.DataFrame()
all_preds = {}
roc_plots = {}
T = 0.5 # Classification threshold

for name, (model, dataset) in eval_configurations.items():

    # Predict probabilities using the correct dataset for the model
    Y_pred_proba = model.predict(dataset, verbose=0).flatten()

    # Store binary predictions
    Y_pred_binary = (Y_pred_proba > T).astype(int)
    all_preds[name] = Y_pred_binary

    # Calculate core metrics
    acc = metrics.accuracy_score(Y_true, Y_pred_binary)
    prec = metrics.precision_score(Y_true, Y_pred_binary)
    rec = metrics.recall_score(Y_true, Y_pred_binary)
    f1 = metrics.f1_score(Y_true, Y_pred_binary)

    # Calculate AUC-ROC
    fpr, tpr, thresholds = metrics.roc_curve(Y_true, Y_pred_proba)
    auc_score = metrics.auc(fpr, tpr)

    # Store results in DataFrame
    results_df = pd.concat([results_df, pd.DataFrame({
        'Model': name,
        'Accuracy': acc,
        'Precision': prec,
        'Recall': rec,
        'F1 Score': f1,
        'AUC-ROC': auc_score
    }, index=[0])], ignore_index=True)

    roc_plots[name] = (fpr, tpr, auc_score)

print("Evaluation complete.")

## Visualization

### Performance Matrics

In [None]:
# Performance Metrics Comparison
pd.set_option('display.float_format', lambda x: '%.4f' % x)
display(results_df.sort_values(by='AUC-ROC', ascending=False))

### ROC Curve

In [None]:
# Plot ROC Curve
plt.figure(figsize=(8, 5))
for name, (fpr, tpr, auc_score) in roc_plots.items():
    plt.plot(fpr, tpr, label=f'{name} (AUC = {auc_score:.4f})')

plt.plot([0, 1], [0, 1], 'k--') # Diagonal line for random guess
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate (FPR)', fontsize=12)
plt.ylabel('True Positive Rate (TPR)', fontsize=12)
plt.title('Receiver Operating Characteristic (ROC) Curve Comparison', fontsize=14)
plt.legend(loc="lower right")
plt.grid(True, linestyle=':', alpha=0.6)
plt.show()


### Confusion Matrix

In [None]:
# Plot Confusion Matrices
plt.figure(figsize=(15, 5))

for i, (name, Y_pred) in enumerate(all_preds.items()):
    cm = metrics.confusion_matrix(Y_true, Y_pred)

    plt.subplot(1, len(eval_configurations), i + 1)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False,
                xticklabels=['Predict Fake (0)', 'Predict Real (1)'],
                yticklabels=['Actual Fake (0)', 'Actual Real (1)'])
    plt.title(f'CM: {name}')
    plt.ylabel('Actual Label')
    plt.xlabel('Predicted Label')

plt.tight_layout()
plt.show()

## Cross Validation (Stratified K Fold + Bootstrap CI + McNEMAR'S Test)



In [None]:
k = 10
threshold = 0.5
random_state = 42

def collect_labels(ds):
    ys = []
    for _, yb in ds:
        ys.append(np.asarray(yb))
    return np.concatenate(ys, axis=0).astype(int).flatten()

Y = collect_labels(test_ds_rgb)
n = len(Y)
print("labels:", n)

def predict_probs(model, ds):
    probs = model.predict(ds.prefetch(tf.data.AUTOTUNE), verbose=1)
    return np.asarray(probs).flatten()

evals = {
    "RGB Only": (model_rgb, test_ds_rgb),
    "Wavelets Only": (model_wavelets, test_ds_wavelet),
    "RGB + Wavelets": (model_rgb_wavelets, test_ds_rgb_wavelet)
}

skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=random_state)
results = []

for name, (model, ds) in evals.items():
    print(f"\nEvaluating: {name}")
    probs = predict_probs(model, ds)          # streaming predict, memory-friendly
    if probs.shape[0] != n:
        raise ValueError(f"{name}: probs length {probs.shape[0]} != labels length {n}")

    fold_metrics = []
    for _, test_idx in skf.split(np.zeros(n), Y):
        y_true = Y[test_idx]
        y_proba = probs[test_idx]
        y_pred = (y_proba > threshold).astype(int)

        acc = metrics.accuracy_score(y_true, y_pred)
        prec = metrics.precision_score(y_true, y_pred, zero_division=0)
        rec = metrics.recall_score(y_true, y_pred, zero_division=0)
        f1 = metrics.f1_score(y_true, y_pred, zero_division=0)
        fpr, tpr, _ = metrics.roc_curve(y_true, y_proba)
        auc = metrics.auc(fpr, tpr)

        fold_metrics.append([acc, prec, rec, f1, auc])

    fold_metrics = np.array(fold_metrics)
    results.append([
        name,
        fold_metrics[:,0].mean(), fold_metrics[:,0].std(),
        fold_metrics[:,1].mean(), fold_metrics[:,1].std(),
        fold_metrics[:,2].mean(), fold_metrics[:,2].std(),
        fold_metrics[:,3].mean(), fold_metrics[:,3].std(),
        fold_metrics[:,4].mean(), fold_metrics[:,4].std(),
    ])

cols = ["Model",
        "Acc_mean","Acc_std","Prec_mean","Prec_std",
        "Rec_mean","Rec_std","F1_mean","F1_std","AUC_mean","AUC_std"]
results_df = pd.DataFrame(results, columns=cols)
print("\nFinal results:")
print(results_df)


In [None]:
# Define evaluation configurations for each model and dataset
eval_configurations = {
    "RGB Only": (model_rgb, test_ds_rgb),
    "Wavelets Only": (model_wavelets, test_ds_wavelet),
    "RGB + Wavelets": (model_rgb_wavelets, test_ds_rgb_wavelet)
}

N_BOOT = 1000
TH = 0.5

probs = {}
preds_bin = {}
Y_true = None

for name, (model, dataset) in eval_configurations.items():
    probs[name] = model.predict(dataset, verbose=0).flatten()
    preds_bin[name] = (probs[name] > TH).astype(int)
    if Y_true is None:
        Y_true = np.concatenate([y for x, y in dataset], axis=0).astype(int).flatten()

rows = []
for name in probs:
    p = preds_bin[name]
    rows.append({
        "Model": name,
        "Accuracy": metrics.accuracy_score(Y_true, p),
        "Precision": metrics.precision_score(Y_true, p, zero_division=0),
        "Recall": metrics.recall_score(Y_true, p, zero_division=0),
        "F1": metrics.f1_score(Y_true, p, zero_division=0),
        "AUC": metrics.roc_auc_score(Y_true, probs[name])
    })
global_df = pd.DataFrame(rows)

def bootstrap_ci(y_true, y_proba, stat_fn, n_boot=N_BOOT):
    n = len(y_true)
    vals = []
    for _ in range(n_boot):
        idx = np.random.choice(n, n, replace=True)
        yt = y_true[idx]
        yp = y_proba[idx]
        vals.append(stat_fn(yt, yp))
    lo, hi = np.percentile(vals, [2.5, 97.5])
    return np.mean(vals), lo, hi

def stat_accuracy(yt, yp):
    return metrics.accuracy_score(yt, (yp>TH).astype(int))
def stat_precision(yt, yp):
    return metrics.precision_score(yt, (yp>TH).astype(int), zero_division=0)
def stat_recall(yt, yp):
    return metrics.recall_score(yt, (yp>TH).astype(int), zero_division=0)
def stat_f1(yt, yp):
    return metrics.f1_score(yt, (yp>TH).astype(int), zero_division=0)
def stat_auc(yt, yp):
    return metrics.roc_auc_score(yt, yp)

ci_rows = []
for name in probs:
    mean_acc, lo_acc, hi_acc = bootstrap_ci(Y_true, probs[name], stat_accuracy)
    mean_pre, lo_pre, hi_pre = bootstrap_ci(Y_true, probs[name], stat_precision)
    mean_rec, lo_rec, hi_rec = bootstrap_ci(Y_true, probs[name], stat_recall)
    mean_f1, lo_f1, hi_f1 = bootstrap_ci(Y_true, probs[name], stat_f1)
    mean_auc, lo_auc, hi_auc = bootstrap_ci(Y_true, probs[name], stat_auc)
    ci_rows.append({
        "Model": name,
        "Acc_mean": mean_acc, "Acc_95ci": f"{lo_acc:.4f}-{hi_acc:.4f}",
        "Prec_mean": mean_pre, "Prec_95ci": f"{lo_pre:.4f}-{hi_pre:.4f}",
        "Rec_mean": mean_rec, "Rec_95ci": f"{lo_rec:.4f}-{hi_rec:.4f}",
        "F1_mean": mean_f1, "F1_95ci": f"{lo_f1:.4f}-{hi_f1:.4f}",
        "AUC_mean": mean_auc, "AUC_95ci": f"{lo_auc:.4f}-{hi_auc:.4f}"
    })

ci_df = pd.DataFrame(ci_rows)

# McNemar test: RGB Only vs RGB + Wavelets
a = preds_bin["RGB Only"]
b = preds_bin["RGB + Wavelets"]

b00 = np.sum((a==Y_true) & (b==Y_true))  # both correct
b01 = np.sum((a==Y_true) & (b!=Y_true))  # a correct, b wrong
b10 = np.sum((a!=Y_true) & (b==Y_true))  # a wrong, b correct
b11 = np.sum((a!=Y_true) & (b!=Y_true))  # both wrong
table = [[b00, b01],[b10,b11]]
mcnemar_res = mcnemar(table, exact=False, correction=True)

print("Global metrics:\n", global_df)
print("\nBootstrap CIs:\n", ci_df)
print("\nMcNemar (RGB Only vs RGB+Wavelets):")
print(" contingency table:", table)
print(" statistic=%.4f p=%.6f" % (mcnemar_res.statistic, mcnemar_res.pvalue))


# Grad-Cam



In [None]:
models = {"RGB Only": model_rgb, "Wavelets Only": model_wavelets, "RGB + Wavelets": model_rgb_wavelets}

LAST_CONV_LAYER_NAME = "conv5_block16_concat"

predictions = {}
ds_map = {"RGB Only": test_ds_rgb, "Wavelets Only": test_ds_wavelet, "RGB + Wavelets": test_ds_rgb_wavelet}

for name, model in models.items():
    print(f"Predicting for {name}...")
    pv_all = model.predict(ds_map[name], verbose=0).ravel()
    pc_all = (pv_all >= 0.5).astype(int)
    predictions[name] = {"pv": pv_all, "pc": pc_all}

In [None]:
import time

MODEL_R = "RGB Only"
MODEL_W = "Wavelets Only"
MODEL_RW = "RGB + Wavelets"
MAX_SAMPLES = 10

# Initialization and Data Loading

LAST_CONV_LAYER_NAME = "conv5_block16_concat"

test_dir = os.path.join(dataset_dir, "test")
paths, labels = [], []
class_names = sorted(os.listdir(test_dir))

for cls_idx, d in enumerate(class_names):
    folder = os.path.join(test_dir, d)
    for f in sorted(os.listdir(folder)):
        if f.lower().endswith((".jpg",".jpeg",".png")):
            paths.append(os.path.join(folder, f))
            labels.append(cls_idx)

paths = np.array(paths)
labels = np.array(labels)
y_true = labels.flatten() # True labels (0 or 1)

models = {MODEL_R: model_rgb, MODEL_W: model_wavelets, MODEL_RW: model_rgb_wavelets}

# Sample Selection
pc_R = predictions[MODEL_R]["pc"]
pc_W = predictions[MODEL_W]["pc"]
pc_RW = predictions[MODEL_RW]["pc"]

def select_samples(y_true, pc_A, pc_B, pred_A_correct, pred_B_correct, max_count):
    A_is_correct = (pc_A == y_true)
    B_is_correct = (pc_B == y_true)
    condition = (A_is_correct == pred_A_correct) & (B_is_correct == pred_B_correct)
    indices = np.where(condition)[0]

    if len(indices) > max_count:
        return random.sample(list(indices), max_count)
    return list(indices)

# Select samples for Figure A (RGB Only vs. Wavelets Only)
chosen_A = {}
chosen_A["W_Corrects_R"] = select_samples(y_true, pc_R, pc_W, pred_A_correct=False, pred_B_correct=True, max_count=MAX_SAMPLES)
chosen_A["R_Corrects_W"] = select_samples(y_true, pc_R, pc_W, pred_A_correct=True, pred_B_correct=False, max_count=MAX_SAMPLES)
chosen_A["Common_Failure"] = select_samples(y_true, pc_R, pc_W, pred_A_correct=False, pred_B_correct=False, max_count=MAX_SAMPLES)

# Select samples for Figure B (RGB Only vs. RGB + Wavelets)
chosen_B = {}
chosen_B["Fusion_Improved"] = select_samples(y_true, pc_R, pc_RW, pred_A_correct=False, pred_B_correct=True, max_count=MAX_SAMPLES)
chosen_B["Fusion_Degraded"] = select_samples(y_true, pc_R, pc_RW, pred_A_correct=True, pred_B_correct=False, max_count=MAX_SAMPLES)
chosen_B["Common_Success"] = select_samples(y_true, pc_R, pc_RW, pred_A_correct=True, pred_B_correct=True, max_count=MAX_SAMPLES)

# Consolidate all unique indices for Grad-CAM computation
all_indices_to_compute = set()
for indices in chosen_A.values(): all_indices_to_compute.update(indices)
for indices in chosen_B.values(): all_indices_to_compute.update(indices)

# Grad-CAM Pre-computation (Optimized Step)

def precompute_data_and_gradcam(indices, paths, models, predictions, y_true, IMG_SIZE, LAST_CONV_LAYER_NAME):
    precomputed_data = {}
    all_model_names = list(models.keys())

    start_time = time.time()

    for i, idx in enumerate(indices):
        p = paths[idx]
        orig_img = keras.utils.img_to_array(keras.utils.load_img(p, target_size=IMG_SIZE)).astype(np.uint8)

        gc_results = {}
        for name in all_model_names:
            m = models[name]
            arr = get_processed_array(p, name, IMG_SIZE)

            pv = predictions[name]["pv"][idx]
            pc = predictions[name]["pc"][idx]

            # Compute Grad-CAM only for the selected subset
            heat = make_gradcam_heatmap(arr, m, LAST_CONV_LAYER_NAME, pred_index=pc).numpy()

            gc_results[name] = {"pv": pv, "pc": pc, "cm_name": get_cm_name(y_true[idx], pc), "heatmap": heat}

        precomputed_data[idx] = {"orig": orig_img, "y_true": y_true[idx], "gc_data": gc_results}

    end_time = time.time()
    print(f"Grad-CAM pre-computation finished in {end_time - start_time:.2f} seconds.")
    return precomputed_data

# Execute the pre-computation
precomputed_results = precompute_data_and_gradcam(list(all_indices_to_compute), paths, models, predictions, y_true, IMG_SIZE, LAST_CONV_LAYER_NAME)

# Plotting

def generate_figure(chosen_dict, precomputed_data, models_to_plot, title):
    """Generates the figure using precomputed data."""

    all_rows_data = []
    row_titles = []
    for group_title, indices in chosen_dict.items():
        for idx in indices:
            if idx in precomputed_data: # Ensure we only plot computed samples
                all_rows_data.append(precomputed_data[idx])
                row_titles.append(group_title)

    # Calculate figure size dynamically
    fig, axes = plt.subplots(len(all_rows_data), 1 + len(models_to_plot), figsize=(4 * (1 + len(models_to_plot)), 3 * len(all_rows_data)))

    print(f"\nGenerating {title} with {len(all_rows_data)} rows...")

    for r, data in enumerate(all_rows_data):
        group_title = row_titles[r]
        orig = data['orig']
        y_true_val = data['y_true']

        # Original Image Column
        axes[r, 0].imshow(orig)
        axes[r, 0].set_title(f"{group_title}\nTrue={y_true_val}", fontsize=8)
        axes[r, 0].axis('off')

        # Grad-CAM Columns
        for c, name in enumerate(models_to_plot, start=1):
            gc_data = data['gc_data'][name]
            pv = gc_data['pv']
            pc = gc_data['pc']
            cm_name = gc_data['cm_name']
            heat = gc_data['heatmap']

            axes[r, c].imshow(overlay(orig, heat))
            axes[r, c].set_title(f"{name}\n{cm_name} Pred={pv:.3f} C={pc}", fontsize=8)
            axes[r, c].axis('off')

    fig.suptitle(title, fontsize=14)
    plt.tight_layout(rect=[0, 0.03, 1, 0.98])
    plt.show()

# Generate Figure A
generate_figure(chosen_A, precomputed_results,
                models_to_plot=[MODEL_R, MODEL_W],
                title="Figure A: Independent Feature Comparison (RGB Only vs. Wavelets Only)")

# Generate Figure B
generate_figure(chosen_B, precomputed_results,
                models_to_plot=[MODEL_R, MODEL_RW],
                title="Figure B: Feature Fusion Effect (RGB Only vs. RGB + Wavelets)")

In [None]:
def select_triple_samples(y_true, pc_R, pc_W, pc_RW, correct_R, correct_W, correct_RW, max_count):
    """Selects indices based on the correctness status of three models."""
    R_is_correct = (pc_R == y_true)
    W_is_correct = (pc_W == y_true)
    RW_is_correct = (pc_RW == y_true)

    condition = (R_is_correct == correct_R) & \
                (W_is_correct == correct_W) & \
                (RW_is_correct == correct_RW)

    indices = np.where(condition)[0]

    if len(indices) > max_count:
        return random.sample(list(indices), max_count)
    return list(indices)

chosen_comparison = {}

chosen_comparison["Fusion_Win"] = select_triple_samples(
    y_true, pc_R, pc_W, pc_RW,
    correct_R=False, correct_W=False, correct_RW=True,
    max_count=MAX_SAMPLES
)

chosen_comparison["Wavelets_Only_Win"] = select_triple_samples(
    y_true, pc_R, pc_W, pc_RW,
    correct_R=False, correct_W=True, correct_RW=False,
    max_count=MAX_SAMPLES
)

chosen_comparison["Common_Failure"] = select_triple_samples(
    y_true, pc_R, pc_W, pc_RW,
    correct_R=False, correct_W=False, correct_RW=False,
    max_count=MAX_SAMPLES
)


all_indices_to_compute = set()
for indices in chosen_comparison.values():
    all_indices_to_compute.update(indices)

print(f"--- Step 3: Selection complete. {len(all_indices_to_compute)} unique samples chosen for triple comparison. ---")

# Grad-CAM Pre-computation

# Execute the pre-computation
precomputed_results = precompute_data_and_gradcam(
    list(all_indices_to_compute), paths, models, predictions, y_true, IMG_SIZE, LAST_CONV_LAYER_NAME
)

# Plotting (Single Figure for Three Models)


def generate_triple_figure(chosen_dict, precomputed_data, models_to_plot, title):

    all_rows_data = []
    row_titles = []
    for group_title, indices in chosen_dict.items():
        for idx in indices:
            if idx in precomputed_data:
                all_rows_data.append(precomputed_data[idx])
                row_titles.append(group_title)

    if not all_rows_data:
        print(f"Warning: No samples found for {title}.")
        return

    # Figure size: 1 Original Image + 3 Models = 4 Columns
    fig, axes = plt.subplots(len(all_rows_data), 1 + len(models_to_plot),
                             figsize=(4 * (1 + len(models_to_plot)), 3 * len(all_rows_data)))

    print(f"\nGenerating {title} with {len(all_rows_data)} rows...")

    for r, data in enumerate(all_rows_data):
        group_title = row_titles[r]
        orig = data['orig']
        y_true_val = data['y_true']

        # Original Image Column
        axes[r, 0].imshow(orig)
        axes[r, 0].set_title(f"{group_title}\nTrue={y_true_val}", fontsize=8)
        axes[r, 0].axis('off')

        # Grad-CAM Columns (3 models)
        for c, name in enumerate(models_to_plot, start=1):
            gc_data = data['gc_data'][name]
            pv = gc_data['pv']
            pc = gc_data['pc']
            cm_name = gc_data['cm_name']
            heat = gc_data['heatmap']

            axes[r, c].imshow(overlay(orig, heat))
            axes[r, c].set_title(f"{name}\n{cm_name} Pred={pv:.3f} C={pc}", fontsize=8)
            axes[r, c].axis('off')

    fig.suptitle(title, fontsize=14)
    plt.tight_layout(rect=[0, 0.03, 1, 0.98])
    plt.show()

# Generate the Triple Comparison Figure
generate_triple_figure(chosen_comparison, precomputed_results,
                       models_to_plot=[MODEL_R, MODEL_W, MODEL_RW],
                       title="Figure C: Three-Model Comparison (Fusion Success, Degradation, and Failure)")


In [None]:
def get_processed_array(img_path: str, model_name: str, size: tuple):
    img = keras.utils.load_img(img_path, target_size=size)
    array = keras.utils.img_to_array(img)
    raw_tensor = tf.constant(array[np.newaxis, ...], dtype=tf.uint8)

    if 'Wavelets Only' in model_name:
        processed_array, _ = preprocess_wavelet_only_valid(raw_tensor, tf.constant([0]))
    elif 'RGB + Wavelets' in model_name:
        processed_array, _ = preprocess_rgb_wavelet_valid(raw_tensor, tf.constant([0]))
    else:
        processed_array, _ = preprocess_rgb_only_valid(raw_tensor, tf.constant([0]))

    return processed_array


def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
    """Computes Grad-CAM heatmap, returns a TensorFlow Tensor."""

    _ = model(img_array)

    grad_model = keras.models.Model(
        [model.inputs],
        [model.get_layer(last_conv_layer_name).output, model.output]
    )

    with tf.GradientTape() as tape:
        tape.watch(img_array)
        conv_outputs, predictions = grad_model(img_array)

        if pred_index is None:
            pred_index = tf.cast(tf.round(predictions[0]), tf.int32)

        if pred_index == 1:
            class_channel = predictions[:, 0]
        else:
            class_channel = 1.0 - predictions[:, 0]

        class_channel = tf.expand_dims(class_channel, axis=-1)

    grads = tape.gradient(class_channel, conv_outputs)
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))

    conv_outputs = conv_outputs[0]

    heatmap = conv_outputs @ tf.cast(pooled_grads, conv_outputs.dtype)[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)

    heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + 1e-7)
    return heatmap

def overlay(orig, heat, alpha=0.5):
    h = np.uint8(255*np.clip(heat,0,1))
    cmap = np.uint8(cm.get_cmap("jet")(np.arange(256))[:,:3]*255)[h]
    cmap = np.array(Image.fromarray(cmap).resize((orig.shape[1], orig.shape[0]), Image.BILINEAR))
    return np.clip(cmap*alpha + orig*(1-alpha),0,255).astype(np.uint8)

# Helper to get Confusion Matrix name (CM name)
def get_cm_name(y_true, pc):
    if y_true == 1 and pc == 1: return "TP"
    if y_true == 0 and pc == 1: return "FP"
    if y_true == 0 and pc == 0: return "TN"
    if y_true == 1 and pc == 0: return "FN"
    return "Unknown"

In [None]:
test_dir = os.path.join(dataset_dir, "test")
paths, labels = [], []
for cls, d in enumerate(sorted(os.listdir(test_dir))):
    folder = os.path.join(test_dir, d)
    for f in sorted(os.listdir(folder)):
        if f.lower().endswith((".jpg",".jpeg",".png")):
            paths.append(os.path.join(folder, f)); labels.append(cls)
paths = np.array(paths); labels = np.array(labels)
y = labels
chosen = {}
while len(chosen) < 4:
    idx = random.randint(0, len(paths)-1)
    arr_ref = get_processed_array(paths[idx], REF, IMG_SIZE)
    pv_ref = float(models[REF].predict(arr_ref, verbose=0).ravel()[0])
    pc_ref = int(pv_ref >= 0.5)
    if 'TP' not in chosen and y[idx]==1 and pc_ref==1: chosen['TP'] = idx
    if 'TN' not in chosen and y[idx]==0 and pc_ref==0: chosen['TN'] = idx
    if 'FP' not in chosen and y[idx]==0 and pc_ref==1: chosen['FP'] = idx
    if 'FN' not in chosen and y[idx]==1 and pc_ref==0: chosen['FN'] = idx

rows = ['TP','FP','TN','FN']
models_list = list(models.keys())

fig, axes = plt.subplots(4, 1+len(models_list), figsize=(4*(1+len(models_list)),16))
for r, key in enumerate(rows):
    idx = chosen[key]
    p = paths[idx]
    orig = keras.utils.img_to_array(keras.utils.load_img(p, target_size=IMG_SIZE)).astype(np.uint8)
    axes[r,0].imshow(orig)
    axes[r,0].set_title(f"{key}\nTrue={y[idx]}")
    axes[r,0].axis('off')

    for c, name in enumerate(models_list, start=1):
        m = models[name]
        arr = get_processed_array(p, name, IMG_SIZE)
        pv = float(m.predict(arr, verbose=0).ravel()[0])
        pc = int(pv>=0.5)
        heat = make_gradcam_heatmap(arr, m, LAST_CONV_LAYER_NAME, pred_index=pc).numpy()
        axes[r,c].imshow(overlay(orig, heat))
        axes[r,c].set_title(f"{name}\nPred={pv:.3f} C={pc}")
        axes[r,c].axis('off')

plt.tight_layout()
plt.show()
