# Set up

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTForImageClassification, TrainingArguments, Trainer, ViTImageProcessor, EarlyStoppingCallback, AutoImageProcessor, SwinForImageClassification
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report, confusion_matrix, roc_auc_score
from PIL import Image, ImageFilter
import os
import io
import numpy as np
from matplotlib import pyplot as plt
import sys
import random
from collections import defaultdict
import glob
from tqdm.notebook import tqdm

2025-11-21 22:59:37.506693: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763765977.688317      48 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763765977.744641      48 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

# Image Preprocessing

In [None]:
class KaggleImageDataset(Dataset):
    """
    Custom PyTorch Dataset to load images directly from the local file system
    for use with HuggingFace Trainer, with optional augmentations.
    """
    def __init__(self, file_paths, labels, processor, is_train=False):
        self.file_paths = file_paths
        self.labels = labels
        self.processor = processor
        self.is_train = is_train

        # Define augmentations for training
        if self.is_train:
            self.augmentations = transforms.Compose([
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomApply([
                    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
                ], p=0.3),
                transforms.RandomApply([
                    transforms.GaussianBlur(kernel_size=3)
                ], p=0.3),
                transforms.RandomRotation(degrees=10),
            ])
        else:
            self.augmentations = None

    def __len__(self):
        """Returns the total number of samples."""
        return len(self.file_paths)

    def __getitem__(self, idx):
        """
        Fetches the image from local path, applies augmentations and processor,
        and returns the sample in HuggingFace format.
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # Get the local path and label
        local_path = self.file_paths[idx]
        label = self.labels[idx]

        try:
            # Open the file from local path
            image = Image.open(local_path).convert('RGB')

            # Apply augmentations BEFORE the processor (only for training)
            if self.augmentations is not None:
                image = self.augmentations(image)

            # Use the ViT processor (handles resizing and normalization)
            processed = self.processor(images=image, return_tensors="pt")

            # Extract the pixel values and remove the batch dimension
            pixel_values = processed['pixel_values'].squeeze(0)

            # Return in HuggingFace format
            return {
                'pixel_values': pixel_values,
                'labels': torch.tensor(label, dtype=torch.long)
            }

        except Exception as e:
            print(f"Error loading image {local_path}: {e}")
            # Return a dummy sample if loading fails
            dummy_image = Image.new('RGB', (224, 224), color='black')
            processed = self.processor(images=dummy_image, return_tensors="pt")
            return {
                'pixel_values': processed['pixel_values'].squeeze(0),
                'labels': torch.tensor(0, dtype=torch.long)
            }

In [None]:
def get_data_mixed_structure(video_real_paths, video_fake_paths,
                              image_real_paths, image_fake_paths,
                              model_name, random_seed):
    """
    Scans local directories and handles two types of datasets:
    1. Video-based: Split BY FOLDER to prevent frame leakage
    2. Image-based: Split BY IMAGE (no folder structure)

    Args:
        video_real_paths: List of paths to real video folders (e.g., Celeb-real, YouTube-real)
        video_fake_paths: List of paths to fake video folders (e.g., Celeb-synthesis)
        image_real_paths: List of paths to real image folders (e.g., FFHQ-real-v2)
        image_fake_paths: List of paths to fake image folders (e.g., StableDiffusion-fake-v2, stylegan-6000)
        model_name: HuggingFace model name for processor
        random_seed: Random seed for reproducibility
    """

    patterns_to_check = ["*.png", "*.jpg", "*.jpeg"]

    # ========================================
    # PART 1: Handle VIDEO-BASED datasets (split by folder)
    # ========================================
    real_video_folders = defaultdict(list)
    fake_video_folders = defaultdict(list)

    # Get REAL video folders
    for path in video_real_paths:
        for ext in patterns_to_check:
            files = glob.glob(os.path.join(path, "**", ext), recursive=True)
            for file in files:
                parent_folder = os.path.dirname(file)
                real_video_folders[parent_folder].append(file)

        print(f"Found {len(real_video_folders)} REAL video folders in {path}")
        total_files = sum(len(files) for files in real_video_folders.values())
        print(f"  Total REAL video frames: {total_files}")

    # Get FAKE video folders
    for path in video_fake_paths:
        for ext in patterns_to_check:
            files = glob.glob(os.path.join(path, "**", ext), recursive=True)
            for file in files:
                parent_folder = os.path.dirname(file)
                fake_video_folders[parent_folder].append(file)

        print(f"Found {len(fake_video_folders)} FAKE video folders in {path}")
        total_files = sum(len(files) for files in fake_video_folders.values())
        print(f"  Total FAKE video frames: {total_files}")

    # Split video folders (70/15/15)
    train_real_video_folders, val_real_video_folders, test_real_video_folders = [], [], []
    train_fake_video_folders, val_fake_video_folders, test_fake_video_folders = [], [], []

    if len(real_video_folders) > 0:
        real_folder_names = list(real_video_folders.keys())
        train_real_video_folders, temp_real = train_test_split(
            real_folder_names, test_size=0.3, random_state=random_seed
        )
        val_real_video_folders, test_real_video_folders = train_test_split(
            temp_real, test_size=0.5, random_state=random_seed
        )

    if len(fake_video_folders) > 0:
        fake_folder_names = list(fake_video_folders.keys())
        train_fake_video_folders, temp_fake = train_test_split(
            fake_folder_names, test_size=0.3, random_state=random_seed
        )
        val_fake_video_folders, test_fake_video_folders = train_test_split(
            temp_fake, test_size=0.5, random_state=random_seed
        )

    # ========================================
    # PART 2: Handle IMAGE-BASED datasets (split by image)
    # ========================================
    real_image_files = []
    fake_image_files = []

    # Get REAL image files
    for path in image_real_paths:
        for ext in patterns_to_check:
            files = glob.glob(os.path.join(path, "**", ext), recursive=True)
            real_image_files.extend(files)
        print(f"Found {len([f for f in real_image_files if path in f])} REAL images in {path}")

    print(f"  Total REAL images: {len(real_image_files)}")

    # Get FAKE image files
    for path in image_fake_paths:
        for ext in patterns_to_check:
            files = glob.glob(os.path.join(path, "**", ext), recursive=True)
            fake_image_files.extend(files)
        print(f"Found {len([f for f in fake_image_files if path in f])} FAKE images in {path}")

    print(f"  Total FAKE images: {len(fake_image_files)}")

    # Split image files (70/15/15)
    train_real_images, val_real_images, test_real_images = [], [], []
    train_fake_images, val_fake_images, test_fake_images = [], [], []

    if len(real_image_files) > 0:
        train_real_images, temp_real = train_test_split(
            real_image_files, test_size=0.3, random_state=random_seed
        )
        val_real_images, test_real_images = train_test_split(
            temp_real, test_size=0.5, random_state=random_seed
        )

    if len(fake_image_files) > 0:
        train_fake_images, temp_fake = train_test_split(
            fake_image_files, test_size=0.3, random_state=random_seed
        )
        val_fake_images, test_fake_images = train_test_split(
            temp_fake, test_size=0.5, random_state=random_seed
        )

    # ========================================
    # PART 3: Combine video-based and image-based data
    # ========================================
    train_files, train_labels = [], []
    val_files, val_labels = [], []
    test_files, test_labels = [], []

    # Add REAL VIDEO frames to splits
    for folder in train_real_video_folders:
        train_files.extend(real_video_folders[folder])
        train_labels.extend([LABEL_REAL] * len(real_video_folders[folder]))

    for folder in val_real_video_folders:
        val_files.extend(real_video_folders[folder])
        val_labels.extend([LABEL_REAL] * len(real_video_folders[folder]))

    for folder in test_real_video_folders:
        test_files.extend(real_video_folders[folder])
        test_labels.extend([LABEL_REAL] * len(real_video_folders[folder]))

    # Add FAKE VIDEO frames to splits
    for folder in train_fake_video_folders:
        train_files.extend(fake_video_folders[folder])
        train_labels.extend([LABEL_FAKE] * len(fake_video_folders[folder]))

    for folder in val_fake_video_folders:
        val_files.extend(fake_video_folders[folder])
        val_labels.extend([LABEL_FAKE] * len(fake_video_folders[folder]))

    for folder in test_fake_video_folders:
        test_files.extend(fake_video_folders[folder])
        test_labels.extend([LABEL_FAKE] * len(fake_video_folders[folder]))

    # Add REAL IMAGES to splits
    train_files.extend(train_real_images)
    train_labels.extend([LABEL_REAL] * len(train_real_images))

    val_files.extend(val_real_images)
    val_labels.extend([LABEL_REAL] * len(val_real_images))

    test_files.extend(test_real_images)
    test_labels.extend([LABEL_REAL] * len(test_real_images))

    # Add FAKE IMAGES to splits
    train_files.extend(train_fake_images)
    train_labels.extend([LABEL_FAKE] * len(train_fake_images))

    val_files.extend(val_fake_images)
    val_labels.extend([LABEL_FAKE] * len(val_fake_images))

    test_files.extend(test_fake_images)
    test_labels.extend([LABEL_FAKE] * len(test_fake_images))

    # ========================================
    # PART 4: Print detailed statistics
    # ========================================
    print("\n" + "="*70)
    print("MIXED DATASET STATISTICS (Video-based + Image-based)")
    print("="*70)

    print("\n--- VIDEO-BASED DATA (split by folder) ---")
    if len(real_video_folders) > 0:
        print(f"Real video folders: {len(real_video_folders)} total")
        print(f"  Train: {len(train_real_video_folders)} folders, {sum(len(real_video_folders[f]) for f in train_real_video_folders)} frames")
        print(f"  Val:   {len(val_real_video_folders)} folders, {sum(len(real_video_folders[f]) for f in val_real_video_folders)} frames")
        print(f"  Test:  {len(test_real_video_folders)} folders, {sum(len(real_video_folders[f]) for f in test_real_video_folders)} frames")
    else:
        print("No real video data")

    if len(fake_video_folders) > 0:
        print(f"\nFake video folders: {len(fake_video_folders)} total")
        print(f"  Train: {len(train_fake_video_folders)} folders, {sum(len(fake_video_folders[f]) for f in train_fake_video_folders)} frames")
        print(f"  Val:   {len(val_fake_video_folders)} folders, {sum(len(fake_video_folders[f]) for f in val_fake_video_folders)} frames")
        print(f"  Test:  {len(test_fake_video_folders)} folders, {sum(len(fake_video_folders[f]) for f in test_fake_video_folders)} frames")
    else:
        print("No fake video data")

    print("\n--- IMAGE-BASED DATA (split by image) ---")
    if len(real_image_files) > 0:
        print(f"Real images: {len(real_image_files)} total")
        print(f"  Train: {len(train_real_images)} images")
        print(f"  Val:   {len(val_real_images)} images")
        print(f"  Test:  {len(test_real_images)} images")
    else:
        print("No real image data")

    if len(fake_image_files) > 0:
        print(f"\nFake images: {len(fake_image_files)} total")
        print(f"  Train: {len(train_fake_images)} images")
        print(f"  Val:   {len(val_fake_images)} images")
        print(f"  Test:  {len(test_fake_images)} images")
    else:
        print("No fake image data")

    print("\n--- COMBINED TOTALS ---")
    print(f"Train: {len(train_files)} total ({train_labels.count(LABEL_REAL)} real, {train_labels.count(LABEL_FAKE)} fake)")
    print(f"Val:   {len(val_files)} total ({val_labels.count(LABEL_REAL)} real, {val_labels.count(LABEL_FAKE)} fake)")
    print(f"Test:  {len(test_files)} total ({test_labels.count(LABEL_REAL)} real, {test_labels.count(LABEL_FAKE)} fake)")
    print(f"\nGrand Total: {len(train_files) + len(val_files) + len(test_files)} images")
    print("="*70)

    # ========================================
    # PART 5: Create datasets
    # ========================================
    processor = AutoImageProcessor.from_pretrained(model_name, use_fast = True)

    train_dataset = KaggleImageDataset(
        file_paths=train_files,
        labels=train_labels,
        processor=processor,
        is_train=True
    )
    val_dataset = KaggleImageDataset(
        file_paths=val_files,
        labels=val_labels,
        processor=processor,
        is_train=False
    )
    test_dataset = KaggleImageDataset(
        file_paths=test_files,
        labels=test_labels,
        processor=processor,
        is_train=False
    )

    return train_dataset, val_dataset, test_dataset, processor

In [None]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    # Already converted to predictions, not logits
    if isinstance(predictions, torch.Tensor):
        predictions = predictions.cpu().numpy()
    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()

    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, predictions, average='binary'
    )

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

In [None]:
class MemoryEfficientTrainer(Trainer):
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        """
        Override to return predictions instead of full logits
        """
        inputs = self._prepare_inputs(inputs)

        with torch.no_grad():
            outputs = model(**inputs)
            loss = outputs.loss
            logits = outputs.logits

        # Return predictions instead of logits to save memory
        if prediction_loss_only:
            return (loss, None, None)

        # Convert to predictions immediately
        preds = torch.argmax(logits, dim=-1)
        labels = inputs.get("labels")

        return (loss, preds, labels)

In [None]:
def attetion_rollout(attentions, discard_ratio=0.9):
    """
    Compute attention rollout from all transformer layers.
    Args:
        attentions: tuple of attention tensors from each layer
        discard_ratio: percentage of lowest attention values to discard
    Returns:
        Attention map for the [CLS] token
    """
    # Get device from first attention tensor
    device = attentions[0].device

    # Create identity matrix on the same device
    result = torch.eye(attentions[0].size(-1)).to(device)

    for attention in attentions:
        # Average across all heads
        attention_heads_fused = attention.mean(dim=1)
        attention_heads_fused = attention_heads_fused[0]

        # Drop the lowest attentions
        flat = attention_heads_fused.view(-1)
        _, indices = flat.topk(k=int(flat.size(-1) * discard_ratio), largest=False)
        flat[indices] = 0

        # Normalize
        I = torch.eye(attention_heads_fused.size(-1)).to(device)  # Fix: add .to(device)
        a = (attention_heads_fused + 1.0 * I) / 2
        a = a / a.sum(dim=-1, keepdim=True)
        result = torch.matmul(a, result)

    mask = result[0, 1:]
    return mask

In [None]:
def visualize_attention(model, image_path, processor, true_label=None):
    """
    Visualize attention rollout for a single image.

    Args:
        model: ViT model with output_attentions=True
        image_path: Path to local image file
        processor: ViTImageProcessor for preprocessing
        true_label: Optional true label (0 for FAKE, 1 for REAL, or string)
    """
    # Load image from local path
    image = Image.open(image_path).convert('RGB')

    # Process image
    inputs = processor(images=image, return_tensors="pt")

    # Move inputs to same device as model
    device = next(model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Get model outputs with attentions
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)

    # Get attention weights
    attentions = outputs.attentions  # tuple of (num_layers) tensors

    # Compute attention rollout
    mask = attetion_rollout(attentions)

    # Reshape mask to image dimensions
    num_patches = int(mask.shape[0] ** 0.5)
    mask = mask.reshape(num_patches, num_patches).cpu().numpy()

    # Resize to original image size
    mask = Image.fromarray((mask * 255).astype(np.uint8)).resize(
        image.size, resample=Image.BILINEAR
    )
    mask = np.array(mask) / 255.0

    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Original image
    axes[0].imshow(image)
    axes[0].set_title('Original Image')
    axes[0].axis('off')

    # Attention heatmap
    axes[1].imshow(mask, cmap='jet')
    axes[1].set_title('Attention Rollout')
    axes[1].axis('off')

    # Overlay
    axes[2].imshow(image)
    axes[2].imshow(mask, cmap='jet', alpha=0.5)
    axes[2].set_title('Overlay')
    axes[2].axis('off')

    # Get prediction
    prediction = outputs.logits.argmax(-1).item()
    pred_label = model.config.id2label[prediction]
    prob = torch.softmax(outputs.logits, dim=-1)[0][prediction].item()

    # Build title with prediction and true label
    title_parts = [f'Prediction: {pred_label} ({prob:.2%})']

    if true_label is not None:
        # Convert true_label to string if it's numeric
        if isinstance(true_label, (int, np.integer)):
            true_label_str = model.config.id2label[true_label]
        else:
            true_label_str = true_label

        # Check if prediction is correct
        is_correct = (pred_label == true_label_str)a
        correctness = "✓" if is_correct else "✗"

        title_parts.append(f'True Label: {true_label_str} {correctness}')

    fig.suptitle(' | '.join(title_parts), fontsize=16)

    plt.tight_layout()
    plt.show()

    return mask

In [None]:
def visualize_attention_per_domain(model, test_dataset, processor, n_samples=10):
    """
    Visualize attention rollout for random samples from each domain.

    Args:
        model: ViT model with output_attentions=True
        test_dataset: Test dataset
        processor: ViTImageProcessor
        n_samples: Number of samples to visualize per domain
        random_seed: Random seed for reproducibility
    """
    # Define domains
    domains = {
        'Celeb-real (video)': 'Celeb-real',
        'YouTube-real (video)': 'YouTube-real',
        'Celeb-synthesis (video)': 'Celeb-synthesis',
        'FFHQ-real (image)': 'FFHQ-real-v2',
        'StableDiffusion-fake (image)': 'StableDiffusion-fake-v2',
        'StyleGAN-fake (image)': 'stylegan',
    }

    paths = test_dataset.file_paths
    labels = test_dataset.labels

    # Process each domain
    for domain_name, pattern in domains.items():
        print("\n" + "="*70)
        print(f"VISUALIZING: {domain_name}")
        print("="*70)

        # Find indices for this domain
        domain_indices = [i for i, p in enumerate(paths) if pattern in p]

        if len(domain_indices) == 0:
            print(f"No samples found for {domain_name}")
            continue

        print(f"Total samples in domain: {len(domain_indices)}")

        # Sample random indices
        n_to_sample = min(n_samples, len(domain_indices))
        sampled_indices = np.random.choice(domain_indices, size=n_to_sample, replace=False)

        print(f"Visualizing {n_to_sample} random samples...\n")

        # Visualize each sample
        for idx in sampled_indices:
            image_path = paths[idx]
            true_label = labels[idx]

            print(f"Sample {idx}: {image_path.split('/')[-1]}")
            visualize_attention(model, image_path, processor, true_label=true_label)
            print()  # Add spacing between visualizations

In [None]:
def evaluate_per_domain(trainer, test_dataset):
    """
    Simpler version using trainer.predict()
    """
    from sklearn.metrics import classification_report

    # Get predictions
    predictions = trainer.predict(test_dataset)
    preds = predictions.predictions
    labels = predictions.label_ids     # Shape: (n_samples,)
    paths = test_dataset.file_paths

    # Overall report
    print("\n" + "="*70)
    print("OVERALL RESULTS")
    print("="*70)
    print(classification_report(labels, preds, target_names=["FAKE", "REAL"], digits=4))

    # Per-domain analysis
    domains = {
        'Celeb-real': 'Celeb-real',
        'YouTube-real': 'YouTube-real',
        'Celeb-synthesis': 'Celeb-synthesis',
        'FFHQ-real': 'FFHQ-real',
        'StableDiffusion-fake': 'StableDiffusion-fake-v2',
        'StyleGAN-fake': 'stylegan',
    }

    for domain_name, pattern in domains.items():
        indices = [i for i, p in enumerate(paths) if pattern in p]
        if not indices:
            continue

        domain_labels = [labels[i] for i in indices]
        domain_preds = [preds[i] for i in indices]

        print(f"\n--- {domain_name} ({len(indices)} samples) ---")
        if len(set(domain_labels)) > 1:
            print(classification_report(domain_labels, domain_preds,
                                       target_names=["FAKE", "REAL"], digits=4))
        else:
            acc = sum(1 for i in range(len(domain_labels)) if domain_labels[i] == domain_preds[i])
            print(f"Accuracy: {acc/len(domain_labels)*100:.2f}%")


In [None]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
# Define Parameters and Input
video_real_paths = [
    "/kaggle/input/deepfake-images/Celeb-DF/data/Celeb-real",
    "/kaggle/input/deepfake-images/Celeb-DF/data/YouTube-real"
]
video_fake_paths = [
    "/kaggle/input/deepfake-images/Celeb-DF/data/Celeb-synthesis"
]

image_real_paths = [
    "/kaggle/input/deepfake-images/FFHQ-real-v2/FFHQ-real-v2"
]

image_fake_paths = [
    "/kaggle/input/deepfake-images/StableDiffusion-fake-v2/StableDiffusion-fake-v2",
    "/kaggle/input/stylegan-6000/kaggle/working/stylegan_fake_dataset_nvidia"
]

LABEL_REAL = 0
LABEL_FAKE = 1

IMG_SIZE = 224
BATCH_SIZE = 16
RANDOM_SEED = 42
EPOCHS = 5
LEARNING_RATE = 2e-5

model_name = "google/vit-base-patch16-224"

In [None]:
seed_everything(RANDOM_SEED)

In [None]:
train_dataset, val_dataset, test_dataset, processor = get_data_mixed_structure(
    video_real_paths = video_real_paths,
    video_fake_paths = video_fake_paths,
    image_real_paths = image_real_paths,
    image_fake_paths = image_fake_paths,
    model_name = model_name,
    random_seed = RANDOM_SEED
)

In [None]:
print(f"Training dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

# Modeling

In [None]:
model = ViTForImageClassification.from_pretrained(
    model_name,
    num_labels = 2,
    id2label = {0: "REAL", 1: "FAKE"},
    label2id = {"REAL": 0, "FAKE": 1},
    ignore_mismatched_sizes = True,
    output_attentions = True
)

In [None]:
for param in model.vit.embeddings.parameters():
    param.requires_grad = False

# Freeze all encoder layers except the last 2
num_layers = len(model.vit.encoder.layer)
for i, layer in enumerate(model.vit.encoder.layer):
    if i < num_layers - 2:  # Freeze all but last 2 layers
        for param in layer.parameters():
            param.requires_grad = False

for param in model.vit.layernorm.parameters():
    param.requires_grad = True

In [None]:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable_params:,} / {total_params:,} ({100 * trainable_params / total_params:.2f}%)")

In [None]:
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience = 3,
    early_stopping_threshold = 0.001
)

In [None]:
training_args = TrainingArguments(
    output_dir = "./vit-fake-detector_freezed",
    per_device_train_batch_size = BATCH_SIZE,
    per_device_eval_batch_size = 8,
    num_train_epochs = EPOCHS,
    eval_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate = LEARNING_RATE,
    weight_decay = 0.01,
    warmup_ratio = 0.1,
    load_best_model_at_end = True,
    metric_for_best_model = "f1",
    greater_is_better = True,
    logging_dir = './logs',
    logging_steps = 100,
    remove_unused_columns = False,
    push_to_hub = False,
    report_to = "none",
    save_total_limit = 2,
    dataloader_pin_memory = False
)

In [None]:
trainer = MemoryEfficientTrainer(
    model = model,
    args = training_args,
    train_dataset = train_dataset,
    eval_dataset = val_dataset,
    compute_metrics = compute_metrics,
    callbacks = [early_stopping_callback]
)

In [None]:
import gc
torch.cuda.empty_cache()
gc.collect()

In [None]:
trainer.train()

In [None]:
test_results = trainer.evaluate(test_dataset)
print(f"Test results: {test_results}")

In [None]:
model = trainer.model

In [None]:
for param in model.parameters():
    param.requires_grad = True

In [None]:
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience = 3,
    early_stopping_threshold = 0.001
)

In [None]:
training_args = TrainingArguments(
    output_dir = "./vit-fake-detector_unfreezed",
    per_device_train_batch_size = BATCH_SIZE,
    per_device_eval_batch_size = 8,
    num_train_epochs = EPOCHS,
    eval_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate = LEARNING_RATE,
    weight_decay = 0.01,
    warmup_ratio = 0.1,
    load_best_model_at_end = True,
    metric_for_best_model = "f1",
    greater_is_better = True,
    logging_dir = './logs',
    logging_steps = 100,
    remove_unused_columns = False,
    push_to_hub = False,
    report_to = "none",
    save_total_limit = 2,
    dataloader_pin_memory = False
)

In [None]:
trainer = MemoryEfficientTrainer(
    model = model,
    args = training_args,
    train_dataset = train_dataset,
    eval_dataset = val_dataset,
    compute_metrics = compute_metrics,
    callbacks = [early_stopping_callback]
)

In [None]:
import gc
torch.cuda.empty_cache()
gc.collect()

In [None]:
trainer.train()

In [None]:
test_results = trainer.evaluate(test_dataset)
print(f"Test results: {test_results}")

In [None]:
evaluate_per_domain(trainer, test_dataset)

In [None]:
best_model = trainer.model

In [None]:
tmp = np.random.randint(0, len(test_dataset.file_paths), 10)

for idx in tmp:
    image_path = test_dataset.file_paths[idx]
    true_label = test_dataset.labels[idx]
    visualize_attention(best_model, image_path, processor, true_label = true_label)

In [None]:
tmp = np.random.randint(0, len(test_dataset.file_paths), 10)

for idx in tmp:
    image_path = test_dataset.file_paths[idx]
    true_label = test_dataset.labels[idx]
    visualize_attention(best_model, image_path, processor, true_label = true_label)

In [None]:
tmp = np.random.randint(0, len(test_dataset.file_paths), 10)

for idx in tmp:
    image_path = test_dataset.file_paths[idx]
    true_label = test_dataset.labels[idx]
    visualize_attention(best_model, image_path, processor, true_label = true_label)

In [None]:
visualize_attention_per_domain(
    model=best_model,
    test_dataset=test_dataset,
    processor=processor,
    n_samples=10,
)

## Hold-out Testing

In [None]:
def evaluate_holdout_set(model, holdout_real_paths, holdout_fake_paths,
                         processor, batch_size=32, model_name="ViT"):
    """
    Evaluate model on a new hold-out testing set (pure images, not from videos).

    Args:
        model: Trained ViT model
        holdout_real_paths: List of paths to real image folders
        holdout_fake_paths: List of paths to fake image folders
        processor: ViTImageProcessor
        batch_size: Batch size for evaluation
        model_name: Name of model for display

    Returns:
        results: Dictionary with metrics and predictions
    """
    print("="*70)
    print(f"HOLD-OUT SET EVALUATION - {model_name}")
    print("="*70)

    # LABEL_REAL = 0, LABEL_FAKE = 1
    LABEL_REAL = 0
    LABEL_FAKE = 1

    patterns_to_check = ["*.png", "*.jpg", "*.jpeg"]

    # Collect all real images
    real_files = []
    for path in holdout_real_paths:
        for ext in patterns_to_check:
            files = glob.glob(os.path.join(path, "**", ext), recursive=True)
            real_files.extend(files)
        print(f"Found {len([f for f in real_files if path in f])} REAL images in {path}")

    print(f"Total REAL images: {len(real_files)}")

    # Collect all fake images
    fake_files = []
    for path in holdout_fake_paths:
        for ext in patterns_to_check:
            files = glob.glob(os.path.join(path, "**", ext), recursive=True)
            fake_files.extend(files)
        print(f"Found {len([f for f in fake_files if path in f])} FAKE images in {path}")

    print(f"Total FAKE images: {len(fake_files)}")

    # Combine files and labels
    all_files = real_files + fake_files
    all_labels = [LABEL_REAL] * len(real_files) + [LABEL_FAKE] * len(fake_files)

    print(f"\nTotal hold-out images: {len(all_files)}")
    print(f"  REAL: {len(real_files)}")
    print(f"  FAKE: {len(fake_files)}")

    if len(all_files) == 0:
        print("Error: No images found!")
        return None

    # Create dataset
    holdout_dataset = KaggleImageDataset(
        file_paths=all_files,
        labels=all_labels,
        processor=processor,
        is_train=False  # No augmentation for testing
    )

    # Create dataloader
    holdout_loader = DataLoader(
        holdout_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2
    )

    # Run inference
    print("\nRunning inference...")
    device = next(model.parameters()).device
    model.eval()

    all_preds = []
    all_probs = []
    all_labels_list = []

    with torch.no_grad():
        for batch in tqdm(holdout_loader, desc="Evaluating"):
            pixel_values = batch['pixel_values'].to(device)
            labels_batch = batch['labels'].to(device)

            outputs = model(pixel_values=pixel_values)
            logits = outputs.logits
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(logits, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
            all_labels_list.extend(labels_batch.cpu().numpy())

    all_preds = np.array(all_preds)
    all_probs = np.array(all_probs)
    all_labels_array = np.array(all_labels_list)

    # Calculate metrics
    print("\n" + "="*70)
    print("RESULTS")
    print("="*70)

    accuracy = accuracy_score(all_labels_array, all_preds)
    print(f"\nOverall Accuracy: {accuracy*100:.2f}%")

    # Classification report
    print("\n" + classification_report(
        all_labels_array, all_preds,
        target_names=["REAL", "FAKE"],
        digits=4
    ))

    # Confusion matrix
    cm = confusion_matrix(all_labels_array, all_preds)
    print("Confusion Matrix:")
    print("                 Predicted")
    print("               REAL    FAKE")
    print(f"Actual REAL   {cm[0][0]:5d}   {cm[0][1]:5d}")
    print(f"       FAKE   {cm[1][0]:5d}   {cm[1][1]:5d}")

    # Calculate per-class metrics
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels_array, all_preds, average=None, labels=[LABEL_REAL, LABEL_FAKE]
    )

    print(f"\nPer-Class Metrics:")
    print(f"  REAL: Precision={precision[0]:.4f}, Recall={recall[0]:.4f}, F1={f1[0]:.4f}")
    print(f"  FAKE: Precision={precision[1]:.4f}, Recall={recall[1]:.4f}, F1={f1[1]:.4f}")

    # ROC-AUC
    try:
        auc = roc_auc_score(all_labels_array, all_probs[:, 1])
        print(f"\nROC-AUC Score: {auc:.4f}")
    except:
        print("\nROC-AUC Score: Could not calculate")
        auc = None

    # Per-source breakdown
    print("\n" + "="*70)
    print("PER-SOURCE BREAKDOWN")
    print("="*70)

    # Create source mapping
    sources = {}
    for path in holdout_real_paths:
        source_name = os.path.basename(path.rstrip('/'))
        sources[source_name] = {'pattern': path, 'type': 'REAL'}

    for path in holdout_fake_paths:
        source_name = os.path.basename(path.rstrip('/'))
        sources[source_name] = {'pattern': path, 'type': 'FAKE'}

    for source_name, source_info in sources.items():
        # Find indices for this source
        pattern = source_info['pattern']
        indices = [i for i, f in enumerate(all_files) if pattern in f]

        if len(indices) == 0:
            continue

        source_labels = all_labels_array[indices]
        source_preds = all_preds[indices]

        source_acc = accuracy_score(source_labels, source_preds)
        correct = (source_labels == source_preds).sum()

        print(f"\n--- {source_name} ({source_info['type']}) ---")
        print(f"Samples: {len(indices)}")
        print(f"Accuracy: {source_acc*100:.2f}% ({correct}/{len(indices)})")

        # Show breakdown if mixed labels
        unique_labels = np.unique(source_labels)
        if len(unique_labels) > 1:
            source_report = classification_report(
                source_labels, source_preds,
                target_names=["REAL", "FAKE"],
                digits=4,
                zero_division=0
            )
            print(source_report)

    print("="*70)

    # Return results
    results = {
        'accuracy': accuracy,
        'predictions': all_preds,
        'probabilities': all_probs,
        'labels': all_labels_array,
        'file_paths': all_files,
        'confusion_matrix': cm,
        'auc': auc
    }

    return results

In [None]:
processor = AutoImageProcessor.from_pretrained(model_name, use_fast = True)

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

In [None]:
best_model = ViTForImageClassification.from_pretrained("/kaggle/input/vit-trained-model/checkpoint-17030")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
best_model.to(device)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermed

In [None]:
holdout_real_paths = [
    "/kaggle/input/deepfake-hold-out-testing/hold-out-testing/test/real",
    "/kaggle/input/deepfake-hold-out-testing/hold-out-testing/train/real",
    "/kaggle/input/deepfake-hold-out-testing/hold-out-testing/val/real"
]

holdout_fake_paths = [
    "/kaggle/input/deepfake-hold-out-testing/hold-out-testing/test/fake",
    "/kaggle/input/deepfake-hold-out-testing/hold-out-testing/train/fake",
    "/kaggle/input/deepfake-hold-out-testing/hold-out-testing/val/fake"
]

In [None]:
results = evaluate_holdout_set(
    model = best_model,
    holdout_real_paths = holdout_real_paths,
    holdout_fake_paths = holdout_fake_paths,
    processor = processor,
    batch_size = 32,
    model_name = "ViT"
)

HOLD-OUT SET EVALUATION - ViT
Found 1606 REAL images in /kaggle/input/deepfake-hold-out-testing/hold-out-testing/test/real
Found 12848 REAL images in /kaggle/input/deepfake-hold-out-testing/hold-out-testing/train/real
Found 1606 REAL images in /kaggle/input/deepfake-hold-out-testing/hold-out-testing/val/real
Total REAL images: 16060
Found 1606 FAKE images in /kaggle/input/deepfake-hold-out-testing/hold-out-testing/test/fake
Found 12848 FAKE images in /kaggle/input/deepfake-hold-out-testing/hold-out-testing/train/fake
Found 1606 FAKE images in /kaggle/input/deepfake-hold-out-testing/hold-out-testing/val/fake
Total FAKE images: 16060

Total hold-out images: 32120
  REAL: 16060
  FAKE: 16060

Running inference...


Evaluating: 100%|██████████| 1004/1004 [04:58<00:00,  3.37it/s]


RESULTS

Overall Accuracy: 58.14%

              precision    recall  f1-score   support

        REAL     0.5498    0.8980    0.6821     16060
        FAKE     0.7219    0.2648    0.3874     16060

    accuracy                         0.5814     32120
   macro avg     0.6359    0.5814    0.5347     32120
weighted avg     0.6359    0.5814    0.5347     32120

Confusion Matrix:
                 Predicted
               REAL    FAKE
Actual REAL   14422    1638
       FAKE   11808    4252

Per-Class Metrics:
  REAL: Precision=0.5498, Recall=0.8980, F1=0.6821
  FAKE: Precision=0.7219, Recall=0.2648, F1=0.3874

ROC-AUC Score: 0.5808

PER-SOURCE BREAKDOWN

--- real (REAL) ---
Samples: 1606
Accuracy: 87.36% (1403/1606)

--- fake (FAKE) ---
Samples: 1606
Accuracy: 16.06% (258/1606)





In [None]:
if results:
    print(f"\nFinal Accuracy: {results['accuracy']*100:.2f}%")
    if results['auc']:
        print(f"AUC: {results['auc']:.4f}")


Final Accuracy: 58.14%
AUC: 0.5808
