In [None]:
# Imports and setups
!pip install kagglehub --quiet

import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
import torch.nn.functional as F

from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix

import cv2
from glob import glob
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.io import read_image


In [None]:
import os
import math
from typing import Tuple, List

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

import torchvision.transforms as T
import torchvision.models as models

# Optional but recommended for wavelet features:
try:
    import pywt
    HAS_PYWT = True
except ImportError:
    HAS_PYWT = False
    print("[WARN] pywt not installed. Wavelet feature will be zeroed. Install with `pip install pywavelets`.")


Feature Extraction

# Gray Tensor

In [None]:
def to_gray_tensor(img_tensor: torch.Tensor) -> torch.Tensor:
    """
    img_tensor: (3, H, W), values in [0,1]
    Returns: (1, H, W) grayscale
    """
    if img_tensor.shape[0] == 1:
        return img_tensor
    r, g, b = img_tensor[0], img_tensor[1], img_tensor[2]
    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
    return gray.unsqueeze(0)

In [None]:
def high_pass_residual(gray: torch.Tensor) -> torch.Tensor:
    """
    Simple high-pass filter using Laplacian kernel.
    gray: (1, H, W)
    returns: (1, H, W)
    """
    kernel = torch.tensor([[ -1., -1., -1.],
                           [ -1.,  8., -1.],
                           [ -1., -1., -1.]], dtype=torch.float32, device=gray.device)
    kernel = kernel.view(1, 1, 3, 3)
    padding = 1
    residual = F.conv2d(gray.unsqueeze(0), kernel, padding=padding)
    residual = residual.squeeze(0)
    # Normalize
    residual = residual / (residual.abs().max() + 1e-8)
    return residual

In [None]:
def prnu_residual(gray: torch.Tensor, sigma: float = 1.0) -> torch.Tensor:
    """
    Approximate PRNU: image - gaussian_blur(image).
    gray: (1, H, W)
    returns: (1, H, W)
    """
    # Use a separable Gaussian blur via conv
    radius = int(3 * sigma)
    x = torch.arange(-radius, radius + 1, device=gray.device, dtype=torch.float32)
    gauss_1d = torch.exp(-0.5 * (x / sigma) ** 2)
    gauss_1d = gauss_1d / gauss_1d.sum()
    gauss_1d = gauss_1d.view(1, 1, -1)

    # Horizontal
    blurred = F.conv2d(gray.unsqueeze(0), gauss_1d.unsqueeze(3), padding=(0, radius))
    # Vertical
    blurred = F.conv2d(blurred, gauss_1d.unsqueeze(2), padding=(radius, 0))
    blurred = blurred.squeeze(0)

    residual = gray - blurred
    residual = residual / (residual.abs().max() + 1e-8)
    return residual

In [None]:

# def fft_magnitude_map(gray: torch.Tensor) -> torch.Tensor:
#     """
#     Compute log-magnitude FFT map.
#     gray: (1, H, W)
#     returns: (1, H, W)
#     """
#     # Convert to numpy for simplicity then back to torch
#     g = gray.squeeze(0).cpu().numpy()
#     fft2 = np.fft.fft2(g)
#     fft_shift = np.fft.fftshift(fft2)
#     magnitude = np.log(1 + np.abs(fft_shift))

#     # Normalize
#     magnitude = (magnitude - magnitude.min()) / (magnitude.max() - magnitude.min() + 1e-8)
#     mag_tensor = torch.from_numpy(magnitude).float().unsqueeze(0)
#     return mag_tensor.to(gray.device)
def fft_magnitude_map(gray: torch.Tensor) -> torch.Tensor:
    """
    Compute log-magnitude FFT map using torch.fft.
    gray: (1, H, W)
    returns: (1, H, W)
    """
    g = gray.squeeze(0)  # (H, W)
    G = torch.fft.fft2(g)
    G_shift = torch.fft.fftshift(G)
    mag = torch.log1p(torch.abs(G_shift))

    mag_min = mag.min()
    mag_max = mag.max()
    mag = (mag - mag_min) / (mag_max - mag_min + 1e-8)
    return mag.unsqueeze(0)



In [None]:

def wavelet_hh_band(gray: torch.Tensor) -> torch.Tensor:
    """
    Extract high-frequency HH band via wavelet transform.
    gray: (1, H, W)
    returns: (1, H, W) (upsampled back to original size if needed)
    """
    if not HAS_PYWT:
        # Fallback: zeros
        return torch.zeros_like(gray)

    g = gray.squeeze(0).cpu().numpy()
    coeffs2 = pywt.dwt2(g, 'haar')
    LL, (LH, HL, HH) = coeffs2
    hh = HH

    # Normalize
    hh = (hh - hh.min()) / (hh.max() - hh.min() + 1e-8)

    hh_tensor = torch.from_numpy(hh).float().unsqueeze(0).to(gray.device)
    # Resize back to original size using bilinear upsampling
    hh_tensor = F.interpolate(hh_tensor.unsqueeze(0),
                              size=gray.shape[1:],
                              mode='bilinear',
                              align_corners=False).squeeze(0)
    return hh_tensor


In [None]:
def autocorrelation_map(gray: torch.Tensor) -> torch.Tensor:
    """
    Approximate local autocorrelation map.
    Inputs can be:
        (H, W)
        (1, H, W)
        (C, H, W)  -> will be averaged to 1 channel
        (B, 1, H, W)
    Output:
        (1, H, W) always
    """

    g = gray

    # --------------------------------------------------------
    # 1. Ensure g is always 4D: [B, C, H, W]
    # --------------------------------------------------------
    if g.dim() == 2:            # H W
        g = g.unsqueeze(0).unsqueeze(0)
    elif g.dim() == 3:          # C H W
        g = g.unsqueeze(0)
    # If B,C,H,W – leave as is

    # --------------------------------------------------------
    # 2. Force single channel (autocorr on grayscale)
    # --------------------------------------------------------
    if g.size(1) > 1:
        g = g.mean(dim=1, keepdim=True)

    # g now = [B=1, C=1, H, W]
    # --------------------------------------------------------
    # 3. Compute horizontal + vertical shifts
    # --------------------------------------------------------
    g_right = F.pad(g[:, :, :, 1:], (0, 1, 0, 0), mode='reflect')
    g_down  = F.pad(g[:, :, 1:, :], (0, 0, 0, 1), mode='reflect')

    # --------------------------------------------------------
    # 4. Z-normalize each patch
    # --------------------------------------------------------
    eps = 1e-8
    mean = g.mean(dim=(2,3), keepdim=True)
    std  = g.std(dim=(2,3), keepdim=True) + eps

    g_norm       = (g - mean) / std
    g_right_norm = (g_right - mean) / std
    g_down_norm  = (g_down - mean) / std

    # --------------------------------------------------------
    # 5. Compute local auto-correlation
    # --------------------------------------------------------
    corr_h = g_norm * g_right_norm
    corr_v = g_norm * g_down_norm
    corr = (corr_h + corr_v) / 2.0

    # --------------------------------------------------------
    # 6. Normalize to [0,1]
    # --------------------------------------------------------
    corr_min = corr.amin(dim=(2,3), keepdim=True)
    corr_max = corr.amax(dim=(2,3), keepdim=True)
    corr = (corr - corr_min) / (corr_max - corr_min + eps)

    return corr.squeeze(0)  # → return shape (1, H, W)


Get The Dataset

In [None]:
def extract_forensic_tensor(rgb_tensor: torch.Tensor, use_channels=None) -> torch.Tensor:
    """
    Extracts forensic tensors.
    use_channels: list of strings, subset of:
        "hpr", "prnu", "fft", "wavelet", "autocorr"
    If use_channels is None -> use all channels.
    If use_channels is []   -> return a dummy zero-channel (1, H, W)
                               (used for RGB-only runs).
    """
    gray = to_gray_tensor(rgb_tensor)  # (1, H, W)

    # No forensic channels requested -> cheap dummy tensor
    if use_channels == []:
        return torch.zeros((1, gray.shape[1], gray.shape[2]), device=gray.device)

    if use_channels is None:
        use_channels = ["hpr", "prnu", "autocorr"]

    features = {}

    if "hpr" in use_channels:
        features["hpr"] = high_pass_residual(gray)

    if "prnu" in use_channels:
        features["prnu"] = prnu_residual(gray)

    if "autocorr" in use_channels:
        features["autocorr"] = autocorrelation_map(gray)

    # Stack selected channels: (C, H, W)
    forensic_tensor = torch.cat([features[k] for k in features.keys()], dim=0)
    return forensic_tensor

class HFRealAIDataset(Dataset):
    def __init__(self, hf_dataset, image_size=256, augment=True, use_channels=None):
        self.ds = hf_dataset
        self.image_size = image_size
        self.use_channels = use_channels

        self.base_transform = T.Compose([
            T.Resize((image_size, image_size)),
            T.ToTensor(),
        ])

        self.augment = T.Compose([
            T.RandomHorizontalFlip(0.5),
            T.RandomVerticalFlip(0.1),
            T.RandomRotation(5),
        ]) if augment else None

    def __getitem__(self, idx):
        item = self.ds[idx]
        img = self.base_transform(item["image"].convert("RGB"))

        if self.augment:
            img = self.augment(img)

        forensic = extract_forensic_tensor(img, use_channels=self.use_channels)

        return {
            "rgb": img,
            "forensic": forensic,
            "label": torch.tensor(item["label"], dtype=torch.long)
        }

    def __len__(self) -> int:
      return len(self.ds)


In [None]:
# =========================================================
# Model Variants for Individual Feature Testing
# =========================================================

class RGBOnlyModel(nn.Module):
    """
    RGB-only model using EfficientNet-B0
    """
    def __init__(self, num_classes: int = 2):
        super().__init__()
        effnet = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
        self.rgb_backbone = effnet.features
        self.rgb_pool = nn.AdaptiveAvgPool2d((1, 1))
        rgb_feat_dim = effnet.classifier[1].in_features

        self.classifier = nn.Sequential(
            nn.Linear(rgb_feat_dim, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(512, num_classes)
        )

    def forward(self, rgb, forensic=None):
        x_rgb = self.rgb_backbone(rgb)
        x_rgb = self.rgb_pool(x_rgb)
        x_rgb = x_rgb.view(x_rgb.size(0), -1)
        logits = self.classifier(x_rgb)
        return logits


class ForensicCNN(nn.Module):
    """
    Small CNN for forensic feature maps.
    """

    def __init__(self, in_channels: int = 5, out_dim: int = 256):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        self.fc = nn.Linear(128, out_dim)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

class ForensicOnlyModel(nn.Module):
    """
    Forensic-only model using a single forensic feature (or subset).
    """
    def __init__(self, num_classes: int = 2, in_channels: int = 1, forensic_out_dim: int = 256):
        super().__init__()
        self.forensic_backbone = ForensicCNN(in_channels=in_channels, out_dim=forensic_out_dim)
        self.classifier = nn.Sequential(
            nn.Linear(forensic_out_dim, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(512, num_classes)
        )

    def forward(self, rgb, forensic):
        x_for = self.forensic_backbone(forensic)
        logits = self.classifier(x_for)
        return logits



In [None]:
# class AIRealDataset(Dataset):
#     """
#     Expects directory structure like:
#     root_dir/
#         real/
#             img1.jpg
#             img2.png
#             ...
#         fake/
#             imgA.jpg
#             imgB.png
#             ...
#     label: real -> 0, fake -> 1
#     """

#     def __init__(self, root_dir: str, image_size: int = 256):
#         super().__init__()
#         self.root_dir = root_dir
#         self.classes = ['REAL', 'FAKE']
#         self.class_to_idx = {c: i for i, c in enumerate(self.classes)}

#         self.samples: List[Tuple[str, int]] = []
#         for cls in self.classes:
#             class_dir = os.path.join(root_dir, cls)
#             if not os.path.isdir(class_dir):
#                 print(f"[WARN] Class directory not found: {class_dir}")
#                 continue
#             for fname in os.listdir(class_dir):
#                 if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.webp')):
#                     path = os.path.join(class_dir, fname)
#                     self.samples.append((path, self.class_to_idx[cls]))

#         print(f"[INFO] Loaded {len(self.samples)} images from {root_dir}")

#         # Basic transforms for RGB input (no heavy augment on noise!)
#         self.transform = T.Compose([
#             T.Resize((image_size, image_size)),
#             T.ToTensor(),  # converts to [0,1]
#         ])

#         # Optional light augmentation for training (you can extend this)
#         self.augment = T.Compose([
#             T.RandomHorizontalFlip(p=0.5),
#             T.RandomVerticalFlip(p=0.1),
#             T.RandomRotation(degrees=5),
#         ])

#     def __len__(self) -> int:
#         return len(self.samples)

#     def __getitem__(self, idx: int):
#         path, label = self.samples[idx]
#         img = Image.open(path).convert('RGB')

#         # Apply base transforms
#         img = self.transform(img)  # (3,H,W)

#         # Light augmentation that doesn't kill noise stats too much
#         img = self.augment(img)

#         # Extract forensic tensor
#         forensic = extract_forensic_tensor(img)  # (5,H,W)

#         return {
#             "rgb": img,               # (3,H,W)
#             "forensic": forensic,     # (5,H,W)
#             "label": torch.tensor(label, dtype=torch.long)
#         }



Two Stream Models

In [None]:

class TwoStreamNet(nn.Module):
    """
    Two-stream architecture:
    - RGB stream: EfficientNet-B0
    - Forensic stream: small CNN
    """

    def __init__(self, num_classes: int = 2, forensic_out_dim: int = 256, forensic_in_channels: int = 5):
        super().__init__()
        effnet = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
        self.rgb_backbone = effnet.features
        self.rgb_pool = nn.AdaptiveAvgPool2d((1, 1))
        rgb_feat_dim = effnet.classifier[1].in_features

        self.forensic_backbone = ForensicCNN(in_channels=forensic_in_channels,
                                             out_dim=forensic_out_dim)

        fusion_dim = rgb_feat_dim + forensic_out_dim
        self.classifier = nn.Sequential(
            nn.Linear(fusion_dim, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(512, num_classes)
        )

    def forward(self, rgb, forensic):
        x_rgb = self.rgb_backbone(rgb)
        x_rgb = self.rgb_pool(x_rgb)
        x_rgb = x_rgb.view(x_rgb.size(0), -1)

        x_for = self.forensic_backbone(forensic)

        x = torch.cat([x_rgb, x_for], dim=1)
        logits = self.classifier(x)
        return logits


In [None]:
from datasets import load_dataset

hf_ds = load_dataset("Hemg/AI-Generated-vs-Real-Images-Datasets")
hf_ds


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/590 [00:00<?, ?B/s]

data/train-00000-of-00006-336b26d54a26e1(…):   0%|          | 0.00/91.2M [00:00<?, ?B/s]

data/train-00001-of-00006-8ad2d550254dea(…):   0%|          | 0.00/25.9M [00:00<?, ?B/s]

data/train-00002-of-00006-ac8970f21c0418(…):   0%|          | 0.00/339M [00:00<?, ?B/s]

data/train-00003-of-00006-f635132ef309a7(…):   0%|          | 0.00/311M [00:00<?, ?B/s]

data/train-00004-of-00006-1101eaf5152e1c(…):   0%|          | 0.00/40.5M [00:00<?, ?B/s]

data/train-00005-of-00006-4bd152a5ab76db(…):   0%|          | 0.00/565M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/152710 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 152710
    })
})

In [None]:
# =========================================================
# Individual Feature Extractor Testing
# =========================================================
import json
import copy
from sklearn.metrics import roc_auc_score

EPOCHS = 20
device = "cuda" if torch.cuda.is_available() else "cpu"

# Define all model configurations
model_configs = {
    "TwoStream_3Features": {
        "model_class": TwoStreamNet,
        "features": "RGB + PRNU + HPF + Autocorrelation",
        "use_channels": ["prnu", "hpr", "autocorr"],
        "forensic_channels": 3
    }
}

# Training function for a single model variant
def train_model_variant(model_name, config, train_loader, val_loader, epochs=EPOCHS):
    print(f"\n{'='*60}")
    print(f"Training: {model_name} ({config['features']})")
    print(f"{'='*60}")

    # Create model
# Create model
    if model_name == "TwoStream_3Features":
        # Create the correct Two-Stream model
        model = TwoStreamNet(
            forensic_in_channels=config["forensic_channels"],  # should be 3
            forensic_out_dim=256
        ).to(device)

        # Load Kaggle pretrained weights
        print("[INFO] Loading pretrained Kaggle weights...")
        state_dict = torch.load("best_three_stream_detector.pt", map_location=device)
        model.load_state_dict(state_dict)
    else:
        raise ValueError("Unknown model type in model_configs")

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

    history = {
        "train_loss": [],
        "train_acc": [],
        "val_loss": [],
        "val_acc": [],
        "val_auc": []
    }

    best_val_acc = 0.0
    best_model_state = None
    best_epoch = 0

    for epoch in range(epochs):
        # Training phase
        model.train()
        train_losses = []
        train_correct = 0
        train_total = 0

        for batch_idx, batch in enumerate(train_loader):
            rgb = batch["rgb"].to(device)
            forensic = batch["forensic"].to(device)
            labels = batch["label"].to(device)

            optimizer.zero_grad()
            preds = model(rgb, forensic)
            loss = criterion(preds, labels)
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())
            _, predicted = preds.max(1)
            train_correct += (predicted == labels).sum().item()
            train_total += labels.size(0)

            # Progress update every 10% of batches
            if (batch_idx + 1) % max(1, len(train_loader) // 10) == 0:
                current_acc = train_correct / train_total if train_total > 0 else 0.0
                current_loss = np.mean(train_losses)
                print(f"  Train Progress: {batch_idx+1}/{len(train_loader)} batches | "
                      f"Loss={current_loss:.4f} | Acc={current_acc:.4f}")

        avg_train_loss = np.mean(train_losses)
        train_acc = train_correct / train_total

        # Validation phase
        model.eval()
        val_losses = []
        val_correct = 0
        val_total = 0
        y_true = []
        y_pred_probs = []

        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                rgb = batch["rgb"].to(device)
                forensic = batch["forensic"].to(device)
                labels = batch["label"].to(device)

                outputs = model(rgb, forensic)
                loss = criterion(outputs, labels)
                val_losses.append(loss.item())

                _, predicted = outputs.max(1)
                val_correct += (predicted == labels).sum().item()
                val_total += labels.size(0)

                # For AUC calculation
                probs = F.softmax(outputs, dim=1)[:, 1]  # Probability of class 1
                y_true.extend(labels.cpu().numpy().tolist())
                y_pred_probs.extend(probs.cpu().numpy().tolist())

                # Progress update every 25% of batches
                if (batch_idx + 1) % max(1, len(val_loader) // 4) == 0:
                    current_acc = val_correct / val_total if val_total > 0 else 0.0
                    current_loss = np.mean(val_losses)
                    print(f"  Val Progress: {batch_idx+1}/{len(val_loader)} batches | "
                          f"Loss={current_loss:.4f} | Acc={current_acc:.4f}")

        avg_val_loss = np.mean(val_losses)
        val_acc = val_correct / val_total
        val_auc = roc_auc_score(y_true, y_pred_probs)

        # Update history
        history["train_loss"].append(avg_train_loss)
        history["train_acc"].append(train_acc)
        history["val_loss"].append(avg_val_loss)
        history["val_acc"].append(val_acc)
        history["val_auc"].append(val_auc)

        print(f"\nEpoch {epoch+1}/{epochs}:")
        print(f"  Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"  Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.4f} | Val AUC: {val_auc:.4f}")

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = copy.deepcopy(model.state_dict())
            best_epoch = epoch + 1
            print(f"  \u2713 New best validation accuracy: {best_val_acc:.4f}")

    return model, history, best_model_state, best_val_acc, best_epoch


# Load dataset and create dataloaders for each feature
print("[INFO] Loading dataset and preparing feature-specific dataloaders...")
full_train = hf_ds["train"]

# Split dataset
seed = 42
split_ds = full_train.train_test_split(test_size=0.2, seed=seed)
train_hf = split_ds["train"]
val_hf = split_ds["test"]

# Create dataloaders for each feature configuration
image_size = 256
batch_size = 16
num_workers = 8

feature_dataloaders = {}

for model_name, config in model_configs.items():
    # Create datasets with specific feature channels
    train_ds = HFRealAIDataset(train_hf, image_size=image_size, augment=True,
                                use_channels=config.get("use_channels"))
    val_ds = HFRealAIDataset(val_hf, image_size=image_size, augment=False,
                              use_channels=config.get("use_channels"))

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
                            num_workers=num_workers, pin_memory=True)

    feature_dataloaders[model_name] = {
        "train_loader": train_loader,
        "val_loader": val_loader
    }

    print(f"  {model_name}: Train={len(train_ds)}, Val={len(val_ds)}")

# Train all model variants
results_dict = {}
all_models = {}

for model_name, config in model_configs.items():
    train_loader = feature_dataloaders[model_name]["train_loader"]
    val_loader = feature_dataloaders[model_name]["val_loader"]

    model, history, best_state, best_val_acc, best_epoch = train_model_variant(
        model_name, config, train_loader, val_loader, epochs=EPOCHS
    )

    # Save best model
    save_path = f"best_{model_name.lower()}_model.pth"
    torch.save(best_state, save_path)
    print(f"\nBest model saved to: {save_path}")

    # Get best training metrics (at the epoch with best validation accuracy)
    best_train_loss = history["train_loss"][best_epoch - 1]
    best_train_acc = history["train_acc"][best_epoch - 1]
    best_val_loss = history["val_loss"][best_epoch - 1]

    # Convert state dict to CPU for summary
    cpu_state = {k: v.cpu() if hasattr(v, 'cpu') else v for k, v in best_state.items()}

    # Store results
    results_dict[model_name] = {
        "features": config["features"],
        "use_channels": config.get("use_channels"),
        "best_epoch": best_epoch,
        "best_train_loss": float(best_train_loss),
        "best_train_acc": float(best_train_acc),
        "best_val_loss": float(best_val_loss),
        "best_val_acc": float(best_val_acc),
        "best_val_auc": float(history["val_auc"][best_epoch - 1]),
        "model_path": save_path,
        "model_weights_summary": {
            k: {
                "shape": list(v.shape) if hasattr(v, 'shape') else [],
                "mean": float(v.mean().item()) if hasattr(v, 'mean') and len(v.shape) > 0 else float(v.item()) if hasattr(v, 'item') else 0.0,
                "std": float(v.std().item()) if hasattr(v, 'std') and len(v.shape) > 0 else 0.0,
                "min": float(v.min().item()) if hasattr(v, 'min') and len(v.shape) > 0 else float(v.item()) if hasattr(v, 'item') else 0.0,
                "max": float(v.max().item()) if hasattr(v, 'max') and len(v.shape) > 0 else float(v.item()) if hasattr(v, 'item') else 0.0
            }
            for k, v in cpu_state.items()
        }
    }

    all_models[model_name] = {
        "model": model,
        "history": history
    }

# Save results dictionary
with open("feature_extractor_comparison_results.json", "w") as f:
    json.dump(results_dict, f, indent=2)

print(f"\n{'='*60}")
print("TRAINING COMPLETE - SUMMARY")
print(f"{'='*60}")
print(f"\n{'Model':<20} {'Features':<30} {'Train Acc':<12} {'Val Acc':<12} {'Val Loss':<12} {'Val AUC':<12}")
print("-" * 100)

for model_name, results in results_dict.items():
    print(f"{model_name:<20} {results['features']:<30} "
          f"{results['best_train_acc']:<12.4f} {results['best_val_acc']:<12.4f} "
          f"{results['best_val_loss']:<12.4f} {results['best_val_auc']:<12.4f}")

print(f"\nResults saved to: feature_extractor_comparison_results.json")
print(f"Best models saved with prefix: best_*_model.pth")


[INFO] Loading dataset and preparing feature-specific dataloaders...
  TwoStream_3Features: Train=122168, Val=30542

Training: TwoStream_3Features (RGB + PRNU + HPF + Autocorrelation)
[INFO] Loading pretrained Kaggle weights...




  Train Progress: 763/7636 batches | Loss=0.5741 | Acc=0.8153
  Train Progress: 1526/7636 batches | Loss=0.3908 | Acc=0.8646
  Train Progress: 2289/7636 batches | Loss=0.3136 | Acc=0.8878
  Train Progress: 3052/7636 batches | Loss=0.2679 | Acc=0.9027
  Train Progress: 3815/7636 batches | Loss=0.2380 | Acc=0.9126
  Train Progress: 4578/7636 batches | Loss=0.2163 | Acc=0.9201
  Train Progress: 5341/7636 batches | Loss=0.2012 | Acc=0.9253
  Train Progress: 6104/7636 batches | Loss=0.1884 | Acc=0.9295
  Train Progress: 6867/7636 batches | Loss=0.1785 | Acc=0.9326
  Train Progress: 7630/7636 batches | Loss=0.1705 | Acc=0.9355
  Val Progress: 477/1909 batches | Loss=0.0649 | Acc=0.9733
  Val Progress: 954/1909 batches | Loss=0.0637 | Acc=0.9725
  Val Progress: 1431/1909 batches | Loss=0.0635 | Acc=0.9733
  Val Progress: 1908/1909 batches | Loss=0.0620 | Acc=0.9741

Epoch 1/20:
  Train Loss: 0.1704 | Train Acc: 0.9355
  Val Loss: 0.0621 | Val Acc: 0.9740 | Val AUC: 0.9977
  ✓ New best validat



  Train Progress: 3052/7636 batches | Loss=0.0790 | Acc=0.9674
  Train Progress: 3815/7636 batches | Loss=0.0782 | Acc=0.9680
  Train Progress: 4578/7636 batches | Loss=0.0785 | Acc=0.9679
  Train Progress: 5341/7636 batches | Loss=0.0781 | Acc=0.9682




  Train Progress: 6104/7636 batches | Loss=0.0775 | Acc=0.9685
  Train Progress: 6867/7636 batches | Loss=0.0770 | Acc=0.9691
  Train Progress: 7630/7636 batches | Loss=0.0768 | Acc=0.9689
  Val Progress: 477/1909 batches | Loss=0.0694 | Acc=0.9704
  Val Progress: 954/1909 batches | Loss=0.0671 | Acc=0.9703
  Val Progress: 1431/1909 batches | Loss=0.0650 | Acc=0.9717
  Val Progress: 1908/1909 batches | Loss=0.0640 | Acc=0.9723

Epoch 2/20:
  Train Loss: 0.0770 | Train Acc: 0.9689
  Val Loss: 0.0641 | Val Acc: 0.9723 | Val AUC: 0.9978
  Train Progress: 763/7636 batches | Loss=0.0656 | Acc=0.9740




  Train Progress: 1526/7636 batches | Loss=0.0646 | Acc=0.9747
  Train Progress: 2289/7636 batches | Loss=0.0635 | Acc=0.9755
  Train Progress: 3052/7636 batches | Loss=0.0641 | Acc=0.9750
  Train Progress: 3815/7636 batches | Loss=0.0651 | Acc=0.9745
  Train Progress: 4578/7636 batches | Loss=0.0650 | Acc=0.9743


In [None]:
# =========================================================
# Load and Display Results
# =========================================================
import json

# Load the results dictionary
with open("feature_extractor_comparison_results.json", "r") as f:
    loaded_results = json.load(f)

print("="*100)
print("FEATURE EXTRACTOR COMPARISON RESULTS")
print("="*100)
print(f"\n{'Model Name':<20} {'Features':<35} {'Train Acc':<12} {'Val Acc':<12} {'Val Loss':<12} {'Val AUC':<12}")
print("-"*110)

for model_name, results in loaded_results.items():
    print(f"{model_name:<20} {results['features']:<35} "
          f"{results['best_train_acc']:<12.4f} {results['best_val_acc']:<12.4f} "
          f"{results['best_val_loss']:<12.4f} {results['best_val_auc']:<12.4f}")

# Find best model
best_model_name = max(loaded_results.keys(), key=lambda x: loaded_results[x]['best_val_acc'])
best_model_info = loaded_results[best_model_name]

print(f"\n{'='*100}")
print(f"BEST MODEL: {best_model_name}")
print(f"{'='*100}")
print(f"Features: {best_model_info['features']}")
print(f"Best Epoch: {best_model_info['best_epoch']}")
print(f"Training Accuracy: {best_model_info['best_train_acc']:.4f}")
print(f"Validation Accuracy: {best_model_info['best_val_acc']:.4f}")
print(f"Validation Loss: {best_model_info['best_val_loss']:.4f}")
print(f"Validation AUC: {best_model_info['best_val_auc']:.4f}")
print(f"Model Path: {best_model_info['model_path']}")

# Display full results dictionary structure
print(f"\n{'='*100}")
print("FULL RESULTS DICTIONARY STRUCTURE")
print(f"{'='*100}")
print(json.dumps(loaded_results, indent=2))


In [None]:
# =========================================================
# Visualize Training Curves for All Feature Extractors
# =========================================================
import matplotlib.pyplot as plt

if 'all_models' not in locals() or len(all_models) == 0:
    print("Note: Run ablation experiments first to generate training data.")
else:
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))

    # ---------- 1. Combined Loss Plot ----------
    ax = axes[0, 0]
    for model_name, model_data in all_models.items():
        hist = model_data["history"]
        ax.plot(hist["train_loss"], label=f"{model_name} - train", marker='o')
        ax.plot(hist["val_loss"], label=f"{model_name} - val", marker='s')
    ax.set_title("Loss Curves (Training vs Validation)")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.legend()
    ax.grid(True)

    # ---------- 2. Combined Accuracy Plot ----------
    ax = axes[0, 1]
    for model_name, model_data in all_models.items():
        hist = model_data["history"]
        ax.plot(hist["train_acc"], label=f"{model_name} - train", marker='o')
        ax.plot(hist["val_acc"], label=f"{model_name} - val", marker='s')
    ax.set_title("Accuracy Curves (Training vs Validation)")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Accuracy")
    ax.legend()
    ax.grid(True)

    # ---------- 3. Validation Loss Only ----------
    ax = axes[1, 0]
    for model_name, model_data in all_models.items():
        ax.plot(model_data["history"]["val_loss"], label=model_name, marker='s')
    ax.set_title("Validation Loss Across Models")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Val Loss")
    ax.legend()
    ax.grid(True)

    # ---------- 4. Validation Accuracy Only ----------
    ax = axes[1, 1]
    for model_name, model_data in all_models.items():
        ax.plot(model_data["history"]["val_acc"], label=model_name, marker='s')
    ax.set_title("Validation Accuracy Across Models")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Val Accuracy")
    ax.legend()
    ax.grid(True)

    plt.tight_layout()
    plt.savefig("feature_extractor_comparison_curves.png", dpi=300, bbox_inches="tight")
    plt.show()

    print("Comparison plots saved to feature_extractor_comparison_curves.png")


In [None]:
def train_one_epoch(model, dataloader, optimizer, device, epoch, loss_fn):
    import time
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    start_time = time.time()
    num_batches = len(dataloader)

    print(f"\n[TRAIN] Epoch {epoch} - Starting training on {num_batches} batches...")

    for batch_idx, batch in enumerate(dataloader):
        rgb = batch["rgb"].to(device)
        forensic = batch["forensic"].to(device)
        labels = batch["label"].to(device)

        # Log first batch details to verify data flow
        if batch_idx == 0:
            print(f"  Batch 0: RGB shape={rgb.shape}, Forensic shape={forensic.shape}, Labels={labels.tolist()[:5]}")

        optimizer.zero_grad()
        logits = model(rgb, forensic)
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * rgb.size(0)
        _, preds = logits.max(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        # Progress update every 10% of batches
        if (batch_idx + 1) % max(1, num_batches // 10) == 0 or batch_idx == 0:
            current_acc = correct / total if total > 0 else 0.0
            current_loss = running_loss / total if total > 0 else 0.0
            elapsed = time.time() - start_time
            print(f"  Progress: {batch_idx+1}/{num_batches} batches | "
                  f"Loss={current_loss:.4f} | Acc={current_acc:.4f} | "
                  f"Time={elapsed:.1f}s")

            # Show sample predictions on first batch
            if batch_idx == 0:
                sample_preds = preds[:5].tolist()
                sample_labels = labels[:5].tolist()
                sample_probs = F.softmax(logits[:5], dim=1).max(dim=1)[0].tolist()
                print(f"  Sample predictions: Pred={sample_preds}, True={sample_labels}, "
                      f"Confidence={[f'{p:.2f}' for p in sample_probs]}")

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    elapsed_time = time.time() - start_time
    print(f"[TRAIN] Epoch {epoch} Complete: Loss={epoch_loss:.4f}, Acc={epoch_acc:.4f}, "
          f"Time={elapsed_time:.1f}s ({elapsed_time/60:.1f} min)")
    return epoch_loss, epoch_acc


def eval_one_epoch(model, dataloader, device, epoch, loss_fn, mode="VAL"):
    import time
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    start_time = time.time()
    num_batches = len(dataloader)

    print(f"\n[{mode}] Epoch {epoch} - Starting evaluation on {num_batches} batches...")

    with torch.inference_mode():
        for batch_idx, batch in enumerate(dataloader):
            rgb = batch["rgb"].to(device)
            forensic = batch["forensic"].to(device)
            labels = batch["label"].to(device)

            logits = model(rgb, forensic)
            loss = loss_fn(logits, labels)

            running_loss += loss.item() * rgb.size(0)
            _, preds = logits.max(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

            # Progress update every 25% of batches for validation
            if (batch_idx + 1) % max(1, num_batches // 4) == 0 or batch_idx == 0:
                current_acc = correct / total if total > 0 else 0.0
                current_loss = running_loss / total if total > 0 else 0.0
                elapsed = time.time() - start_time
                print(f"  Progress: {batch_idx+1}/{num_batches} batches | "
                      f"Loss={current_loss:.4f} | Acc={current_acc:.4f} | "
                      f"Time={elapsed:.1f}s")

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    elapsed_time = time.time() - start_time
    print(f"[{mode}] Epoch {epoch} Complete: Loss={epoch_loss:.4f}, Acc={epoch_acc:.4f}, "
          f"Time={elapsed_time:.1f}s")
    return epoch_loss, epoch_acc


Training

In [None]:
from huggingface_hub import login
login()

In [None]:
def main():
    # ---- Config ----
    image_size = 256
    batch_size = 16
    num_epochs = 20
    lr = 1e-4
    num_workers = 4
    seed = 42

    torch.manual_seed(seed)
    np.random.seed(seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[INFO] Using device: {device}")

    # ---- Load HuggingFace dataset ----
    hf_ds = load_dataset("Hemg/AI-Generated-vs-Real-Images-Datasets")
    full_train = hf_ds["train"]  # 152,710 samples, no val/test splits

    # ---- Manual 80/20 split using HuggingFace's train_test_split ----
    # random_split only works on PyTorch Datasets, so we use HF's method instead
    split_ds = full_train.train_test_split(test_size=0.2, seed=seed)
    train_hf = split_ds["train"]
    val_hf = split_ds["test"]

    # ---- Wrap in PyTorch dataset ----
    train_ds = HFRealAIDataset(train_hf, image_size=image_size, augment=True)
    val_ds   = HFRealAIDataset(val_hf, image_size=image_size, augment=False)

    # ---- DataLoaders ----
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=True)

    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
                            num_workers=num_workers, pin_memory=True)

    # ---- Model / Loss / Optimizer ----
    model = TwoStreamNet(num_classes=2, forensic_out_dim=256).to(device)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

    # Print model info
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"[INFO] Model initialized: {total_params:,} total parameters, {trainable_params:,} trainable")
    print(f"[INFO] Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
    print(f"[INFO] Train samples: {len(train_ds)}, Val samples: {len(val_ds)}")

    # Test forward pass on a dummy batch
    print("\n[INFO] Testing forward pass...")
    model.eval()
    with torch.no_grad():
        sample_batch = next(iter(train_loader))
        sample_rgb = sample_batch["rgb"][:2].to(device)  # Just 2 samples
        sample_forensic = sample_batch["forensic"][:2].to(device)
        sample_output = model(sample_rgb, sample_forensic)
        print(f"  Input shapes: RGB={sample_rgb.shape}, Forensic={sample_forensic.shape}")
        print(f"  Output shape: {sample_output.shape}, Sample logits: {sample_output[0].tolist()}")
    model.train()
    print("[INFO] Forward pass test successful!\n")

    best_val_acc = 0.0
    import time
    training_start_time = time.time()

    # ---- Training Loop ----
    print("=" * 60)
    print("STARTING TRAINING")
    print("=" * 60)
    for epoch in range(1, num_epochs + 1):
        epoch_start = time.time()
        print(f"\n{'='*60}")
        print(f"EPOCH {epoch}/{num_epochs}")
        print(f"{'='*60}")

        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, device, epoch, loss_fn)
        val_loss, val_acc = eval_one_epoch(model, val_loader, device, epoch, loss_fn, mode="VAL")

        epoch_time = time.time() - epoch_start
        total_time = time.time() - training_start_time

        print(f"\n[EPOCH {epoch} SUMMARY]")
        print(f"  Train: Loss={train_loss:.4f}, Acc={train_acc:.4f}")
        print(f"  Val:   Loss={val_loss:.4f}, Acc={val_acc:.4f}")
        print(f"  Epoch time: {epoch_time:.1f}s ({epoch_time/60:.1f} min)")
        print(f"  Total time: {total_time/60:.1f} min")
        print(f"  Estimated remaining: {(total_time/epoch) * (num_epochs - epoch)/60:.1f} min")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            save_path = "best_two_stream_detector.pt"
            torch.save(model.state_dict(), save_path)
            print(f"  ✓ NEW BEST MODEL! Saved to {save_path} (val_acc={best_val_acc:.4f})")
        else:
            print(f"  (Best val_acc so far: {best_val_acc:.4f})")

    total_training_time = time.time() - training_start_time
    print(f"\n{'='*60}")
    print(f"[DONE] Training finished!")
    print(f"  Total time: {total_training_time/60:.1f} min ({total_training_time/3600:.2f} hours)")
    print(f"  Best validation accuracy: {best_val_acc:.4f}")
    print(f"{'='*60}")


In [None]:
# import torch
# import torch.nn as nn
# from torch.utils.data import DataLoader, random_split
# from torchvision import transforms
# from datasets import load_from_disk, load_dataset
# import numpy as np
# from PIL import Image

# # --- Assumed Dummy Functions/Classes for a complete code block ---
# # NOTE: You must define 'TwoStreamNet', 'train_one_epoch', and 'eval_one_epoch'
# # for this code to run successfully.
# # ------------------------------------------------------------------

# def main():
#     # ---- Config ----
#     dataset_source = "/content/AI-Generated-vs-Real-Images-Local" # Local path
#     # If the local path is causing issues, you can try reloading from Hugging Face:
#     # dataset_source = "Hemg/AI-Generated-vs-Real-Images-Datasets"

#     image_size = 256
#     batch_size = 16
#     num_epochs = 20
#     lr = 1e-4
#     val_split = 0.2
#     num_workers = 2
#     seed = 42

#     torch.manual_seed(seed)
#     np.random.seed(seed)

#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     print(f"[INFO] Using device: {device}")

#     # ---- Image Transforms ----
#     transform = transforms.Compose([
#         transforms.Resize((image_size, image_size)),
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
#     ])

#     # ---- Dataset Loading and Mapping ----
#     try:
#         # ATTEMPT 1: Load from the local path first
#         dataset_dict = load_from_disk(dataset_source)
#         print("[INFO] Successfully loaded dataset from disk.")
#     except Exception as e:
#         print(f"WARNING: Failed to load dataset from disk: {e}")
#         # ATTEMPT 2: If local load fails, try loading directly from Hugging Face (requires internet)
#         print("[INFO] Attempting to load dataset from Hugging Face Hub...")
#         try:
#              # Force a reload, which might resolve inconsistencies.
#             dataset_dict = load_dataset("Hemg/AI-Generated-vs-Real-Images-Datasets")
#             # If successful, save it again to fix the local copy for next time
#             dataset_dict.save_to_disk(dataset_source)
#             print("[INFO] Successfully downloaded and saved fresh copy locally.")
#         except Exception as e:
#             print(f"FATAL ERROR: Failed to load dataset from both local and hub: {e}")
#             return # Exit main if loading fails

#     # The dataset now has 'image' and 'label' columns based on your error message.
#     def preprocess_function(examples):
#         # 1. Process the image
#         # Ensure the image is converted to RGB before transforming
#         examples["pixel_values"] = [transform(img.convert("RGB")) for img in examples["image"]]

#         # 2. Process the label
#         # Assuming the 'label' column is already an integer (0 or 1)
#         # If 'label' is a string ('real', 'fake') you'd need mapping here.
#         # Based on your error message, we assume 'label' is the target column.
#         examples["labels"] = examples["label"]

#         return examples

#     # Apply the preprocessing and transformation to the 'train' split
#     processed_ds = dataset_dict['train'].map(
#         preprocess_function,
#         batched=True,
#         # Only remove the old 'image' and 'label' columns, keeping the new 'pixel_values' and 'labels'
#         remove_columns=['image', 'label']
#     )

#     full_train_ds = processed_ds.with_format("torch")

#     # ---- Dataset Split and Dataloaders ----
#     n_total = len(full_train_ds)
#     n_val = int(val_split * n_total)
#     n_train = n_total - n_val

#     train_ds, val_ds = random_split(full_train_ds, [n_train, n_val])

#     train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
#                               num_workers=num_workers, pin_memory=True)
#     val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
#                             num_workers=num_workers, pin_memory=True)

#     # ---- Model, Loss, Optimizer ----
#     model = TwoStreamNet(num_classes=2, forensic_out_dim=256).to(device)
#     loss_fn = nn.CrossEntropyLoss()
#     optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

#     best_val_acc = 0.0

#     # ---- Training Loop ----
#     for epoch in range(1, num_epochs + 1):
#         print('epoch: ', epoch)
#         # ... (rest of your training loop remains the same)
#         # train_one_epoch(...)
#         # _, val_acc = eval_one_epoch(...)

#         # # Save best model
#         # if val_acc > best_val_acc:
#         #     best_val_acc = val_acc
#         #     save_path = "best_two_stream_detector.pt"
#         #     torch.save(model.state_dict(), save_path)
#         #     print(f"[INFO] New best model saved to {save_path} (val_acc={best_val_acc:.4f})")

#     # print(f"[DONE] Training finished. Best val acc = {best_val_acc:.4f}")

In [None]:
main()

In [None]:
#using that best model as the accuracy is similar