In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fake Image Detection - Model Training\n",
    "\n",
    "This notebook contains the complete pipeline for training deep learning models to detect fake/manipulated images.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Import Libraries and Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from tqdm import tqdm\n",
    "import cv2\n",
    "from PIL import Image\n",
    "import albumentations as A\n",
    "from albumentations.pytorch import ToTensorV2\n",
    "\n",
    "# Deep Learning Libraries\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers, models, optimizers, callbacks\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.models as models\n",
    "\n",
    "# Sklearn\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve\n",
    "\n",
    "# Add project root to path\n",
    "sys.path.append('..')\n",
    "from utils.data_preprocessing import *\n",
    "from utils.visualization import *\n",
    "from utils.metrics import *\n",
    "from models.cnn_model import *\n",
    "from models.resnet_model import *\n",
    "\n",
    "# Set random seeds for reproducibility\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)\n",
    "torch.manual_seed(42)\n",
    "\n",
    "# GPU Configuration\n",
    "physical_devices = tf.config.experimental.list_physical_devices('GPU')\n",
    "if len(physical_devices) > 0:\n",
    "    tf.config.experimental.set_memory_growth(physical_devices[0], True)\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Configuration and Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configuration\n",
    "CONFIG = {\n",
    "    'DATA_DIR': '../data',\n",
    "    'REAL_DIR': '../data/real',\n",
    "    'FAKE_DIR': '../data/fake',\n",
    "    'TEST_DIR': '../data/test',\n",
    "    'MODEL_SAVE_DIR': '../models/saved_models',\n",
    "    \n",
    "    # Image parameters\n",
    "    'IMG_SIZE': (224, 224),\n",
    "    'BATCH_SIZE': 32,\n",
    "    'CHANNELS': 3,\n",
    "    \n",
    "    # Training parameters\n",
    "    'EPOCHS': 50,\n",
    "    'LEARNING_RATE': 0.001,\n",
    "    'VALIDATION_SPLIT': 0.2,\n",
    "    'TEST_SPLIT': 0.1,\n",
    "    \n",
    "    # Model parameters\n",
    "    'DROPOUT_RATE': 0.5,\n",
    "    'L2_REGULARIZATION': 0.01,\n",
    "    \n",
    "    # Callbacks\n",
    "    'EARLY_STOPPING_PATIENCE': 10,\n",
    "    'REDUCE_LR_PATIENCE': 5,\n",
    "    'REDUCE_LR_FACTOR': 0.5\n",
    "}\n",
    "\n",
    "# Create necessary directories\n",
    "os.makedirs(CONFIG['MODEL_SAVE_DIR'], exist_ok=True)\n",
    "\n",
    "print(\"Configuration loaded successfully!\")\n",
    "print(f\"Image size: {CONFIG['IMG_SIZE']}\")\n",
    "print(f\"Batch size: {CONFIG['BATCH_SIZE']}\")\n",
    "print(f\"Epochs: {CONFIG['EPOCHS']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Data Loading and Preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load and preprocess data\n",
    "def load_dataset(real_dir, fake_dir, img_size):\n",
    "    \"\"\"\n",
    "    Load images from real and fake directories\n",
    "    \"\"\"\n",
    "    images = []\n",
    "    labels = []\n",
    "    filepaths = []\n",
    "    \n",
    "    # Load real images (label = 1)\n",
    "    real_files = [f for f in os.listdir(real_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
    "    print(f\"Found {len(real_files)} real images\")\n",
    "    \n",
    "    for filename in tqdm(real_files, desc=\"Loading real images\"):\n",
    "        filepath = os.path.join(real_dir, filename)\n",
    "        img = load_and_preprocess_image(filepath, img_size)\n",
    "        if img is not None:\n",
    "            images.append(img)\n",
    "            labels.append(1)  # Real = 1\n",
    "            filepaths.append(filepath)\n",
    "    \n",
    "    # Load fake images (label = 0)\n",
    "    fake_files = [f for f in os.listdir(fake_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
    "    print(f\"Found {len(fake_files)} fake images\")\n",
    "    \n",
    "    for filename in tqdm(fake_files, desc=\"Loading fake images\"):\n",
    "        filepath = os.path.join(fake_dir, filename)\n",
    "        img = load_and_preprocess_image(filepath, img_size)\n",
    "        if img is not None:\n",
    "            images.append(img)\n",
    "            labels.append(0)  # Fake = 0\n",
    "            filepaths.append(filepath)\n",
    "    \n",
    "    return np.array(images), np.array(labels), filepaths\n",
    "\n",
    "# Load the dataset\n",
    "print(\"Loading dataset...\")\n",
    "X, y, filepaths = load_dataset(CONFIG['REAL_DIR'], CONFIG['FAKE_DIR'], CONFIG['IMG_SIZE'])\n",
    "\n",
    "print(f\"Total images loaded: {len(X)}\")\n",
    "print(f\"Real images: {np.sum(y)}\")\n",
    "print(f\"Fake images: {len(y) - np.sum(y)}\")\n",
    "print(f\"Image shape: {X.shape}\")\n",
    "print(f\"Label distribution: {np.bincount(y)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Data Exploration and Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize sample images\n",
    "plot_sample_images(X, y, num_samples=8)\n",
    "\n",
    "# Plot class distribution\n",
    "plt.figure(figsize=(8, 6))\n",
    "labels_text = ['Fake', 'Real']\n",
    "counts = np.bincount(y)\n",
    "plt.bar(labels_text, counts, color=['red', 'green'], alpha=0.7)\n",
    "plt.title('Class Distribution')\n",
    "plt.ylabel('Number of Images')\n",
    "for i, count in enumerate(counts):\n",
    "    plt.text(i, count + 10, str(count), ha='center', va='bottom')\n",
    "plt.show()\n",
    "\n",
    "# Calculate and display basic statistics\n",
    "print(f\"\\nDataset Statistics:\")\n",
    "print(f\"Total samples: {len(X)}\")\n",
    "print(f\"Image dimensions: {X.shape[1:]}\")\n",
    "print(f\"Pixel value range: [{X.min():.3f}, {X.max():.3f}]\")\n",
    "print(f\"Mean pixel value: {X.mean():.3f}\")\n",
    "print(f\"Std pixel value: {X.std():.3f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Data Augmentation Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TensorFlow Data Augmentation\n",
    "tf_train_datagen = ImageDataGenerator(\n",
    "    rotation_range=20,\n",
    "    width_shift_range=0.2,\n",
    "    height_shift_range=0.2,\n",
    "    horizontal_flip=True,\n",
    "    zoom_range=0.2,\n",
    "    shear_range=0.1,\n",
    "    brightness_range=[0.8, 1.2],\n",
    "    validation_split=CONFIG['VALIDATION_SPLIT']\n",
    ")\n",
    "\n",
    "tf_val_datagen = ImageDataGenerator(\n",
    "    validation_split=CONFIG['VALIDATION_SPLIT']\n",
    ")\n",
    "\n",
    "# PyTorch Data Augmentation using Albumentations\n",
    "train_transform = A.Compose([\n",
    "    A.Resize(CONFIG['IMG_SIZE'][0], CONFIG['IMG_SIZE'][1]),\n",
    "    A.HorizontalFlip(p=0.5),\n",
    "    A.VerticalFlip(p=0.2),\n",
    "    A.Rotate(limit=20, p=0.5),\n",
    "    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),\n",
    "    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5),\n",
    "    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "    ToTensorV2()\n",
    "])\n",
    "\n",
    "val_transform = A.Compose([\n",
    "    A.Resize(CONFIG['IMG_SIZE'][0], CONFIG['IMG_SIZE'][1]),\n",
    "    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "    ToTensorV2()\n",
    "])\n",
    "\n",
    "print(\"Data augmentation pipelines created successfully!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Data Splitting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split the data\n",
    "X_temp, X_test, y_temp, y_test = train_test_split(\n",
    "    X, y, test_size=CONFIG['TEST_SPLIT'], \n",
    "    stratify=y, random_state=42\n",
    ")\n",
    "\n",
    "X_train, X_val, y_train, y_val = train_test_split(\n",
    "    X_temp, y_temp, test_size=CONFIG['VALIDATION_SPLIT']/(1-CONFIG['TEST_SPLIT']), \n",
    "    stratify=y_temp, random_state=42\n",
    ")\n",
    "\n",
    "print(f\"Training set size: {len(X_train)}\")\n",
    "print(f\"Validation set size: {len(X_val)}\")\n",
    "print(f\"Test set size: {len(X_test)}\")\n",
    "\n",
    "print(f\"\\nClass distribution in training set:\")\n",
    "print(f\"Real: {np.sum(y_train)}, Fake: {len(y_train) - np.sum(y_train)}\")\n",
    "print(f\"\\nClass distribution in validation set:\")\n",
    "print(f\"Real: {np.sum(y_val)}, Fake: {len(y_val) - np.sum(y_val)}\")\n",
    "print(f\"\\nClass distribution in test set:\")\n",
    "print(f\"Real: {np.sum(y_test)}, Fake: {len(y_test) - np.sum(y_test)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. TensorFlow/Keras Model Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build CNN Model\n",
    "def build_cnn_model(input_shape, num_classes=1):\n",
    "    model = models.Sequential([\n",
    "        # First Convolutional Block\n",
    "        layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),\n",
    "        layers.BatchNormalization(),\n",
    "        layers.MaxPooling2D((2, 2)),\n",
    "        layers.Dropout(0.25),\n",
    "        \n",
    "        # Second Convolutional Block\n",
    "        layers.Conv2D(64, (3, 3), activation='relu'),\n",
    "        layers.BatchNormalization(),\n",
    "        layers.MaxPooling2D((2, 2)),\n",
    "        layers.Dropout(0.25),\n",
    "        \n",
    "        # Third Convolutional Block\n",
    "        layers.Conv2D(128, (3, 3), activation='relu'),\n",
    "        layers.BatchNormalization(),\n",
    "        layers.MaxPooling2D((2, 2)),\n",
    "        layers.Dropout(0.25),\n",
    "        \n",
    "        # Fourth Convolutional Block\n",
    "        layers.Conv2D(256, (3, 3), activation='relu'),\n",
    "        layers.BatchNormalization(),\n",
    "        layers.MaxPooling2D((2, 2)),\n",
    "        layers.Dropout(0.25),\n",
    "        \n",
    "        # Dense layers\n",
    "        layers.Flatten(),\n",
    "        layers.Dense(512, activation='relu'),\n",
    "        layers.BatchNormalization(),\n",
    "        layers.Dropout(CONFIG['DROPOUT_RATE']),\n",
    "        layers.Dense(256, activation='relu'),\n",
    "        layers.Dropout(CONFIG['DROPOUT_RATE']),\n",
    "        layers.Dense(num_classes, activation='sigmoid')\n",
    "    ])\n",
    "    \n",
    "    return model\n",
    "\n",
    "# Build the model\n",
    "input_shape = (*CONFIG['IMG_SIZE'], CONFIG['CHANNELS'])\n",
    "cnn_model = build_cnn_model(input_shape)\n",
    "\n",
    "# Compile the model\n",
    "cnn_model.compile(\n",
    "    optimizer=optimizers.Adam(learning_rate=CONFIG['LEARNING_RATE']),\n",
    "    loss='binary_crossentropy',\n",
    "    metrics=['accuracy', 'precision', 'recall']\n",
    ")\n",
    "\n",
    "# Model summary\n",
    "cnn_model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup callbacks\n",
    "callbacks_list = [\n",
    "    callbacks.EarlyStopping(\n",
    "        monitor='val_loss',\n",
    "        patience=CONFIG['EARLY_STOPPING_PATIENCE'],\n",
    "        restore_best_weights=True,\n",
    "        verbose=1\n",
    "    ),\n",
    "    callbacks.ReduceLROnPlateau(\n",
    "        monitor='val_loss',\n",
    "        factor=CONFIG['REDUCE_LR_FACTOR'],\n",
    "        patience=CONFIG['REDUCE_LR_PATIENCE'],\n",
    "        verbose=1,\n",
    "        min_lr=1e-7\n",
    "    ),\n",
    "    callbacks.ModelCheckpoint(\n",
    "        filepath=os.path.join(CONFIG['MODEL_SAVE_DIR'], 'best_cnn_model.h5'),\n",
    "        monitor='val_accuracy',\n",
    "        save_best_only=True,\n",
    "        verbose=1\n",
    "    )\n",
    "]\n",
    "\n",
    "print(\"Callbacks configured successfully!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train the CNN model\n",
    "print(\"Starting CNN model training...\")\n",
    "\n",
    "cnn_history = cnn_model.fit(\n",
    "    X_train, y_train,\n",
    "    batch_size=CONFIG['BATCH_SIZE'],\n",
    "    epochs=CONFIG['EPOCHS'],\n",
    "    validation_data=(X_val, y_val),\n",
    "    callbacks=callbacks_list,\n",
    "    verbose=1\n",
    ")\n",
    "\n",
    "print(\"CNN model training completed!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. Transfer Learning with ResNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build ResNet-based model\n",
    "def build_resnet_model(input_shape, num_classes=1):\n",
    "    # Load pre-trained ResNet50\n",
    "    base_model = tf.keras.applications.ResNet50(\n",
    "        weights='imagenet',\n",
    "        include_top=False,\n",
    "        input_shape=input_shape\n",
    "    )\n",
    "    \n",
    "    # Freeze base model layers\n",
    "    base_model.trainable = False\n",
    "    \n",
    "    # Add custom layers\n",
    "    model = models.Sequential([\n",
    "        base_model,\n",
    "        layers.GlobalAveragePooling2D(),\n",
    "        layers.BatchNormalization(),\n",
    "        layers.Dropout(0.5),\n",
    "        layers.Dense(512, activation='relu'),\n",
    "        layers.BatchNormalization(),\n",
    "        layers.Dropout(0.3),\n",
    "        layers.Dense(256, activation='relu'),\n",
    "        layers.Dropout(0.2),\n",
    "        layers.Dense(num_classes, activation='sigmoid')\n",
    "    ])\n",
    "    \n",
    "    return model, base_model\n",
    "\n",
    "# Build ResNet model\n",
    "resnet_model, base_model = build_resnet_model(input_shape)\n",
    "\n",
    "# Compile the model\n",
    "resnet_model.compile(\n",
    "    optimizer=optimizers.Adam(learning_rate=CONFIG['LEARNING_RATE']),\n",
    "    loss='binary_crossentropy',\n",
    "    metrics=['accuracy', 'precision', 'recall']\n",
    ")\n",
    "\n",
    "# Model summary\n",
    "resnet_model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup callbacks for ResNet\n",
    "resnet_callbacks = [\n",
    "    callbacks.EarlyStopping(\n",
    "        monitor='val_loss',\n",
    "        patience=CONFIG['EARLY_STOPPING_PATIENCE'],\n",
    "        restore_best_weights=True,\n",
    "        verbose=1\n",
    "    ),\n",
    "    callbacks.ReduceLROnPlateau(\n",
    "        monitor='val_loss',\n",
    "        factor=CONFIG['REDUCE_LR_FACTOR'],\n",
    "        patience=CONFIG['REDUCE_LR_PATIENCE'],\n",
    "        verbose=1,\n",
    "        min_lr=1e-7\n",
    "    ),\n",
    "    callbacks.ModelCheckpoint(\n",
    "        filepath=os.path.join(CONFIG['MODEL_SAVE_DIR'], 'best_resnet_model.h5'),\n",
    "        monitor='val_accuracy',\n",
    "        save_best_only=True,\n",
    "        verbose=1\n",
    "    )\n",
    "]\n",
    "\n",
    "# Train ResNet model\n",
    "print(\"Starting ResNet model training...\")\n",
    "\n",
    "resnet_history = resnet_model.fit(\n",
    "    X_train, y_train,\n",
    "    batch_size=CONFIG['BATCH_SIZE'],\n",
    "    epochs=CONFIG['EPOCHS'],\n",
    "    validation_data=(X_val, y_val),\n",
    "    callbacks=resnet_callbacks,\n",
    "    verbose=1\n",
    ")\n",
    "\n",
    "print(\"ResNet model training completed!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9. Fine-tuning ResNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fine-tune the ResNet model\n",
    "print(\"Starting fine-tuning of ResNet model...\")\n",
    "\n",
    "# Unfreeze the top layers of the base model\n",
    "base_model.trainable = True\n",
    "\n",
    "# Fine-tune from this layer onwards\n",
    "fine_tune_at = 100\n",
    "\n",
    "# Freeze all the layers before the `fine_tune_at` layer\n",
    "for layer in base_model.layers[:fine_tune_at]:\n",
    "    layer.trainable = False\n",
    "\n",
    "# Recompile with a lower learning rate\n",
    "resnet_model.compile(\n",
    "    optimizer=optimizers.Adam(learning_rate=CONFIG['LEARNING_RATE']/10),\n",
    "    loss='binary_crossentropy',\n",
    "    metrics=['accuracy', 'precision', 'recall']\n",
    ")\n",
    "\n",
    "# Fine-tune callbacks\n",
    "finetune_callbacks = [\n",
    "    callbacks.EarlyStopping(\n",
    "        monitor='val_loss',\n",
    "        patience=5,\n",
    "        restore_best_weights=True,\n",
    "        verbose=1\n",
    "    ),\n",
    "    callbacks.ModelCheckpoint(\n",
    "        filepath=os.path.join(CONFIG['MODEL_SAVE_DIR'], 'best_resnet_finetuned.h5'),\n",
    "        monitor='val_accuracy',\n",
    "        save_best_only=True,\n",
    "        verbose=1\n",
    "    )\n",
    "]\n",
    "\n",
    "# Continue training with fine-tuning\n",
    "finetune_epochs = 10\n",
    "total_epochs = len(resnet_history.history['loss']) + finetune_epochs\n",
    "\n",
    "resnet_finetune_history = resnet_model.fit(\n",
    "    X_train, y_train,\n",
    "    batch_size=CONFIG['BATCH_SIZE'],\n",
    "    epochs=total_epochs,\n",
    "    initial_epoch=len(resnet_history.history['loss']),\n",
    "    validation_data=(X_val, y_val),\n",
    "    callbacks=finetune_callbacks,\n",
    "    verbose=1\n",
    ")\n",
    "\n",
    "print(\"Fine-tuning completed!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10. PyTorch Model Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# PyTorch Dataset Class\n",
    "class FakeImageDataset(Dataset):\n",
    "    def __init__(self, images, labels, transform=None):\n",
    "        self.images = images\n",
    "        self.labels = labels\n",
    "        self.transform = transform\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.images)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        image = self.images[idx]\n",
    "        label = self.labels[idx]\n",
    "        \n",
    "        if self.transform:\n",
    "            # Convert to uint8 for albumentations\n",
    "            image = (image * 255).astype(np.uint8)\n",
    "            transformed = self.transform(image=image)\n",
    "            image = transformed['image']\n",
    "        else:\n",
    "            image = torch.FloatTensor(image).permute(2, 0, 1)\n",
    "        \n",
    "        return image, torch.FloatTensor([label])\n",
    "\n",
    "# Create PyTorch datasets\n",
    "train_dataset = FakeImageDataset(X_train, y_train, transform=train_transform)\n",
    "val_dataset = FakeImageDataset(X_val, y_val, transform=val_transform)\n",
    "test_dataset = FakeImageDataset(X_test, y_test, transform=val_transform)\n",
    "\n",
    "# Create data loaders\n",
    "train_loader = DataLoader(train_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=True)\n",
    "val_loader = DataLoader(val_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=False)\n",
    "test_loader = DataLoader(test_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=False)\n",
    "\n",
    "print(f\"PyTorch datasets created:\")\n",
    "print(f\"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# PyTorch ResNet Model\n",
    "class PyTorchResNetModel(nn.Module):\n",
    "    def __init__(self, num_classes=1):\n",
    "        super(PyTorchResNetModel, self).__init__()\n",
    "        self.resnet = models.resnet50(pretrained=True)\n",
    "        \n",
    "        # Freeze early layers\n",
    "        for param in self.resnet.parameters():\n",
    "            param.requires_grad = False\n",
    "        \n",
    "        # Unfreeze last few layers\n",
    "        for param in self.resnet.layer4.parameters():\n",
    "            param.requires_grad = True\n",
    "        \n",
    "        # Replace the classifier\n",
    "        num_features = self.resnet.fc.in_features\n",
    "        self.resnet.fc = nn.Sequential(\n",
    "            nn.Dropout(0.5),\n",
    "            nn.Linear(num_features, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.BatchNorm1d(512),\n",
    "            nn.Dropout(0.3),\n",
    "            nn.Linear(512, 256),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(0.2),\n",
    "            nn.Linear(256, num_classes),\n",
    "            nn.Sigmoid()\n",
    "        )\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.resnet(x)\n",
    "\n",
    "# Initialize PyTorch model\n",
    "pytorch_model = PyTorchResNetModel().to(device)\n",
    "criterion = nn.BCELoss()\n",
    "optimizer = optim.Adam(pytorch_model.parameters(), lr=CONFIG['LEARNING_RATE'])\n",
    "scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)\n",
    "\n",
    "print(\"PyTorch model initialized!\")\n",
    "print(f\"Model parameters: {sum(p.numel() for p in pytorch_model.parameters() if p.requires_grad)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# PyTorch training function\n",
    "def train_pytorch_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs):\n",
    "    train_losses = []\n",
    "    val_losses = []\n",
    "    train_accuracies = []\n",
    "    val_accuracies = []\n",
    "    \n",
    "    best_val_acc = 0.0\n",
    "    \n",
    "    for epoch in range(num_epochs):\n",
    "        # Training phase\n",
    "        model.train()\n",
    "        running_loss = 0.0\n",
    "        correct = 0\n",
    "        total = 0\n",
    "        \n",
    "        for batch_idx, (data, target) in enumerate(tqdm(train_loader, desc=f'Epoch {epoch+1}/50')):\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "            output = model(data)\n",
    "            loss = criterion(output, target)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "            running_loss += loss.item()\n",
    "            predicted = (output > 0.5).float()\n",
    "            total += target.size(0)\n",
    "            correct += (predicted == target).sum().item()\n",
    "        \n",
    "        train_loss = running_loss / len(train_loader)\n",
    "        train_acc = 100. * correct / total\n",
    "        \n",
    "        # Validation phase\n",
    "        model.eval()\n",
    "        val_loss = 0.0\n",
    "        correct = 0\n",
    "        total = 0\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            for data, target in val_loader:\n",
    "                data, target = data.to(device), target.to(device)\n",
    "                output = model(data)\n",
    "                val_loss += criterion(output, target).item()\n",
    "                predicted = (output > 0.5).float()\n",
    "                total += target.size(0)\n",
    "                correct += (predicted == target).sum().item()\n",
    "        \n",
    "        val_loss /= len(val_loader)\n",
    "        val_acc = 100. * correct / total\n",
    "        \n",
    "        # Store metrics\n",
    "        train_losses.append(train_loss)\n",
    "        val_losses.append(val_loss)\n",
    "        train_accuracies.append(train_acc)\n",
    "        val_accuracies.append(val_acc)\n",
    "        \n",
    "        # Learning rate scheduling\n",
    "        scheduler.step(val_loss)\n",
    "        \n",
    "        # Save best model\n",
    "        if val_acc > best_val_acc:\n",
    "            best_val_acc = val_acc\n",
    "            torch.save(model.state_dict(), os.path.join(CONFIG['MODEL_SAVE_DIR'], 'best_pytorch_model.pth'))\n",
    "        \n",
    "        print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '\n",
    "              f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')\n",
    "    \n",
    "    return {\n",
    "        'train_loss': train_losses,\n",
    "        'val_loss': val_losses,\n",
    "        'train_acc': train_accuracies,\n",
    "        'val_acc': val_accuracies\n",
    "    }\n",
    "\n",
    "# Train PyTorch model\n",
    "print(\"Starting PyTorch model training...\")\n",
    "pytorch_history = train_pytorch_model(\n",
    "    pytorch_model, train_loader, val_loader, \n",
    "    criterion, optimizer, scheduler, CONFIG['EPOCHS']\n",
    ")\n",
    "print(\"PyTorch model training completed!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 11. Model Evaluation and Comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate TensorFlow models\n",
    "print(\"Evaluating TensorFlow models...\")\n",
    "\n",
    "# CNN Model Evaluation\n",
    "cnn_test_loss, cnn_test_acc, cnn_test_precision, cnn_test_recall = cnn_model.evaluate(X_test, y_test, verbose=0)\n",
    "cnn_predictions = cnn_model.predict(X_test)\n",
    "cnn_pred_binary = (cnn_predictions > 0.5).astype(int).flatten()\n",
    "\n",
    "# ResNet Model Evaluation\n",
    "resnet_test_loss, resnet_test_acc, resnet_test_precision, resnet_test_recall = resnet_model.evaluate(X_test, y_test, verbose=0)\n",
    "resnet_predictions = resnet_model.predict(X_test)\n",
    "resnet_pred_binary = (resnet_predictions > 0.5).astype(int).flatten()\n",
    "\n",
    "print(f\"CNN Model - Test Accuracy: {cnn_test_acc:.4f}\")\n",
    "print(f\"ResNet Model - Test Accuracy: {resnet_test_acc:.4f}\")\n",
    "\n",
    "# Calculate additional metrics\n",
    "from sklearn.metrics import f1_score, roc_auc_score\n",
    "\n",
    "cnn_f1 = f1_score(y_test, cnn_pred_binary)\n",
    "cnn_auc = roc_auc_score(y_test, cnn_predictions)\n",
    "\n",
    "resnet_f1 = f1_score(y_test, resnet_pred_binary)\n",
    "resnet_auc = roc_auc_score(y_test, resnet_predictions)\n",
    "\n",
    "print(f\"CNN Model - F1 Score: {cnn_f1:.4f}, AUC: {cnn_auc:.4f}\")\n",
    "print(f\"ResNet Model - F1 Score: {resnet_f1:.4f}, AUC: {resnet_auc:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate PyTorch model\n",
    "def evaluate_pytorch_model(model, test_loader):\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    all_predictions = []\n",
    "    all_targets = []\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model(data)\n",
    "            predicted = (output > 0.5).float()\n",
    "            total += target.size(0)\n",
    "            correct += (predicted == target).sum().item()\n",
    "            \n",
    "            all_predictions.extend(output.cpu().numpy())\n",
    "            all_targets.extend(target.cpu().numpy())\n",
    "    \n",
    "    accuracy = 100. * correct / total\n",
    "    return accuracy, np.array(all_predictions), np.array(all_targets).flatten()\n",
    "\n",
    "pytorch_acc, pytorch_predictions, pytorch_targets = evaluate_pytorch_model(pytorch_model, test_loader)\n",
    "pytorch_pred_binary = (pytorch_predictions > 0.5).astype(int).flatten()\n",
    "\n",
    "pytorch_f1 = f1_score(pytorch_targets, pytorch_pred_binary)\n",
    "pytorch_auc = roc_auc_score(pytorch_targets, pytorch_predictions)\n",
    "\n",
    "print(f\"PyTorch Model - Test Accuracy: {pytorch_acc:.2f}%\")\n",
    "print(f\"PyTorch Model - F1 Score: {pytorch_f1:.4f}, AUC: {pytorch_auc:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create comprehensive evaluation report\n",
    "def create_evaluation_report():\n",
    "    models_performance = {\n",
    "        'Model': ['CNN', 'ResNet', 'PyTorch ResNet'],\n",
    "        'Accuracy': [cnn_test_acc, resnet_test_acc, pytorch_acc/100],\n",
    "        'Precision': [cnn_test_precision, resnet_test_precision, None],\n",
    "        'Recall': [cnn_test_recall, resnet_test_recall, None],\n",
    "        'F1-Score': [cnn_f1, resnet_f1, pytorch_f1],\n",
    "        'AUC': [cnn_auc, resnet_auc, pytorch_auc]\n",
    "    }\n",
    "    \n",
    "    df_performance = pd.DataFrame(models_performance)\n",
    "    print(\"\\n=== Model Performance Comparison ===\")\n",
    "    print(df_performance.to_string(index=False))\n",
    "    \n",
    "    return df_performance\n",
    "\n",
    "performance_df = create_evaluation_report()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 12. Visualization of Training History"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot training history\n",
    "def plot_training_history(histories, model_names):\n",
    "    fig, axes = plt.subplots(2, 2, figsize=(15, 10))\n",
    "    \n",
    "    # Plot training & validation accuracy\n",
    "    axes[0, 0].set_title('Model Accuracy')\n",
    "    for i, (history, name) in enumerate(zip(histories, model_names)):\n",
    "        if 'accuracy' in history:\n",
    "            axes[0, 0].plot(history['accuracy'], label=f'{name} Train')\n",
    "            axes[0, 0].plot(history['val_accuracy'], label=f'{name} Val')\n",
    "        elif 'train_acc' in history:\n",
    "            axes[0, 0].plot([acc/100 for acc in history['train_acc']], label=f'{name} Train')\n",
    "            axes[0, 0].plot([acc/100 for acc in history['val_acc']], label=f'{name} Val')\n",
    "    axes[0, 0].set_ylabel('Accuracy')\n",
    "    axes[0, 0].set_xlabel('Epoch')\n",
    "    axes[0, 0].legend()\n",
    "    \n",
    "    # Plot training & validation loss\n",
    "    axes[0, 1].set_title('Model Loss')\n",
    "    for i, (history, name) in enumerate(zip(histories, model_names)):\n",
    "        if 'loss' in history:\n",
    "            axes[0, 1].plot(history['loss'], label=f'{name} Train')\n",
    "            axes[0, 1].plot(history['val_loss'], label=f'{name} Val')\n",
    "        elif 'train_loss' in history:\n",
    "            axes[0, 1].plot(history['train_loss'], label=f'{name} Train')\n",
    "            axes[0, 1].plot(history['val_loss'], label=f'{name} Val')\n",
    "    axes[0, 1].set_ylabel('Loss')\n",
    "    axes[0, 1].set_xlabel('Epoch')\n",
    "    axes[0, 1].legend()\n",
    "    \n",
    "    # Plot ROC curves\n",
    "    axes[1, 0].set_title('ROC Curves')\n",
    "    \n",
    "    # CNN ROC\n",
    "    fpr_cnn, tpr_cnn, _ = roc_curve(y_test, cnn_predictions)\n",
    "    axes[1, 0].plot(fpr_cnn, tpr_cnn, label=f'CNN (AUC = {cnn_auc:.3f})')\n",
    "    \n",
    "    # ResNet ROC\n",
    "    fpr_resnet, tpr_resnet, _ = roc_curve(y_test, resnet_predictions)\n",
    "    axes[1, 0].plot(fpr_resnet, tpr_resnet, label=f'ResNet (AUC = {resnet_auc:.3f})')\n",
    "    \n",
    "    # PyTorch ROC\n",
    "    fpr_pytorch, tpr_pytorch, _ = roc_curve(pytorch_targets, pytorch_predictions)\n",
    "    axes[1, 0].plot(fpr_pytorch, tpr_pytorch, label=f'PyTorch (AUC = {pytorch_auc:.3f})')\n",
    "    \n",
    "    axes[1, 0].plot([0, 1], [0, 1], 'k--', label='Random')\n",
    "    axes[1, 0].set_xlabel('False Positive Rate')\n",
    "    axes[1, 0].set_ylabel('True Positive Rate')\n",
    "    axes[1, 0].legend()\n",
    "    \n",
    "    # Plot confusion matrices\n",
    "    axes[1, 1].set_title('Model Performance Comparison')\n",
    "    x = np.arange(len(model_names))\n",
    "    accuracies = [cnn_test_acc, resnet_test_acc, pytorch_acc/100]\n",
    "    f1_scores = [cnn_f1, resnet_f1, pytorch_f1]\n",
    "    \n",
    "    width = 0.35\n",
    "    axes[1, 1].bar(x - width/2, accuracies, width, label='Accuracy', alpha=0.8)\n",
    "    axes[1, 1].bar(x + width/2, f1_scores, width, label='F1-Score', alpha=0.8)\n",
    "    axes[1, 1].set_xlabel('Models')\n",
    "    axes[1, 1].set_ylabel('Score')\n",
    "    axes[1, 1].set_xticks(x)\n",
    "    axes[1, 1].set_xticklabels(model_names)\n",
    "    axes[1, 1].legend()\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "# Plot training histories\n",
    "histories = [cnn_history.history, resnet_history.history, pytorch_history]\n",
    "model_names = ['CNN', 'ResNet', 'PyTorch']\n",
    "plot_training_history(histories, model_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 13. Confusion Matrix Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot confusion matrices for all models\n",
    "fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
    "\n",
    "models_data = [\n",
    "    (y_test, cnn_pred_binary, 'CNN Model'),\n",
    "    (y_test, resnet_pred_binary, 'ResNet Model'),\n",
    "    (pytorch_targets, pytorch_pred_binary, 'PyTorch Model')\n",
    "]\n",
    "\n",
    "for i, (true_labels, predictions, title) in enumerate(models_data):\n",
    "    cm = confusion_matrix(true_labels, predictions)\n",
    "    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[i])\n",
    "    axes[i].set_title(title)\n",
    "    axes[i].set_xlabel('Predicted')\n",
    "    axes[i].set_ylabel('Actual')\n",
    "    axes[i].set_xticklabels(['Fake', 'Real'])\n",
    "    axes[i].set_yticklabels(['Fake', 'Real'])\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 14. Model Interpretability - Feature Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Grad-CAM implementation for model interpretability\n",
    "def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):\n",
    "    # First, we create a model that maps the input image to the activations\n",
    "    # of the last conv layer as well as the output predictions\n",
    "    grad_model = tf.keras.models.Model(\n",
    "        [model.inputs], [model.get_layer(last_conv_layer_name).output, model.output]\n",
    "    )\n",
    "    \n",
    "    # Then, we compute the gradient of the top predicted class for our input image\n",
    "    # with respect to the activations of the last conv layer\n",
    "    with tf.GradientTape() as tape:\n",
    "        last_conv_layer_output, preds = grad_model(img_array)\n",
    "        if pred_index is None:\n",
    "            pred_index = tf.argmax(preds[0])\n",
    "        class_channel = preds[:, pred_index]\n",
    "    \n",
    "    # This is the gradient of the output neuron (top predicted or chosen)\n",
    "    # with regard to the output feature map of the last conv layer\n",
    "    grads = tape.gradient(class_channel, last_conv_layer_output)\n",
    "    \n",
    "    # This is a vector where each entry is the mean intensity of the gradient\n",
    "    # over a specific feature map channel\n",
    "    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))\n",
    "    \n",
    "    # We multiply each channel in the feature map array\n",
    "    # by \"how important this channel is\" with regard to the top predicted class\n",
    "    last_conv_layer_output = last_conv_layer_output[0]\n",
    "    heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]\n",
    "    heatmap = tf.squeeze(heatmap)\n",
    "    \n",
    "    # For visualization purpose, we will also normalize the heatmap between 0 & 1\n",
    "    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)\n",
    "    return heatmap.numpy()\n",
    "\n",
    "# Visualize Grad-CAM for a few test images\n",
    "def visualize_gradcam(model, images, true_labels, predictions, num_images=4):\n",
    "    fig, axes = plt.subplots(2, num_images, figsize=(16, 8))\n",
    "    \n",
    "    for i in range(num_images):\n",
    "        img = images[i:i+1]\n",
    "        \n",
    "        # Get the last convolutional layer name\n",
    "        last_conv_layer_name = None\n",
    "        for layer in reversed(model.layers):\n",
    "            if len(layer.output_shape) == 4:\n",
    "                last_conv_layer_name = layer.name\n",
    "                break\n",
    "        \n",
    "        if last_conv_layer_name:\n",
    "            # Generate heatmap\n",
    "            heatmap = make_gradcam_heatmap(img, model, last_conv_layer_name)\n",
    "            \n",
    "            # Display original image\n",
    "            axes[0, i].imshow(images[i])\n",
    "            axes[0, i].set_title(f'True: {\"Real\" if true_labels[i] else \"Fake\"}\\n'\n",
    "                               f'Pred: {\"Real\" if predictions[i] > 0.5 else \"Fake\"} ({predictions[i]:.3f})')\n",
    "            axes[0, i].axis('off')\n",
    "            \n",
    "            # Display heatmap\n",
    "            axes[1, i].imshow(images[i])\n",
    "            axes[1, i].imshow(heatmap, alpha=0.6, cmap='jet')\n",
    "            axes[1, i].set_title('Grad-CAM')\n",
    "            axes[1, i].axis('off')\n",
    "        else:\n",
    "            axes[0, i].text(0.5, 0.5, 'No conv layer found', ha='center', va='center')\n",
    "            axes[1, i].text(0.5, 0
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "# Continuation of the fake image detection notebook

# Complete the Grad-CAM visualization function
def visualize_gradcam(model, images, true_labels, predictions, num_images=4):
    fig, axes = plt.subplots(2, num_images, figsize=(16, 8))
    
    for i in range(num_images):
        img = images[i:i+1]
        
        # Get the last convolutional layer name
        last_conv_layer_name = None
        for layer in reversed(model.layers):
            if len(layer.output_shape) == 4:
                last_conv_layer_name = layer.name
                break
        
        if last_conv_layer_name:
            # Generate heatmap
            heatmap = make_gradcam_heatmap(img, model, last_conv_layer_name)
            
            # Display original image
            axes[0, i].imshow(images[i])
            axes[0, i].set_title(f'True: {"Real" if true_labels[i] else "Fake"}\n'
                               f'Pred: {"Real" if predictions[i] > 0.5 else "Fake"} ({predictions[i]:.3f})')
            axes[0, i].axis('off')
            
            # Display heatmap
            axes[1, i].imshow(images[i])
            axes[1, i].imshow(heatmap, alpha=0.6, cmap='jet')
            axes[1, i].set_title('Grad-CAM')
            axes[1, i].axis('off')
        else:
            axes[0, i].text(0.5, 0.5, 'No conv layer found', ha='center', va='center')
            axes[1, i].text(0.5, 0.5, 'No conv layer found', ha='center', va='center')
    
    plt.tight_layout()
    plt.show()

# Apply Grad-CAM to ResNet model
print("Generating Grad-CAM visualizations...")
try:
    # Select a few test images for visualization
    sample_indices = np.random.choice(len(X_test), 4, replace=False)
    sample_images = X_test[sample_indices]
    sample_labels = y_test[sample_indices]
    sample_predictions = resnet_predictions[sample_indices].flatten()
    
    visualize_gradcam(resnet_model, sample_images, sample_labels, sample_predictions)
except Exception as e:
    print(f"Grad-CAM visualization failed: {e}")
    print("Skipping Grad-CAM visualization...")

# 15. Error Analysis
print("\n=== Error Analysis ===")

# Find misclassified samples
def analyze_errors(true_labels, predictions, images, model_name):
    pred_binary = (predictions > 0.5).astype(int).flatten()
    misclassified = np.where(true_labels != pred_binary)[0]
    
    print(f"\n{model_name} - Misclassified samples: {len(misclassified)}")
    
    if len(misclassified) > 0:
        # False positives (predicted real, actually fake)
        false_positives = misclassified[(true_labels[misclassified] == 0) & (pred_binary[misclassified] == 1)]
        # False negatives (predicted fake, actually real)
        false_negatives = misclassified[(true_labels[misclassified] == 1) & (pred_binary[misclassified] == 0)]
        
        print(f"False Positives: {len(false_positives)}")
        print(f"False Negatives: {len(false_negatives)}")
        
        # Visualize some misclassified samples
        if len(misclassified) >= 4:
            sample_errors = np.random.choice(misclassified, min(4, len(misclassified)), replace=False)
            
            fig, axes = plt.subplots(1, len(sample_errors), figsize=(16, 4))
            if len(sample_errors) == 1:
                axes = [axes]
            
            for i, idx in enumerate(sample_errors):
                axes[i].imshow(images[idx])
                axes[i].set_title(f'True: {"Real" if true_labels[idx] else "Fake"}\n'
                                f'Pred: {"Real" if pred_binary[idx] else "Fake"}\n'
                                f'Confidence: {predictions[idx]:.3f}')
                axes[i].axis('off')
            
            plt.suptitle(f'{model_name} - Misclassified Samples')
            plt.tight_layout()
            plt.show()
    
    return misclassified

# Analyze errors for each model
cnn_errors = analyze_errors(y_test, cnn_predictions.flatten(), X_test, "CNN")
resnet_errors = analyze_errors(y_test, resnet_predictions.flatten(), X_test, "ResNet")
pytorch_errors = analyze_errors(pytorch_targets, pytorch_predictions.flatten(), X_test, "PyTorch")

# 16. Model Ensemble
print("\n=== Creating Model Ensemble ===")

# Create ensemble predictions
ensemble_predictions = (cnn_predictions.flatten() + resnet_predictions.flatten() + pytorch_predictions.flatten()) / 3
ensemble_pred_binary = (ensemble_predictions > 0.5).astype(int)

# Evaluate ensemble
ensemble_acc = np.mean(ensemble_pred_binary == y_test)
ensemble_f1 = f1_score(y_test, ensemble_pred_binary)
ensemble_auc = roc_auc_score(y_test, ensemble_predictions)

print(f"Ensemble Model Performance:")
print(f"Accuracy: {ensemble_acc:.4f}")
print(f"F1-Score: {ensemble_f1:.4f}")
print(f"AUC: {ensemble_auc:.4f}")

# Update performance comparison
ensemble_row = {
    'Model': 'Ensemble',
    'Accuracy': ensemble_acc,
    'Precision': None,
    'Recall': None,
    'F1-Score': ensemble_f1,
    'AUC': ensemble_auc
}

performance_df = pd.concat([performance_df, pd.DataFrame([ensemble_row])], ignore_index=True)
print("\n=== Updated Model Performance Comparison ===")
print(performance_df.to_string(index=False))

# 17. Feature Importance Analysis
print("\n=== Feature Importance Analysis ===")

def analyze_feature_importance(model, X_sample, layer_name=None):
    """
    Analyze which parts of images are most important for classification
    """
    try:
        # Get intermediate layer outputs
        if layer_name is None:
            # Find last convolutional layer
            for layer in reversed(model.layers):
                if 'conv' in layer.name.lower():
                    layer_name = layer.name
                    break
        
        if layer_name:
            intermediate_model = tf.keras.Model(
                inputs=model.input,
                outputs=model.get_layer(layer_name).output
            )
            
            # Get feature maps
            feature_maps = intermediate_model.predict(X_sample[:4])
            
            # Visualize feature maps
            fig, axes = plt.subplots(4, 4, figsize=(12, 12))
            for i in range(4):
                for j in range(4):
                    if j < feature_maps.shape[-1]:
                        axes[i, j].imshow(feature_maps[i, :, :, j], cmap='viridis')
                        axes[i, j].set_title(f'Sample {i+1}, Filter {j+1}')
                        axes[i, j].axis('off')
            
            plt.suptitle(f'Feature Maps from {layer_name}')
            plt.tight_layout()
            plt.show()
            
    except Exception as e:
        print(f"Feature importance analysis failed: {e}")

# Analyze feature importance for ResNet model
analyze_feature_importance(resnet_model, X_test)

# 18. Model Comparison Summary
print("\n=== Final Model Comparison Summary ===")

# Create a comprehensive comparison
comparison_data = {
    'Metric': ['Accuracy', 'F1-Score', 'AUC-ROC', 'Training Time (relative)'],
    'CNN': [f"{cnn_test_acc:.4f}", f"{cnn_f1:.4f}", f"{cnn_auc:.4f}", "Fast"],
    'ResNet': [f"{resnet_test_acc:.4f}", f"{resnet_f1:.4f}", f"{resnet_auc:.4f}", "Medium"],
    'PyTorch': [f"{pytorch_acc/100:.4f}", f"{pytorch_f1:.4f}", f"{pytorch_auc:.4f}", "Medium"],
    'Ensemble': [f"{ensemble_acc:.4f}", f"{ensemble_f1:.4f}", f"{ensemble_auc:.4f}", "Slow"]
}

comparison_df = pd.DataFrame(comparison_data)
print(comparison_df.to_string(index=False))

# Determine best model
best_model_idx = performance_df['AUC'].idxmax()
best_model = performance_df.loc[best_model_idx, 'Model']
best_auc = performance_df.loc[best_model_idx, 'AUC']

print(f"\nBest performing model: {best_model} with AUC: {best_auc:.4f}")

# 19. Save Final Models and Results
print("\n=== Saving Models and Results ===")

# Save model architectures and weights
try:
    # Save TensorFlow models
    cnn_model.save(os.path.join(CONFIG['MODEL_SAVE_DIR'], 'final_cnn_model.h5'))
    resnet_model.save(os.path.join(CONFIG['MODEL_SAVE_DIR'], 'final_resnet_model.h5'))
    
    # Save PyTorch model
    torch.save({
        'model_state_dict': pytorch_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'config': CONFIG
    }, os.path.join(CONFIG['MODEL_SAVE_DIR'], 'final_pytorch_model.pth'))
    
    print("Models saved successfully!")
    
except Exception as e:
    print(f"Error saving models: {e}")

# Save performance results
try:
    performance_df.to_csv(os.path.join(CONFIG['MODEL_SAVE_DIR'], 'model_performance.csv'), index=False)
    
    # Save detailed results
    results = {
        'cnn_predictions': cnn_predictions.flatten().tolist(),
        'resnet_predictions': resnet_predictions.flatten().tolist(),
        'pytorch_predictions': pytorch_predictions.flatten().tolist(),
        'ensemble_predictions': ensemble_predictions.tolist(),
        'true_labels': y_test.tolist(),
        'config': CONFIG
    }
    
    import json
    with open(os.path.join(CONFIG['MODEL_SAVE_DIR'], 'detailed_results.json'), 'w') as f:
        json.dump(results, f, indent=2)
    
    print("Results saved successfully!")
    
except Exception as e:
    print(f"Error saving results: {e}")

# 20. Model Deployment Preparation
print("\n=== Model Deployment Preparation ===")

def create_inference_function(model, img_size):
    """
    Create a function for model inference on new images
    """
    def predict_image(image_path):
        try:
            # Load and preprocess image
            img = load_and_preprocess_image(image_path, img_size)
            if img is None:
                return None, "Error loading image"
            
            # Make prediction
            img_batch = np.expand_dims(img, axis=0)
            prediction = model.predict(img_batch)[0][0]
            
            # Determine result
            is_real = prediction > 0.5
            confidence = prediction if is_real else 1 - prediction
            
            result = {
                'prediction': 'Real' if is_real else 'Fake',
                'confidence': float(confidence),
                'raw_score': float(prediction)
            }
            
            return result, None
            
        except Exception as e:
            return None, str(e)
    
    return predict_image

# Create inference functions for best model
if best_model == 'ResNet':
    best_tf_model = resnet_model
elif best_model == 'CNN':
    best_tf_model = cnn_model
else:
    best_tf_model = resnet_model  # Default to ResNet

inference_fn = create_inference_function(best_tf_model, CONFIG['IMG_SIZE'])

print(f"Inference function created for {best_model} model")
print("Usage: result, error = inference_fn('path/to/image.jpg')")

# 21. Generate Classification Report
print("\n=== Detailed Classification Reports ===")

# Generate classification reports for all models
models_info = [
    ('CNN', y_test, cnn_pred_binary),
    ('ResNet', y_test, resnet_pred_binary),
    ('PyTorch', pytorch_targets, pytorch_pred_binary),
    ('Ensemble', y_test, ensemble_pred_binary)
]

for model_name, true_labels, predictions in models_info:
    print(f"\n{model_name} Model Classification Report:")
    print("=" * 50)
    print(classification_report(true_labels, predictions, 
                              target_names=['Fake', 'Real'], 
                              digits=4))

# 22. Final Recommendations and Next Steps
print("\n=== Recommendations and Next Steps ===")
print("""
RECOMMENDATIONS:
1. Best Model: {} with AUC: {:.4f}
2. Ensemble approach shows promising results
3. Consider data augmentation techniques for better generalization
4. Implement additional evaluation metrics (precision-recall curves)
5. Test on different types of fake images (deepfakes, GAN-generated, etc.)

NEXT STEPS:
1. Deploy the best model to production
2. Implement real-time inference pipeline
3. Create model monitoring and retraining pipeline
4. Collect more diverse training data
5. Experiment with newer architectures (Vision Transformers, EfficientNet)
6. Implement adversarial training for robustness

MODEL FILES SAVED:
- final_cnn_model.h5
- final_resnet_model.h5
- final_pytorch_model.pth
- model_performance.csv
- detailed_results.json
""".format(best_model, best_auc))

print("\n=== Training Pipeline Completed Successfully ===")
print(f"Total models trained: 4 (CNN, ResNet, PyTorch ResNet, Ensemble)")
print(f"Best model: {best_model}")import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import cv2
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Deep Learning Libraries
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, optimizers, callbacks
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models

# Sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve, f1_score

# Add project root to path
sys.path.append('..')
from utils.data_preprocessing import load_and_preprocess_image
from utils.visualization import plot_sample_images, plot_roc_curve, plot_confusion_matrix
from utils.metrics import plot_training_history_tf, plot_training_history_torch # Assuming these are in utils.metrics
from models.cnn_model import build_cnn_model # Assuming these are defined in models
from models.resnet_model import build_resnet_model # Assuming these are defined in models

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)
torch.manual_seed(42)

# GPU Configuration
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

---

## 2. Configuration and Hyperparameters

```python
# Configuration
CONFIG = {
    'DATA_DIR': '../data',
    'REAL_DIR': '../data/real',
    'FAKE_DIR': '../data/fake',
    'TEST_DIR': '../data/test',
    'MODEL_SAVE_DIR': '../models/saved_models',
    
    # Image parameters
    'IMG_SIZE': (224, 224),
    'BATCH_SIZE': 32,
    'CHANNELS': 3,
    
    # Training parameters
    'EPOCHS': 50,
    'LEARNING_RATE': 0.001,
    'VALIDATION_SPLIT': 0.2,
    'TEST_SPLIT': 0.1,
    
    # Model parameters
    'DROPOUT_RATE': 0.5,
    'L2_REGULARIZATION': 0.01,
    
    # Callbacks
    'EARLY_STOPPING_PATIENCE': 10,
    'REDUCE_LR_PATIENCE': 5,
    'REDUCE_LR_FACTOR': 0.5
}

# Create necessary directories
os.makedirs(CONFIG['MODEL_SAVE_DIR'], exist_ok=True)

print("Configuration loaded successfully!")
print(f"Image size: {CONFIG['IMG_SIZE']}")
print(f"Batch size: {CONFIG['BATCH_SIZE']}")
print(f"Epochs: {CONFIG['EPOCHS']}")

---

## 3. Data Loading and Preprocessing

```python
# Load and preprocess data
def load_dataset(real_dir, fake_dir, img_size):
    """
    Load images from real and fake directories
    """
    images = []
    labels = []
    filepaths = []
    
    # Load real images (label = 1)
    real_files = [f for f in os.listdir(real_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    print(f"Found {len(real_files)} real images")
    
    for filename in tqdm(real_files, desc="Loading real images"):
        filepath = os.path.join(real_dir, filename)
        img = load_and_preprocess_image(filepath, img_size)
        if img is not None:
            images.append(img)
            labels.append(1)  # Real = 1
            filepaths.append(filepath)
    
    # Load fake images (label = 0)
    fake_files = [f for f in os.listdir(fake_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    print(f"Found {len(fake_files)} fake images")
    
    for filename in tqdm(fake_files, desc="Loading fake images"):
        filepath = os.path.join(fake_dir, filename)
        img = load_and_preprocess_image(filepath, img_size)
        if img is not None:
            images.append(img)
            labels.append(0)  # Fake = 0
            filepaths.append(filepath)
    
    return np.array(images), np.array(labels), filepaths

# Load the dataset
print("Loading dataset...")
X, y, filepaths = load_dataset(CONFIG['REAL_DIR'], CONFIG['FAKE_DIR'], CONFIG['IMG_SIZE'])

print(f"Total images loaded: {len(X)}")
print(f"Real images: {np.sum(y)}")
print(f"Fake images: {len(y) - np.sum(y)}")
print(f"Image shape: {X.shape}")
print(f"Label distribution: {np.bincount(y)}")

---

## 4. Data Exploration and Visualization

```python
# Visualize sample images
plot_sample_images(X, y, num_samples=8)

# Plot class distribution
plt.figure(figsize=(8, 6))
labels_text = ['Fake', 'Real']
counts = np.bincount(y)
plt.bar(labels_text, counts, color=['red', 'green'], alpha=0.7)
plt.title('Class Distribution')
plt.ylabel('Number of Images')
for i, count in enumerate(counts):
    plt.text(i, count + 10, str(count), ha='center', va='bottom')
plt.show()

# Calculate and display basic statistics
print(f"\nDataset Statistics:")
print(f"Total samples: {len(X)}")
print(f"Image dimensions: {X.shape[1:]}")
print(f"Pixel value range: [{X.min():.3f}, {X.max():.3f}]")
print(f"Mean pixel value: {X.mean():.3f}")
print(f"Std pixel value: {X.std():.3f}")

---

## 5. Data Augmentation Setup

```python
# TensorFlow Data Augmentation
tf_train_datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    zoom_range=0.2,
    shear_range=0.1,
    brightness_range=[0.8, 1.2],
    validation_split=CONFIG['VALIDATION_SPLIT']
)

tf_val_datagen = ImageDataGenerator(
    validation_split=CONFIG['VALIDATION_SPLIT']
)

# PyTorch Data Augmentation using Albumentations
train_transform = A.Compose([
    A.Resize(CONFIG['IMG_SIZE'][0], CONFIG['IMG_SIZE'][1]),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.2),
    A.Rotate(limit=20, p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(CONFIG['IMG_SIZE'][0], CONFIG['IMG_SIZE'][1]),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

print("Data augmentation pipelines created successfully!")

---

## 6. Data Splitting

```python
# Split the data
X_temp, X_test, y_temp, y_test = train_test_split(
    X, y, test_size=CONFIG['TEST_SPLIT'], 
    stratify=y, random_state=42
)

X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=CONFIG['VALIDATION_SPLIT']/(1-CONFIG['TEST_SPLIT']), 
    stratify=y_temp, random_state=42
)

print(f"Training set size: {len(X_train)}")
print(f"Validation set size: {len(X_val)}")
print(f"Test set size: {len(X_test)}")

print(f"\nClass distribution in training set:")
print(f"Real: {np.sum(y_train)}, Fake: {len(y_train) - np.sum(y_train)}")
print(f"\nClass distribution in validation set:")
print(f"Real: {np.sum(y_val)}, Fake: {len(y_val) - np.sum(y_val)}")
print(f"\nClass distribution in test set:")
print(f"Real: {np.sum(y_test)}, Fake: {len(y_test) - np.sum(y_test)}")

---

## 7. TensorFlow/Keras Model Training

```python
# Build CNN Model
def build_cnn_model(input_shape, num_classes=1):
    model = models.Sequential([
        # First Convolutional Block
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.25),
        
        # Second Convolutional Block
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.25),
        
        # Third Convolutional Block
        layers.Conv2D(128, (3, 3), activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.25),
        
        # Fourth Convolutional Block
        layers.Conv2D(256, (3, 3), activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.25),
        
        # Dense layers
        layers.Flatten(),
        layers.Dense(512, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(CONFIG['DROPOUT_RATE']),
        layers.Dense(256, activation='relu'),
        layers.Dropout(CONFIG['DROPOUT_RATE']),
        layers.Dense(num_classes, activation='sigmoid')
    ])
    
    return model

# Build the model
input_shape = (*CONFIG['IMG_SIZE'], CONFIG['CHANNELS'])
cnn_model = build_cnn_model(input_shape)

# Compile the model
cnn_model.compile(
    optimizer=optimizers.Adam(learning_rate=CONFIG['LEARNING_RATE']),
    loss='binary_crossentropy',
    metrics=['accuracy', 'precision', 'recall']
)

# Model summary
cnn_model.summary()

```python
# Setup callbacks
callbacks_list = [
    callbacks.EarlyStopping(
        monitor='val_loss',
        patience=CONFIG['EARLY_STOPPING_PATIENCE'],
        restore_best_weights=True,
        verbose=1
    ),
    callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=CONFIG['REDUCE_LR_FACTOR'],
        patience=CONFIG['REDUCE_LR_PATIENCE'],
        verbose=1,
        min_lr=1e-7
    ),
    callbacks.ModelCheckpoint(
        filepath=os.path.join(CONFIG['MODEL_SAVE_DIR'], 'best_cnn_model.h5'),
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    )
]

print("Callbacks configured successfully!")

```python
# Train the CNN model
print("Starting CNN model training...")

cnn_history = cnn_model.fit(
    X_train, y_train,
    batch_size=CONFIG['BATCH_SIZE'],
    epochs=CONFIG['EPOCHS'],
    validation_data=(X_val, y_val),
    callbacks=callbacks_list,
    verbose=1
)

print("CNN model training completed!")

---

## 8. Transfer Learning with ResNet

```python
# Build ResNet-based model
def build_resnet_model(input_shape, num_classes=1):
    # Load pre-trained ResNet50
    base_model = tf.keras.applications.ResNet50(
        weights='imagenet',
        include_top=False,
        input_shape=input_shape
    )
    
    # Freeze base model layers
    base_model.trainable = False
    
    # Add custom layers
    model = models.Sequential([
        base_model,
        layers.GlobalAveragePooling2D(),
        layers.BatchNormalization(),
        layers.Dropout(0.5),
        layers.Dense(512, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.3),
        layers.Dense(256, activation='relu'),
        layers.Dropout(0.2),
        layers.Dense(num_classes, activation='sigmoid')
    ])
    
    return model, base_model

# Build ResNet model
resnet_model, base_model = build_resnet_model(input_shape)

# Compile the model
resnet_model.compile(
    optimizer=optimizers.Adam(learning_rate=CONFIG['LEARNING_RATE']),
    loss='binary_crossentropy',
    metrics=['accuracy', 'precision', 'recall']
)

# Model summary
resnet_model.summary()

```python
# Setup callbacks for ResNet
resnet_callbacks = [
    callbacks.EarlyStopping(
        monitor='val_loss',
        patience=CONFIG['EARLY_STOPPING_PATIENCE'],
        restore_best_weights=True,
        verbose=1
    ),
    callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=CONFIG['REDUCE_LR_FACTOR'],
        patience=CONFIG['REDUCE_LR_PATIENCE'],
        verbose=1,
        min_lr=1e-7
    ),
    callbacks.ModelCheckpoint(
        filepath=os.path.join(CONFIG['MODEL_SAVE_DIR'], 'best_resnet_model.h5'),
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    )
]

# Train ResNet model
print("Starting ResNet model training...")

resnet_history = resnet_model.fit(
    X_train, y_train,
    batch_size=CONFIG['BATCH_SIZE'],
    epochs=CONFIG['EPOCHS'],
    validation_data=(X_val, y_val),
    callbacks=resnet_callbacks,
    verbose=1
)

print("ResNet model training completed!")

---

## 9. Fine-tuning ResNet

```python
# Fine-tune the ResNet model
print("Starting fine-tuning of ResNet model...")

# Unfreeze the top layers of the base model
base_model.trainable = True

# Fine-tune from this layer onwards
fine_tune_at = 100

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
    layer.trainable = False

# Recompile with a lower learning rate
resnet_model.compile(
    optimizer=optimizers.Adam(learning_rate=CONFIG['LEARNING_RATE']/10),
    loss='binary_crossentropy',
    metrics=['accuracy', 'precision', 'recall']
)

# Fine-tune callbacks
finetune_callbacks = [
    callbacks.EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True,
        verbose=1
    ),
    callbacks.ModelCheckpoint(
        filepath=os.path.join(CONFIG['MODEL_SAVE_DIR'], 'best_resnet_finetuned.h5'),
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    )
]

# Continue training with fine-tuning
finetune_epochs = 10
total_epochs = len(resnet_history.history['loss']) + finetune_epochs

resnet_finetune_history = resnet_model.fit(
    X_train, y_train,
    batch_size=CONFIG['BATCH_SIZE'],
    epochs=total_epochs,
    initial_epoch=len(resnet_history.history['loss']),
    validation_data=(X_val, y_val),
    callbacks=finetune_callbacks,
    verbose=1
)

print("Fine-tuning completed!")

---

## 10. PyTorch Model Training

```python
# PyTorch Dataset Class
class FakeImageDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        
        if self.transform:
            # Convert to uint8 for albumentations
            image = (image * 255).astype(np.uint8)
            transformed = self.transform(image=image)
            image = transformed['image']
        else:
            image = torch.FloatTensor(image).permute(2, 0, 1)
        
        return image, torch.FloatTensor([label])

# Create PyTorch datasets
train_dataset = FakeImageDataset(X_train, y_train, transform=train_transform)
val_dataset = FakeImageDataset(X_val, y_val, transform=val_transform)
test_dataset = FakeImageDataset(X_test, y_test, transform=val_transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=False)

print(f"PyTorch datasets created:")
print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

```python
# PyTorch ResNet Model
class PyTorchResNetModel(nn.Module):
    def __init__(self, num_classes=1):
        super(PyTorchResNetModel, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        
        # Freeze early layers
        for param in self.resnet.parameters():
            param.requires_grad = False
        
        # Unfreeze last few layers
        for param in self.resnet.layer4.parameters():
            param.requires_grad = True
        
        # Replace the classifier
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.resnet(x)

# Initialize PyTorch model
pytorch_model = PyTorchResNetModel().to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(pytorch_model.parameters(), lr=CONFIG['LEARNING_RATE'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)

print("PyTorch model initialized!")
print(f"Model parameters: {sum(p.numel() for p in pytorch_model.parameters() if p.requires_grad)}")

```python
# PyTorch training function
def train_pytorch_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs):
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    
    best_val_acc = 0.0
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            predicted = (output > 0.5).float()
            total += target.size(0)
            correct += (predicted == target).sum().item()
        
        train_loss = running_loss / len(train_loader)
        train_acc = 100. * correct / total
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                val_loss += criterion(output, target).item()
                predicted = (output > 0.5).float()
                total += target.size(0)
                correct += (predicted == target).sum().item()
        
        val_loss /= len(val_loader)
        val_acc = 100. * correct / total
        
        # Store metrics
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accuracies.append(train_acc)
        val_accuracies.append(val_acc)
        
        # Learning rate scheduling
        scheduler.step(val_loss)
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), os.path.join(CONFIG['MODEL_SAVE_DIR'], 'best_pytorch_model.pth'))
        
        print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
              f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
    
    return {
        'train_loss': train_losses,
        'val_loss': val_losses,
        'train_acc': train_accuracies,
        'val_acc': val_accuracies
    }

# Train PyTorch model
print("Starting PyTorch model training...")
pytorch_history = train_pytorch_model(
    pytorch_model, train_loader, val_loader, 
    criterion, optimizer, scheduler, CONFIG['EPOCHS']
)
print("PyTorch model training completed!")

---

## 11. Model Evaluation and Comparison

```python
# Evaluate TensorFlow models
print("Evaluating TensorFlow models...")

# CNN Model Evaluation
cnn_test_loss, cnn_test_acc, cnn_test_precision, cnn_test_recall = cnn_model.evaluate(X_test, y_test, verbose=0)
cnn_predictions = cnn_model.predict(X_test)
cnn_pred_binary = (cnn_predictions > 0.5).astype(int).flatten()

# ResNet Model Evaluation
resnet_test_loss, resnet_test_acc, resnet_test_precision, resnet_test_recall = resnet_model.evaluate(X_test, y_test, verbose=0)
resnet_predictions = resnet_model.predict(X_test)
resnet_pred_binary = (resnet_predictions > 0.5).astype(int).flatten()

print(f"CNN Model - Test Accuracy: {cnn_test_acc:.4f}")
print(f"ResNet Model - Test Accuracy: {resnet_test_acc:.4f}")

# Calculate additional metrics
from sklearn.metrics import f1_score, roc_auc_score

cnn_f1 = f1_score(y_test, cnn_pred_binary)
cnn_auc = roc_auc_score(y_test, cnn_predictions)

resnet_f1 = f1_score(y_test, resnet_pred_binary)
resnet_auc = roc_auc_score(y_test, resnet_predictions)

print(f"CNN Model - F1 Score: {cnn_f1:.4f}, AUC: {cnn_auc:.4f}")
print(f"ResNet Model - F1 Score: {resnet_f1:.4f}, AUC: {resnet_auc:.4f}")

```python
# Evaluate PyTorch model
def evaluate_pytorch_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            predicted = (output > 0.5).float()
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            all_predictions.extend(output.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    
    accuracy = 100. * correct / total
    return accuracy, np.array(all_predictions), np.array(all_targets).flatten()

pytorch_acc, pytorch_predictions, pytorch_targets = evaluate_pytorch_model(pytorch_model, test_loader)
pytorch_pred_binary = (pytorch_predictions > 0.5).astype(int).flatten()

pytorch_f1 = f1_score(pytorch_targets, pytorch_pred_binary)
pytorch_auc = roc_auc_score(pytorch_targets, pytorch_predictions)

print(f"PyTorch Model - Test Accuracy: {pytorch_acc:.2f}%")
print(f"PyTorch Model - F1 Score: {pytorch_f1:.4f}, AUC: {pytorch_auc:.4f}")

```python
# Create comprehensive evaluation report
def create_evaluation_report():
    models_performance = {
        'Model': ['CNN', 'ResNet', 'PyTorch ResNet'],
        'Accuracy': [cnn_test_acc, resnet_test_acc, pytorch_acc/100],
        'Precision': [cnn_test_precision, resnet_test_precision, None], # Precision/Recall need to be calculated for PyTorch separately if needed.
        'Recall': [cnn_test_recall, resnet_test_recall, None], # Placeholder as they weren't explicitly calculated in the PyTorch eval function
        'F1-Score': [cnn_f1, resnet_f1, pytorch_f1],
        'AUC': [cnn_auc, resnet_auc, pytorch_auc]
    }
    
    df_performance = pd.DataFrame(models_performance)
    print("\n=== Model Performance Comparison ===")
    print(df_performance.to_string(index=False))
    
    return df_performance

performance_df = create_evaluation_report()

---

## 12. Visualization of Training History

```python
# Plot training history
def plot_training_history(histories, model_names):
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Plot training & validation accuracy
    axes[0, 0].set_title('Model Accuracy')
    for i, (history, name) in enumerate(zip(histories, model_names)):
        if 'accuracy' in history:
            axes[0, 0].plot(history['accuracy'], label=f'{name} Train')
            axes[0, 0].plot(history['val_accuracy'], label=f'{name} Val')
        elif 'train_acc' in history:
            axes[0, 0].plot([acc/100 for acc in history['train_acc']], label=f'{name} Train')
            axes[0, 0].plot([acc/100 for acc in history['val_acc']], label=f'{name} Val')
    axes[0, 0].set_ylabel('Accuracy')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].legend()
    
    # Plot training & validation loss
    axes[0, 1].set_title('Model Loss')
    for i, (history, name) in enumerate(zip(histories, model_names)):
        if 'loss' in history:
            axes[0, 1].plot(history['loss'], label=f'{name} Train')
            axes[0, 1].plot(history['val_loss'], label=f'{name} Val')
        elif 'train_loss' in history:
            axes[0, 1].plot(history['train_loss'], label=f'{name} Train')
            axes[0, 1].plot(history['val_loss'], label=f'{name} Val')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].legend()

    # Plot ROC Curves
    axes[1, 0].set_title('ROC Curve')
    # CNN ROC
    fpr, tpr, thresholds = roc_curve(y_test, cnn_predictions)
    roc_auc = roc_auc_score(y_test, cnn_predictions)
    axes[1, 0].plot(fpr, tpr, label=f'CNN (AUC = {roc_auc:.2f})')
    # ResNet ROC
    fpr, tpr, thresholds = roc_curve(y_test, resnet_predictions)
    roc_auc = roc_auc_score(y_test, resnet_predictions)
    axes[1, 0].plot(fpr, tpr, label=f'ResNet (AUC = {roc_auc:.2f})')
    # PyTorch ResNet ROC
    fpr, tpr, thresholds = roc_curve(pytorch_targets, pytorch_predictions)
    roc_auc = roc_auc_score(pytorch_targets, pytorch_predictions)
    axes[1, 0].plot(fpr, tpr, label=f'PyTorch ResNet (AUC = {roc_auc:.2f})')
    
    axes[1, 0].plot([0, 1], [0, 1], 'k--', label='Random Guess')
    axes[1, 0].set_xlim([0.0, 1.0])
    axes[1, 0].set_ylim([0.0, 1.05])
    axes[1, 0].set_xlabel('False Positive Rate')
    axes[1, 0].set_ylabel('True Positive Rate')
    axes[1, 0].legend(loc="lower right")

    # Plot Confusion Matrices (example for CNN, can extend for others)
    axes[1, 1].set_title('CNN Confusion Matrix')
    sns.heatmap(confusion_matrix(y_test, cnn_pred_binary), annot=True, fmt='d', cmap='Blues', ax=axes[1, 1],
                xticklabels=['Fake', 'Real'], yticklabels=['Fake', 'Real'])
    axes[1, 1].set_ylabel('Actual')
    axes[1, 1].set_xlabel('Predicted')
    
    plt.tight_layout()
    plt.show()

# Collect histories for plotting
histories_to_plot = [cnn_history.history, resnet_history.history, pytorch_history]
model_names_to_plot = ['CNN', 'ResNet', 'PyTorch ResNet']

plot_training_history(histories_to_plot, model_names_to_plot)

# Additionally, plot confusion matrices for all models
print("\n=== Confusion Matrices ===")
plt.figure(figsize=(18, 5))

plt.subplot(1, 3, 1)
plot_confusion_matrix(y_test, cnn_pred_binary, title='CNN Confusion Matrix')

plt.subplot(1, 3, 2)
plot_confusion_matrix(y_test, resnet_pred_binary, title='ResNet Confusion Matrix')

plt.subplot(1, 3, 3)
plot_confusion_matrix(pytorch_targets, pytorch_pred_binary, title='PyTorch ResNet Confusion Matrix')

plt.tight_layout()
plt.show()
print(f"Best AUC score: {best_auc:.4f}")
print("All models and results have been saved to the models directory.")

# Optional: Memory cleanup
import gc
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print("Memory cleanup completed.")