# Set up

In [1]:
import torch
import os
import subprocess
import sys
from PIL import Image
from torchvision import utils,transforms,models
from torch.utils.data import Dataset,DataLoader
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import glob
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import torch.nn as nn
from torch import optim
import copy
import time
from tqdm import tqdm
import random
import json

In [2]:
def preview_local_images(path_list):
    """
    Fetches and displays one sample image from each local path.
    path_list: list of tuples (label, path) or list of paths
    """
    num_paths = len(path_list)
    plt.figure(figsize=(15, 5))
    
    for i, item in enumerate(path_list):
        # Handle both tuple (label, path) and plain path
        if isinstance(item, tuple):
            label, path = item
        else:
            path = item
            label = os.path.basename(path)
        
        ax = plt.subplot(1, num_paths, i + 1)
        patterns_to_check = [".png", ".jpg", ".jpeg"]
        file_list = []
        
        for root, _, files in os.walk(path):
            for file in files:
                if any(file.lower().endswith(ext) for ext in patterns_to_check):
                    file_list.append(os.path.join(root, file))
        
        if not file_list:
            print(f"No image files found in {file_list}")
            continue
        
        sample_file_path = file_list[0]
        
        try:
            image = Image.open(sample_file_path).convert('RGB')
            ax.imshow(image)
            ax.set_title(f"Class: {label}\n{os.path.basename(sample_file_path)}")
            ax.axis('off')
        except Exception as e:
            print(f"Error loading image {sample_file_path}: {e}")
            ax.set_title(f"Error loading {label}")
            ax.axis('off')
    
    plt.tight_layout()
    plt.show()

# Image Preprocessing

In [3]:
class ImageDataset(Dataset):
    """
    Custom PyTorch Dataset to load images directly from the local file system
    for use with a standard PyTorch training loop (e.g., with ResNet).
    """
    def __init__(self, file_paths, labels, transform=None):
        """
        Args:
            file_paths (list): List of local file paths to images.
            labels (list): List of corresponding labels (0 or 1).
            transform (torchvision.transforms, optional): Transformations to apply to the image.
        """
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        """Returns the total number of samples."""
        return len(self.file_paths)

    def __getitem__(self, idx):
        """
        Fetches the image from local path, applies transforms, and returns
        a standard (image_tensor, label_tensor) tuple.
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()

        local_path = self.file_paths[idx]
        label = self.labels[idx]

        try:
            image = Image.open(local_path).convert('RGB')
            if self.transform:
                image_tensor = self.transform(image)
            return image_tensor, torch.tensor(label, dtype=torch.long), local_path

        except Exception as e:
            print(f"Error loading image {local_path}: {e}")
            dummy_tensor = torch.zeros((3, 224, 224)) # Assuming 224x224
            return dummy_tensor, torch.tensor(0, dtype=torch.long), "ERROR_LOADING_FILE"

In [4]:
def get_data_mixed_structure(
    video_real_paths, video_fake_paths,
    image_real_paths, image_fake_paths,
    random_seed
):
    """
    Scans local directories and handles two types of datasets:
    1. Video-based: Split BY FOLDER to prevent frame leakage
    2. Image-based: Split BY IMAGE (no folder structure)
    
    Args:
        video_real_paths: List of paths to real video folders (e.g., Celeb-real, YouTube-real)
        video_fake_paths: List of paths to fake video folders (e.g., Celeb-synthesis)
        image_real_paths: List of paths to real image folders (e.g., FFHQ-real-v2)
        image_fake_paths: List of paths to fake image folders (e.g., StableDiffusion-fake-v2, stylegan-6000),
        random_seed: Random seed for reproducibility
    """
    
    patterns_to_check = ["*.png", "*.jpg", "*.jpeg"]
    real_video_folders = defaultdict(list)
    fake_video_folders = defaultdict(list)

    # ========================================
    # PART 1: Handle VIDEO-BASED datasets (split by folder)
    # ========================================
    for path in video_real_paths:
        for ext in patterns_to_check:
            files = glob.glob(os.path.join(path, "**", ext), recursive=True)
            for file in files:
                real_video_folders[os.path.dirname(file)].append(file)
    
    for path in video_fake_paths:
        for ext in patterns_to_check:
            files = glob.glob(os.path.join(path, "**", ext), recursive=True)
            for file in files:
                fake_video_folders[os.path.dirname(file)].append(file)

    train_real_video_folders, val_real_video_folders, test_real_video_folders = [], [], []
    train_fake_video_folders, val_fake_video_folders, test_fake_video_folders = [], [], []
    
    if len(real_video_folders) > 0:
        real_folder_names = sorted(list(real_video_folders.keys()))
        train_real_video_folders, temp_real = train_test_split(
            real_folder_names, test_size=0.3, random_state=random_seed
        )
        val_real_video_folders, test_real_video_folders = train_test_split(
            temp_real, test_size=0.5, random_state=random_seed
        )
    
    if len(fake_video_folders) > 0:
        fake_folder_names = sorted(list(fake_video_folders.keys()))
        train_fake_video_folders, temp_fake = train_test_split(
            fake_folder_names, test_size=0.3, random_state=random_seed
        )
        val_fake_video_folders, test_fake_video_folders = train_test_split(
            temp_fake, test_size=0.5, random_state=random_seed
        )

    # ========================================
    # PART 2: Handle IMAGE-BASED datasets (split by image)
    # ========================================
    real_image_files = []
    fake_image_files = []
    
    for path in image_real_paths:
        for ext in patterns_to_check:
            files = glob.glob(os.path.join(path, "**", ext), recursive=True)
            real_image_files.extend(files)
    
    for path in image_fake_paths:
        for ext in patterns_to_check:
            files = glob.glob(os.path.join(path, "**", ext), recursive=True)
            fake_image_files.extend(files)
    
    train_real_images, val_real_images, test_real_images = [], [], []
    train_fake_images, val_fake_images, test_fake_images = [], [], []
    
    if len(real_image_files) > 0:
        sorted_real_image_files = sorted(real_image_files)
        train_real_images, temp_real = train_test_split(
            sorted_real_image_files, test_size=0.3, random_state=random_seed
        )
        val_real_images, test_real_images = train_test_split(
            temp_real, test_size=0.5, random_state=random_seed
        )
    
    if len(fake_image_files) > 0:
        sorted_fake_image_files = sorted(fake_image_files)
        train_fake_images, temp_fake = train_test_split(
            sorted_fake_image_files, test_size=0.3, random_state=random_seed
        )
        val_fake_images, test_fake_images = train_test_split(
            temp_fake, test_size=0.5, random_state=random_seed
        )
    
    train_files, train_labels = [], []
    val_files, val_labels = [], []
    test_files, test_labels = [], []
    
    for folder in train_real_video_folders:
        train_files.extend(real_video_folders[folder])
        train_labels.extend([LABEL_REAL] * len(real_video_folders[folder]))
    for folder in val_real_video_folders:
        val_files.extend(real_video_folders[folder])
        val_labels.extend([LABEL_REAL] * len(real_video_folders[folder]))
    for folder in test_real_video_folders:
        test_files.extend(real_video_folders[folder])
        test_labels.extend([LABEL_REAL] * len(real_video_folders[folder]))
    
    for folder in train_fake_video_folders:
        train_files.extend(fake_video_folders[folder])
        train_labels.extend([LABEL_FAKE] * len(fake_video_folders[folder]))
    for folder in val_fake_video_folders:
        val_files.extend(fake_video_folders[folder])
        val_labels.extend([LABEL_FAKE] * len(fake_video_folders[folder]))
    for folder in test_fake_video_folders:
        test_files.extend(fake_video_folders[folder])
        test_labels.extend([LABEL_FAKE] * len(fake_video_folders[folder]))
        
    train_files.extend(train_real_images)
    train_labels.extend([LABEL_REAL] * len(train_real_images))
    val_files.extend(val_real_images)
    val_labels.extend([LABEL_REAL] * len(val_real_images))
    test_files.extend(test_real_images)
    test_labels.extend([LABEL_REAL] * len(test_real_images))
    
    train_files.extend(train_fake_images)
    train_labels.extend([LABEL_FAKE] * len(train_fake_images))
    val_files.extend(val_fake_images)
    val_labels.extend([LABEL_FAKE] * len(val_fake_images))
    test_files.extend(test_fake_images)
    test_labels.extend([LABEL_FAKE] * len(test_fake_images))

    # ========================================
    # PART 4: Print detailed statistics
    # ========================================
    print("\n" + "="*70)
    print("MIXED DATASET STATISTICS (Video-based + Image-based)")
    print("="*70)
    
    print("\n--- VIDEO-BASED DATA (split by folder) ---")
    if len(real_video_folders) > 0:
        print(f"Real video folders: {len(real_video_folders)} total")
        print(f"  Train: {len(train_real_video_folders)} folders, {sum(len(real_video_folders[f]) for f in train_real_video_folders)} frames")
        print(f"  Val:   {len(val_real_video_folders)} folders, {sum(len(real_video_folders[f]) for f in val_real_video_folders)} frames")
        print(f"  Test:  {len(test_real_video_folders)} folders, {sum(len(real_video_folders[f]) for f in test_real_video_folders)} frames")
    else:
        print("No real video data")
    
    if len(fake_video_folders) > 0:
        print(f"\nFake video folders: {len(fake_video_folders)} total")
        print(f"  Train: {len(train_fake_video_folders)} folders, {sum(len(fake_video_folders[f]) for f in train_fake_video_folders)} frames")
        print(f"  Val:   {len(val_fake_video_folders)} folders, {sum(len(fake_video_folders[f]) for f in val_fake_video_folders)} frames")
        print(f"  Test:  {len(test_fake_video_folders)} folders, {sum(len(fake_video_folders[f]) for f in test_fake_video_folders)} frames")
    else:
        print("No fake video data")
    
    print("\n--- IMAGE-BASED DATA (split by image) ---")
    if len(real_image_files) > 0:
        print(f"Real images: {len(real_image_files)} total")
        print(f"  Train: {len(train_real_images)} images")
        print(f"  Val:   {len(val_real_images)} images")
        print(f"  Test:  {len(test_real_images)} images")
    else:
        print("No real image data")
    
    if len(fake_image_files) > 0:
        print(f"\nFake images: {len(fake_image_files)} total")
        print(f"  Train: {len(train_fake_images)} images")
        print(f"  Val:   {len(val_fake_images)} images")
        print(f"  Test:  {len(test_fake_images)} images")
    else:
        print("No fake image data")
    
    print("\n--- COMBINED TOTALS ---")
    print(f"Train: {len(train_files)} total ({train_labels.count(LABEL_REAL)} real, {train_labels.count(LABEL_FAKE)} fake)")
    print(f"Val:   {len(val_files)} total ({val_labels.count(LABEL_REAL)} real, {val_labels.count(LABEL_FAKE)} fake)")
    print(f"Test:  {len(test_files)} total ({test_labels.count(LABEL_REAL)} real, {test_labels.count(LABEL_FAKE)} fake)")
    print(f"\nGrand Total: {len(train_files) + len(val_files) + len(test_files)} images")
    print("="*70)
    
    # ========================================
    # PART 5: Create datasets
    # ========================================
    train_dataset = ImageDataset(
        file_paths=train_files,
        labels=train_labels,
        transform=train_transform  # Pass the train_transform
    )
    val_dataset = ImageDataset(
        file_paths=val_files,
        labels=val_labels,
        transform=val_test_transform   # Pass the val/test_transform
    )
    test_dataset = ImageDataset(
        file_paths=test_files,
        labels=test_labels,
        transform=val_test_transform   # Pass the val/test_transform
    )

    # ========================================
    # PART 6: Save test file paths for reproducibility
    # ========================================
    import json
    
    test_metadata = {
        'random_seed': random_seed,
        'test_files': test_files,
        'test_labels': test_labels,
        'num_test_samples': len(test_files),
        'num_real': test_labels.count(LABEL_REAL),
        'num_fake': test_labels.count(LABEL_FAKE),
    }
    
    # Save to JSON file
    metadata_path = f"test_set_seed_{random_seed}.json"
    with open(metadata_path, 'w') as f:
        json.dump(test_metadata, f, indent=2)
    
    print(f"\nâœ“ Test set metadata saved to: {metadata_path}")
    
    return train_dataset, val_dataset, test_dataset
    

# Modeling

In [5]:
def get_model(model_name: str, pretrained: bool = True, device: str = 'cuda'):
    """
    Loads a pre-trained ResNet model (e.g., 'resnet34', 'resnet50') 
    from torchvision, adapts its final layer, and moves it to the device.
    """
    print(f"Loading model: {model_name} (pretrained={pretrained})")
    
    # Use 'DEFAULT' string for the newest weights
    weights = 'DEFAULT' if pretrained else None
    
    try:
        model = models.get_model(model_name, weights=weights)
    except AttributeError:
        print(f"Error: Model '{model_name}' not found in torchvision.models.")
        raise

    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, 2)
    model = model.to(device)
    print(f"Model {model_name} is ready and on device: {device}")
        
    return model

In [6]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    """
    Runs a single training epoch.
    """
    model.train()  # Set model to training mode
    running_loss = 0.0
    correct_preds = 0
    total_samples = 0
    
    pbar = tqdm(loader, desc="Training", leave=False)
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item() * images.size(0)
        preds = torch.argmax(outputs, dim=1)
        correct_preds += (preds == labels).sum().item()
        total_samples += labels.size(0)
        
        pbar.set_postfix(loss=(running_loss / total_samples), acc=(correct_preds / total_samples * 100.0))

    epoch_loss = running_loss / total_samples
    epoch_acc = (correct_preds / total_samples) * 100.0
    return epoch_loss, epoch_acc

In [7]:
def validate_one_epoch(model, loader, criterion, device):
    """
    Runs a single validation epoch.
    """
    model.eval()  # Set model to evaluation mode
    running_loss = 0.0
    correct_preds = 0
    total_samples = 0
    
    pbar = tqdm(loader, desc="Validating", leave=False)
    with torch.no_grad():
        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Statistics
            running_loss += loss.item() * images.size(0)
            preds = torch.argmax(outputs, dim=1)
            correct_preds += (preds == labels).sum().item()
            total_samples += labels.size(0)
            
            pbar.set_postfix(loss=(running_loss / total_samples), acc=(correct_preds / total_samples * 100.0))

    epoch_loss = running_loss / total_samples
    epoch_acc = (correct_preds / total_samples) * 100.0
    return epoch_loss, epoch_acc

In [8]:
def main_training_loop(model, train_loader, val_loader, criterion, optimizer, num_epochs, model_save_path, device):
    """
    Manages the main training and validation loop over N epochs,
    saving the best performing model.
    """
    best_val_acc = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())
    
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': []
    }
    
    start_time = time.time()
    
    for epoch in range(num_epochs):
        print(f"\n--- Epoch {epoch+1}/{num_epochs} ---")
        
        # Train
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        
        # Validate
        val_loss, val_acc = validate_one_epoch(model, val_loader, criterion, device)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f"Epoch {epoch+1} Complete: "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        # Save the best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            torch.save(model.state_dict(), model_save_path)
            print(f"New best model saved with Val Acc: {val_acc:.2f}%")

    time_elapsed = time.time() - start_time
    print(f"\nTraining complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s")
    print(f"Best Val Acc: {best_val_acc:.4f}")
    
    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model, history

In [9]:
def test_model(model, test_loader, test_dataset, device):
    """
    Evaluates the final model on the test set and prints:
    1. Overall classification report
    2. Per-domain classification reports (Celeb-real, YouTube-real, FFHQ, etc.)
    """
    print("\nStarting evaluation on test set...")
    model.eval()  # Set model to evaluation mode
    
    all_labels = []
    all_preds = []
    all_file_paths = []
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Testing"):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
    
    # Get file paths from test_dataset
    all_file_paths = test_dataset.file_paths
    
    # ========================================
    # 1. OVERALL REPORT
    # ========================================
    print("\n" + "="*70)
    print("OVERALL TEST SET CLASSIFICATION REPORT")
    print("="*70)
    report = classification_report(
        all_labels, all_preds, target_names=["REAL", "FAKE"], digits=4
    )
    print(report)
    print("="*70)
    
    # ========================================
    # 2. PER-DOMAIN REPORTS
    # ========================================
    print("\n" + "="*70)
    print("PER-DOMAIN CLASSIFICATION REPORTS")
    print("="*70)
    
    # Define domain patterns (adjust these to match your actual paths)
    domains = {
        'Celeb-real': '/Celeb-real/',
        'YouTube-real': '/YouTube-real/',
        'Celeb-synthesis': '/Celeb-synthesis/',
        'FFHQ-real': '/FFHQ-real',
        'StableDiffusion-fake': '/StableDiffusion-fake',
        'StyleGAN-fake': '/stylegan',
    }
    
    for domain_name, domain_pattern in domains.items():
        # Find indices that belong to this domain
        domain_indices = [i for i, path in enumerate(all_file_paths) 
                         if domain_pattern in path]
        
        if len(domain_indices) == 0:
            continue  # Skip if no samples from this domain
        
        # Extract predictions and labels for this domain
        domain_labels = [all_labels[i] for i in domain_indices]
        domain_preds = [all_preds[i] for i in domain_indices]
        
        # Print domain report
        print(f"\n--- {domain_name} ({len(domain_indices)} samples) ---")
        
        # Calculate accuracy
        correct = sum(1 for i in range(len(domain_labels)) 
                     if domain_labels[i] == domain_preds[i])
        accuracy = correct / len(domain_labels) * 100
        
        print(f"Accuracy: {accuracy:.2f}% ({correct}/{len(domain_labels)})")
        
        # Only print detailed report if both classes are present
        unique_labels = set(domain_labels)
        if len(unique_labels) > 1:
            domain_report = classification_report(
                domain_labels, domain_preds, 
                target_names=["REAL", "FAKE"], 
                digits=4,
                zero_division=0
            )
            print(domain_report)
        else:
            # Single class - just show confusion
            label_name = "REAL" if 0 in unique_labels else "FAKE"
            print(f"  (All samples are {label_name})")
            print(f"  True Positives: {correct}")
            print(f"  False Negatives: {len(domain_labels) - correct}")
    
    print("="*70)
    
    return all_labels, all_preds

# Main

In [10]:
# --- 1. Configuration ---
REAL_CELEB_PATH = "/kaggle/input/deepfake-images/Celeb-DF/data/Celeb-real"
REAL_YOUTUBE_PATH = "/kaggle/input/deepfake-images/Celeb-DF/data/YouTube-real"
FAKE_CELEB_PATH = "/kaggle/input/deepfake-images/Celeb-DF/data/Celeb-synthesis"
FFHQ_REAL_PATH = "/kaggle/input/deepfake-images/FFHQ-real-v2/FFHQ-real-v2"
SD_PATH = "/kaggle/input/deepfake-images/StableDiffusion-fake-v2/StableDiffusion-fake-v2"
GAN_PATH = "/kaggle/input/stylegan-6000/kaggle/working/stylegan_fake_dataset_nvidia"

# Define labels
LABEL_REAL = 0
LABEL_FAKE = 1

# Model and Training Hyperparameters
MODEL_NAME = "resnet34"
MODEL_SAVE_PATH = f"/kaggle/working/best_{MODEL_NAME}_by_folder.pth"

IMG_SIZE = 224
BATCH_SIZE = 32
NUM_EPOCHS = 5 
LEARNING_RATE = 1e-4
RANDOM_SEED = 42

In [11]:
# --- 2. Setup Device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set seed for reproducibility
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

Using device: cuda


In [12]:
PATHS_TO_PREVIEW = [
    ("Celeb-real", REAL_CELEB_PATH),
    ("YouTube-real", REAL_YOUTUBE_PATH),
    ("Celeb-synthesis", FAKE_CELEB_PATH),
    ("FFHQ", FFHQ_REAL_PATH),
    ("Stable-Diffusion", SD_PATH),
    ("Style-GAN", GAN_PATH )
]

In [None]:
preview_local_images(PATHS_TO_PREVIEW)

In [None]:
# --- 3. Define Image Transforms ---
# Standard ImageNet normalization
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    normalize,
])

val_test_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    normalize
])

In [None]:
# --- 4. Create Datasets (Split by Folder) ---
print("Loading and splitting data by folder...")
video_real_paths = [REAL_CELEB_PATH,REAL_YOUTUBE_PATH]
video_fake_paths = [FAKE_CELEB_PATH]
image_real_paths = [FFHQ_REAL_PATH]
image_fake_paths = [SD_PATH,GAN_PATH]

train_dataset, val_dataset, test_dataset = get_data_mixed_structure(
    video_real_paths = video_real_paths,
    video_fake_paths = video_fake_paths,
    image_real_paths = image_real_paths,
    image_fake_paths = image_fake_paths,
    random_seed = RANDOM_SEED
)

In [None]:
# --- 5. Create DataLoaders ---
num_workers = 2 if device.type == 'cuda' else 0

train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=num_workers, 
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=num_workers, 
    pin_memory=True
)
test_loader = DataLoader(
    test_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=num_workers, 
    pin_memory=True
)

In [None]:
# --- 6. Initialize Model, Loss, and Optimizer ---
model = get_model(MODEL_NAME, pretrained=True, device=device)

# Loss Function (CrossEntropyLoss is standard for classification)
criterion = nn.CrossEntropyLoss()

# Optimizer (Adam is a good default)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
# --- 7. Run Training ---
print(f"Starting training for {NUM_EPOCHS} epochs...")
model, history = main_training_loop(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=NUM_EPOCHS,
    model_save_path=MODEL_SAVE_PATH,
    device=device
)
print("Training finished.")

In [None]:
# --- 8. Run Final Testing ---
print(f"Loading best model from {MODEL_SAVE_PATH} for final testing...")
# Load the best model weights that were saved during training
model.load_state_dict(torch.load(MODEL_SAVE_PATH))


domains = {
    'Celeb-real (video)': '/kaggle/input/deepfake-images/Celeb-DF/data/Celeb-real',
    'YouTube-real (video)': '/kaggle/input/deepfake-images/Celeb-DF/data/YouTube-real',
    'Celeb-synthesis (video)': '/kaggle/input/deepfake-images/Celeb-DF/data/Celeb-synthesis',
    'FFHQ-real (image)': '/kaggle/input/deepfake-images/FFHQ-real-v2/FFHQ-real-v2',
    'StableDiffusion-fake (image)': '/kaggle/input/deepfake-images/StableDiffusion-fake-v2/StableDiffusion-fake-v2',
    'StyleGAN-fake (image)': '/kaggle/input/stylegan-6000',
}

# Run evaluation on the test set
test_model(model, test_loader, test_dataset, device)

print("\n--- Pipeline Complete ---")

In [None]:
!ls -l /kaggle/working/

# GradCam

In [None]:
!pip install grad-cam
import cv2
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

In [None]:
def visualize_gradcam(model, test_dataset, target_layer, device, num_images=5):
    """
    Finds correct FAKE predictions, runs Grad-CAM++, and prints the
    visualizations directly to the notebook.
    """
    # Inverse normalization transform for displaying images
    inv_normalize = transforms.Normalize(
        mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
        std=[1/0.229, 1/0.224, 1/0.225]
    )
    
    # Initialize Grad-CAM
    cam =GradCAM(model=model, target_layers=[target_layer])
    
    # Target class is 1 (FAKE)
    targets = [ClassifierOutputTarget(LABEL_FAKE)]
    
    # --- Find correctly predicted FAKE images ---
    print(f"Searching for {num_images} correctly predicted 'FAKE' images...")
    correct_fake_samples = []
    
    # Shuffle dataset to get random samples
    indices = list(range(len(test_dataset)))
    random.shuffle(indices)
    
    model.eval()
    with torch.no_grad():
        for idx in tqdm(indices, desc="Finding samples"):
            
            # --- THIS IS THE FIX ---
            # Only unpack 2 items, not 3
            image_tensor, label = test_dataset[idx]
            # --- END FIX ---
            
            # Check if it's a FAKE image
            if label == LABEL_FAKE:
                # Get model prediction
                output = model(image_tensor.unsqueeze(0).to(device))
                pred = torch.argmax(output, dim=1).item()
                
                # Check if prediction was correct (pred == 1)
                if pred == LABEL_FAKE:
                    # Save the (unnormalized) tensor for visualization
                    rgb_img = inv_normalize(image_tensor).permute(1, 2, 0).numpy()
                    rgb_img = np.clip(rgb_img, 0, 1) # Clip values to be valid image
                    
                    correct_fake_samples.append((image_tensor, rgb_img))
                    
                    if len(correct_fake_samples) >= num_images:
                        break
                        
    if not correct_fake_samples:
        print("Could not find any correctly predicted 'FAKE' images to visualize.")
        return

    print(f"Found {len(correct_fake_samples)} samples. Generating visualizations...")

    # --- Create the visualization plot ---
    # (The rest of the function is identical and correct)
    fig, axs = plt.subplots(len(correct_fake_samples), 3, figsize=(15, len(correct_fake_samples) * 5))
    fig.suptitle("Grad-CAM++ Visualization for 'FAKE' Predictions", fontsize=20, y=1.02)
    
    for i, (input_tensor, rgb_img) in enumerate(correct_fake_samples):
        # Generate the CAM
        input_tensor_cam = input_tensor.unsqueeze(0)
        grayscale_cam = cam(input_tensor=input_tensor_cam, targets=targets)[0, :]
        
        # Create the overlay
        visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
        
        # Plot Original
        axs[i, 0].imshow(rgb_img)
        axs[i, 0].set_title(f"Sample {i+1}: Original Image")
        axs[i, 0].axis('off')
        
        # Plot Heatmap
        axs[i, 1].imshow(grayscale_cam, cmap='jet')
        axs[i, 1].set_title("Grad-CAM++ Heatmap")
        axs[i, 1].axis('off')
        
        # Plot Overlay
        axs[i, 2].imshow(visualization)
        axs[i, 2].set_title("Overlay")
        axs[i, 2].axis('off')
        
    plt.tight_layout()
    plt.show()

In [None]:
# For ResNet-50, 'layer4' is the last convolutional block
target_layer = model.layer4

# --- Run Visualization ---
visualize_gradcam(
    model=model,
    test_dataset=test_dataset,
    target_layer=target_layer,
    device=device,
    num_images=10  # Display 10 examples
)

# Model testing

In [None]:
# load model
model = "/kaggle/input/resnet-34/pytorch/default/1"

