In [None]:
# Scarlet.ipynb

# Deepfake Detection using TimeSformer and FaceForensics++ (C23) on Kaggle

# --------------------
# 1. INSTALL DEPENDENCIES
# --------------------
# Install essential libraries for model training and video processing
!pip install -q transformers accelerate timm decord --no-deps
!pip install decord

In [None]:
# --------------------
# 2. IMPORT LIBRARIES
# --------------------
import os  # For interacting with the file system
import random  # For random sampling and shuffling
import json  # For saving evaluation metrics
import torch  # PyTorch framework
from torch import nn  # Neural network modules
from decord import VideoReader, cpu
from torch.utils.data import Dataset, DataLoader  # Dataset and data loading utilities
import torchvision.transforms as T  # Image transformations
import cv2  # OpenCV for image and video operations
from PIL import Image
from transformers import AutoImageProcessor, TimesformerForVideoClassification  # Hugging Face model and processor
from sklearn.metrics import f1_score, roc_auc_score, classification_report, confusion_matrix  # Evaluation metrics
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
# --------------------
# 3. SETUP DEVICE
# --------------------
# Use GPU if available, else fallback to CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# --------------------
# 4. CONFIGURATION
# --------------------

# Install dependencies as needed:

import pandas as pd

# Update these paths after checking them in the file explorer
df = pd.read_csv("/kaggle/input/ff-c23/FaceForensics++_C23/csv/original.csv")
df1 = pd.read_csv("/kaggle/input/ff-c23/FaceForensics++_C23/csv/Deepfakes.csv")

print("First 5 records (real):", df.head())
print("First 5 records (fake):", df1.head())


real_path = "/kaggle/input/ff-c23/FaceForensics++_C23/original"
fake_path = "/kaggle/input/ff-c23/FaceForensics++_C23/Deepfakes"

# Video processing and training parameters
num_frames = 8  # Number of frames to sample from each video

''' Why have we chosen frames to be 8:-

TimeSformer models require a fixed number of frames from each video input.
Using 8 frames provides a balance between temporal representation and computational efficiency.
Fewer frames (e.g., 4) might not capture enough temporal context to distinguish real from fake.
More frames (e.g., 16 or 32) increase memory usage significantly, which can cause OOM (Out of Memory) errors on limited GPUs like those on Kaggle.'''

image_size = 224  # Target image size (height, width)

''' Why have we chosen image size as 224x224:- 

224x224 is the standard input size for many pretrained image and video models including TimeSformer.
Pretrained weights from "facebook/timesformer-base-finetuned-k400" expect 224x224 resolution.
Resizing videos to this shape keeps consistency with the model’s original training distribution.'''

batch_size = 2  # Number of samples per batch

''' Why have we chosen batch size to be 2:- 

Video models are extremely memory-intensive since they process multiple frames per sample.
On Kaggle’s free GPU (often a T4 with 16GB), a batch size of 2 is safe and avoids OOM errors.
We can experiment with higher batch sizes (e.g., 4 or 8) if we reduce the number of frames or image size.'''

epochs = 21 # Total number of training epochs

''' Why have we chosen epochs to be :- 

Performance Balance: 21 epochs capture sufficient learning before overfitting starts.
Resource Considerations: Limited time and GPU resources often make 21 a good upper bound.
Validation Results: 21 epochs may provide the best validation performance observed during training.'''

checkpoint_path = "/kaggle/working/timesformer_ffpp_checkpoint.pth"  # Path to save/resume checkpoint

In [None]:
# --------------------
# 5. TRANSFORMATIONS
# --------------------
# Define transformation for each frame in the video
transform = T.Compose([
    T.Resize((image_size, image_size)),  # Resize frames
    T.ToTensor(),  # Convert images to PyTorch tensors
    T.Normalize([0.5], [0.5])  # Normalize to [-1, 1]
])

In [None]:
# --------------------
# 6. DATASET CLASS
# --------------------
class VideoDataset(Dataset):
    """
    A custom PyTorch Dataset for loading videos and converting them into frame tensors.

    Args:
        video_paths (List[str]): Paths to video files.
        labels (List[int]): Binary labels corresponding to real (0) or fake (1) videos.
        transform (Callable): Transformation function applied to each frame.
    """
    def __init__(self, video_paths, labels, transform):
        self.video_paths = video_paths  # Store paths to video files
        self.labels = labels  # Store associated labels
        self.transform = transform  # Store frame transformation function

    def __len__(self):
        """Returns the number of samples in the dataset."""
        return len(self.video_paths)

    def __getitem__(self, idx):
        """
        Loads and processes a video at the given index.

        Args:
            idx (int): Index of the sample.

        Returns:
            video (Tensor): Tensor of shape [C, T, H, W].
            label (int): Class label.
        """
        path = self.video_paths[idx]  # Get the file path
        label = self.labels[idx]  # Get the label
        vr = VideoReader(path, ctx=cpu(0))  # Load video
        total_frames = len(vr)  # Total number of frames
        indices = [int(i * total_frames / num_frames) for i in range(num_frames)]  # Sample evenly spaced frames
        frames = [self.transform(Image.fromarray(vr[i].asnumpy())) for i in indices] # Read and transform frames
        video = torch.stack(frames)  # T, C, H, W
        video = video.permute(0, 1, 2, 3)  # (unchanged, safe way to keep T, C, H, W)
        return video, label

In [None]:
# --------------------
# 7. LOAD AND BALANCE DATA
# --------------------
# Get video file names
real_videos = [os.path.join(real_path, f) for f in os.listdir(real_path) if f.endswith(".mp4")]
fake_videos = [os.path.join(fake_path, f) for f in os.listdir(fake_path) if f.endswith(".mp4")]

# Sample an equal number of real and fake videos
sample_size = min(len(real_videos), len(fake_videos), 350)
real_videos = random.sample(real_videos, sample_size)
fake_videos = random.sample(fake_videos, sample_size)

# Combine and shuffle data
video_paths = real_videos + fake_videos
labels = [0] * sample_size + [1] * sample_size
combined = list(zip(video_paths, labels))
random.shuffle(combined)
video_paths, labels = zip(*combined)

# Split into training and validation sets
train_paths = video_paths[:600]
train_labels = labels[:600]
val_paths = video_paths[600:700]
val_labels = labels[600:700]

# Create PyTorch datasets and dataloaders
train_dataset = VideoDataset(train_paths, train_labels, transform)
val_dataset = VideoDataset(val_paths, val_labels, transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# --------------------
# 8. LOAD MODEL
# --------------------
from transformers import TimesformerForVideoClassification
import torch.nn as nn
import os
import torch

# Step 1: Load the pretrained model as-is (400 output classes)
model = TimesformerForVideoClassification.from_pretrained(
    "facebook/timesformer-base-finetuned-k400"
)

# Step 2: Replace classifier BEFORE loading checkpoint
if hasattr(model, 'classifier'):
    model.classifier = nn.Linear(model.config.hidden_size, 2)
    print("✅ Replaced classifier layer with binary classification head.")
elif hasattr(model, 'head'):
    model.head = nn.Linear(model.config.hidden_size, 2)
    print("✅ Replaced head layer with binary classification head.")
else:
    print("⚠️ Unable to find classification head ('classifier' or 'head').")

# Step 3: Load best model weights (after classifier has been replaced)
best_model_path = "/kaggle/working/outputs/Scarlet.pt"  # Your best model path
if os.path.exists(best_model_path):
    model.load_state_dict(torch.load(best_model_path, map_location=device, weights_only=True))
    print("✅ Loaded best model from Scarlet.pt")

# Step 4: Move model to device
model.to(device)


In [None]:
# --------------------
# 9. TRAINING SETUP
# --------------------
# Define optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5, weight_decay=1e-2)

loss_fn = nn.CrossEntropyLoss()

# Create output directory for saving results
os.makedirs("outputs", exist_ok=True)

# Initialize best accuracy for model saving
best_val_acc = 0.0
metrics_dict = {}  # Dictionary to store evaluation metrics

In [None]:
# --------------------
# 10. RESUME CHECKPOINT (if exists)
# --------------------
# Variables to resume training
start_epoch = 0
start_batch_idx = 0

# If checkpoint exists, resume from it
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    start_batch_idx = checkpoint['batch_idx'] + 1
    print(f"Resuming from epoch {start_epoch + 1}, batch {start_batch_idx}")

In [None]:
# --------------------
# 11. TRAINING LOOP
# --------------------
for epoch in range(start_epoch, epochs):
    model.train()  # Set model to training mode
    total_loss = 0  # Track total loss

    for batch_idx, batch in enumerate(train_loader):
        # Skip already trained batches when resuming
        if epoch == start_epoch and batch_idx < start_batch_idx:
            continue

        videos, labels = batch  # Get a batch of videos and labels
        videos = videos.to(device)  # shape: (B, T, C, H, W)
        labels = labels.to(device)

        optimizer.zero_grad()  # Clear previous gradients
        outputs = model(pixel_values=videos).logits  # Forward pass
        loss = loss_fn(outputs, labels)  # Compute loss
        loss.backward()  # Backpropagation
        optimizer.step()  # Update model parameters

        total_loss += loss.item()

        # Save checkpoint every 20 batches
        if batch_idx % 20 == 0:
            torch.save({
                'epoch': epoch,
                'batch_idx': batch_idx,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, checkpoint_path)
            print(f"Checkpoint saved at epoch {epoch+1}, batch {batch_idx}")

    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader):.4f}")

    # --------------------
    # VALIDATION LOOP
    # --------------------
    model.eval()  # Set model to evaluation mode
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():  # No gradient calculation
        for batch in val_loader:
            videos, labels = batch
            videos = videos.to(device)  # shape: (B, T, C, H, W)
            labels = labels.to(device)

            outputs = model(pixel_values=videos).logits  # Forward pass
            preds = torch.argmax(outputs, dim=1)  # Predicted classes

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    val_acc = correct / total  # Accuracy
    val_f1 = f1_score(all_labels, all_preds)  # F1 Score
    val_auc = roc_auc_score(all_labels, all_preds)  # AUC Score

    print(f"Validation Accuracy: {val_acc:.4f}, F1 Score: {val_f1:.4f}, AUC: {val_auc:.4f}")

    # Print classification report
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, target_names=["Real", "Fake"]))

    # Print confusion matrix
    print("\nConfusion Matrix:")
    print(confusion_matrix(all_labels, all_preds))

    # Plot confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["Real", "Fake"], yticklabels=["Real", "Fake"])
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.title("Confusion Matrix")
    plt.tight_layout()
    plt.savefig("/kaggle/working/outputs/confusion_matrix.png")  # Save figure
    plt.show()

    # Save classification report and confusion matrix
    with open("/kaggle/working/outputs/classification_report.txt", "w") as f:
        f.write("Classification Report:\n")
        f.write(classification_report(all_labels, all_preds, target_names=["Real", "Fake"]))
        f.write("\nConfusion Matrix:\n")
        cm = confusion_matrix(all_labels, all_preds)
        f.write(str(cm))


    # Save best model and metrics
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "/kaggle/working/outputs/Scarlet.pt")  # Save model
        metrics_dict = {
            "epoch": epoch + 1,
            "accuracy": val_acc,
            "f1_score": val_f1,
            "auc": val_auc
        }
        with open("/kaggle/working/outputs/metrics.json", "w") as f:
            json.dump(metrics_dict, f, indent=4)
        print(" Best model and metrics saved!")

In [None]:
# --------------------
# 12. SAVE FINAL MODEL
# --------------------
torch.save(model.state_dict(), "/kaggle/working/timesformer_ffpp_final.pth")  # Save final model
print(" Final model saved!")