# Environment Initialization and Dataset Verification
<p style="font-size:18px; font-weight:normal;">
 This cell imports essential libraries (NumPy, Pandas, os) and verifies that the dataset is correctly mounted in the Kaggle environment.
 It lists all files under /kaggle/input to ensure that video data is accessible before model training.
 </p>

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import torch
import torchvision
import timm


In [None]:
print(torch.cuda.is_available())


In [None]:
# Install dependencies

!pip -q install timm decord grad-cam --upgrade

# Imports & Reproducibility Setup

<p style="font-size:18px; font-weight:normal;">
This cell imports all required libraries for deep learning, video and audio processing, and model evaluation.  
It also sets a fixed random seed to ensure reproducibility and defines the computation device (CPU/GPU) for training.
</p>

In [None]:
#  Imports & Seed

import os, glob, re, random, math, time, hashlib, subprocess
import numpy as np
import pandas as pd
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler

from decord import VideoReader, cpu
import timm
import torchaudio

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score

import matplotlib.pyplot as plt

SEED = 42
def seed_everything(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

seed_everything()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


# Configuration Settings

<p style="font-size:18px; font-weight:normal;">
This section defines all global hyperparameters and training settings used throughout the project, including image size, number of frames, learning rates, batch size, and audio parameters.  
It centralizes configuration values to ensure consistency, reproducibility, and easy experimentation.
</p>

In [None]:
@dataclass
class CFG:
    img_size: int = 224
    k_frames: int = 16
    batch_size: int = 4
    num_workers: int = 2
    epochs_baseline: int = 4
    epochs_av: int = 6
    lr_baseline: float = 3e-4
    lr_av: float = 2e-4
    weight_decay: float = 1e-4
    patience: int = 2
    sample_rate: int = 16000
    audio_seconds: int = 4
    n_mfcc: int = 40
    audio_cache_dir: str = "/kaggle/working/audio_cache"
    work_dir: str = "/kaggle/working"

cfg = CFG()
os.makedirs(cfg.audio_cache_dir, exist_ok=True)

MEAN = (0.485, 0.456, 0.406)
STD  = (0.229, 0.224, 0.225)


# Video Sampling and Preprocessing

<p style="font-size:18px; font-weight:normal;">
This section defines helper functions for extracting and preprocessing video frames.  
It samples a fixed number of frames per video, applies optional data augmentation, and performs center-crop resizing to ensure consistent input dimensions for the model.
</p>

In [None]:
def sample_frame_indices(num_frames_total, k=16, strategy="uniform"):
    if num_frames_total <= 0:
        return np.zeros((k,), dtype=np.int64)
    if strategy == "uniform":
        idx = np.linspace(0, num_frames_total - 1, k).round().astype(np.int64)
        return idx
    if strategy == "random":
        if num_frames_total >= k:
            start = random.randint(0, num_frames_total - k)
            return np.arange(start, start + k, dtype=np.int64)
        else:
            idx = np.arange(num_frames_total, dtype=np.int64)
            pad = np.full((k - num_frames_total,), num_frames_total - 1, dtype=np.int64)
            return np.concatenate([idx, pad])
    raise ValueError("Unknown strategy")

def random_horizontal_flip(frames, p=0.5):
    if random.random() < p:
        return frames[:, :, ::-1, :]
    return frames

def center_crop_resize(frames, out_size=224):
    # frames: (T,H,W,3) uint8
    import cv2
    out = []
    for fr in frames:
        h, w, _ = fr.shape
        scale = out_size / min(h, w)
        nh, nw = int(h * scale), int(w * scale)
        fr = cv2.resize(fr, (nw, nh), interpolation=cv2.INTER_AREA)
        h, w, _ = fr.shape
        y0 = (h - out_size) // 2
        x0 = (w - out_size) // 2
        fr = fr[y0:y0 + out_size, x0:x0 + out_size]
        out.append(fr)
    return np.stack(out, axis=0)

# Audio Feature Extraction (MFCC)

<p style="font-size:18px; font-weight:normal;">
This section extracts audio features from each video using MFCC representations.  
Audio is converted to mono 16kHz format, standardized in length, and transformed into MFCC coefficients. If no audio is available, a zero tensor is returned to maintain consistent input dimensions.
</p>

In [None]:
SAMPLE_RATE = cfg.sample_rate
N_MFCC = cfg.n_mfcc

mfcc_tf = torchaudio.transforms.MFCC(
    sample_rate=SAMPLE_RATE,
    n_mfcc=N_MFCC,
    melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 64, "center": True},
)

def audio_path_for_video(video_path: str):
    h = hashlib.md5(video_path.encode("utf-8")).hexdigest()
    return os.path.join(cfg.audio_cache_dir, f"{h}.wav")

def ensure_audio(video_path: str):
    out = audio_path_for_video(video_path)
    if os.path.exists(out) and os.path.getsize(out) > 1000:
        return out
    # extract audio from mp4 -> wav mono 16k
    cmd = [
    "ffmpeg",
    "-y",
    "-loglevel", "quiet",
    "-i", video_path,
    "-vn",
    "-ac", "1",
    "-ar", str(SAMPLE_RATE),
    out
]

    try:
        subprocess.run(cmd, check=False, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

        if os.path.exists(out) and os.path.getsize(out) > 1000:
            return out
    except Exception:
        pass
    return None  # if no audio or ffmpeg failed

def load_mfcc_or_zeros(video_path: str):
    apath = ensure_audio(video_path)
    if apath is None:
        # Return consistent shape zeros (n_mfcc, time)
        return torch.zeros((N_MFCC, 401), dtype=torch.float32)

    wav, sr = torchaudio.load(apath)
    wav = wav.mean(dim=0, keepdim=True)  # mono
    if sr != SAMPLE_RATE:
        wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)

    target_len = SAMPLE_RATE * cfg.audio_seconds
    if wav.shape[1] >= target_len:
        wav = wav[:, :target_len]
    else:
        wav = F.pad(wav, (0, target_len - wav.shape[1]))

    mfcc = mfcc_tf(wav)          # (1, n_mfcc, time)
    mfcc = torch.log1p(mfcc)
    return mfcc.squeeze(0)       # (n_mfcc, time)


# Dataset Discovery and Automatic Label Inference

<p style="font-size:18px; font-weight:normal;">
This section scans the FaceForensics++ dataset to locate all video files and automatically assign binary labels based on directory naming conventions.  
Videos labeled as "original/pristine" are assigned class 0 (real), while manipulated methods (e.g., DeepFakes, FaceSwap) are assigned class 1 (fake).
</p>

In [None]:
def find_all_mp4(base="/kaggle/input"):
    mp4s = glob.glob(base + "/**/*.mp4", recursive=True)
    print("Found mp4:", len(mp4s))
    return mp4s

mp4s = find_all_mp4()

# Label inference: FaceForensics++ usually has "original" vs manipulated methods
# We'll treat:
#  - Original / pristine / real => label 0
#  - Anything under manipulated methods (deepfakes, face2face, faceswap, neuraltextures, etc) => label 1
FAKE_KEYS = [
    "deepfakes", "face2face", "faceswap", "neuraltextures", "manipulated",
    "fake", "df", "f2f", "fs", "nt"
]
REAL_KEYS = ["original", "pristine", "real"]

def infer_label_ffpp(path: str):
    s = path.lower()
    if any(k in s for k in REAL_KEYS) and not any(k in s for k in FAKE_KEYS):
        return 0
    if any(k in s for k in FAKE_KEYS) and not any(k in s for k in REAL_KEYS):
        return 1
    # If ambiguous, we try a stronger rule:
    # if contains "original" anywhere -> real
    if "original" in s or "pristine" in s:
        return 0
    # otherwise if contains known manip methods -> fake
    if any(k in s for k in ["deepfakes", "face2face", "faceswap", "neuraltextures"]):
        return 1
    return None

rows = []
for p in mp4s:
    y = infer_label_ffpp(p)
    if y is not None:
        rows.append({"video_path": p, "label": y})

df = pd.DataFrame(rows)
print("Labeled videos:", len(df), "out of", len(mp4s))
if len(df) == 0:
    raise RuntimeError("Could not infer labels. Print mp4 paths and adjust infer_label_ffpp().")

print("Label distribution:", df["label"].value_counts().to_dict())
print(df.sample(5, random_state=SEED))


# Train, Validation, and Test Split

<p style="font-size:18px; font-weight:normal;">
This section removes duplicate entries and performs a stratified split of the dataset into training (70%), validation (15%), and test (15%) sets.  
Stratification ensures balanced class distribution across all subsets, and the splits are saved as CSV files for reproducibility.
</p>

In [None]:
df = df.drop_duplicates(subset=["video_path"]).reset_index(drop=True)

# Stratified split
train_df, tmp_df = train_test_split(df, test_size=0.30, random_state=SEED, stratify=df["label"])
val_df, test_df  = train_test_split(tmp_df, test_size=0.50, random_state=SEED, stratify=tmp_df["label"])

train_csv = os.path.join(cfg.work_dir, "train.csv")
val_csv   = os.path.join(cfg.work_dir, "val.csv")
test_csv  = os.path.join(cfg.work_dir, "test.csv")

train_df.to_csv(train_csv, index=False)
val_df.to_csv(val_csv, index=False)
test_df.to_csv(test_csv, index=False)

print("Saved:", train_csv, val_csv, test_csv)
print("Train:", train_df["label"].value_counts().to_dict())
print("Val  :", val_df["label"].value_counts().to_dict())
print("Test :", test_df["label"].value_counts().to_dict())


# Custom Dataset and DataLoaders

<p style="font-size:18px; font-weight:normal;">
This section defines a custom PyTorch Dataset that loads video frames and corresponding audio features for each sample.  
Video frames are sampled, augmented (during training), normalized, and converted to tensors, while MFCC audio features are extracted and aligned. DataLoaders are then created for efficient batch processing.
</p>

In [None]:
class RealVisionDataset(Dataset):
    def __init__(self, csv_path, k_frames=16, train=True):
        self.df = pd.read_csv(csv_path)
        self.k_frames = k_frames
        self.train = train
        assert set(self.df["label"].unique()).issubset({0,1})

    def _load_video_frames(self, video_path):
        vr = VideoReader(video_path, ctx=cpu(0))
        n = len(vr)
        idx = sample_frame_indices(n, k=self.k_frames, strategy="uniform")
        frames = vr.get_batch(idx).asnumpy()  # (T,H,W,3)

        if self.train:
            frames = random_horizontal_flip(frames, p=0.5)

        frames = center_crop_resize(frames, out_size=cfg.img_size)
        x = torch.from_numpy(frames).float() / 255.0
        x = x.permute(0,3,1,2)  # (T,3,H,W)
        # --- ImageNet normalization ---
        mean = torch.tensor(MEAN).view(1,3,1,1)
        std  = torch.tensor(STD).view(1,3,1,1)
        x = (x - mean) / std
        return x

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        row = self.df.iloc[i]
        vpath = row["video_path"]
        y = int(row["label"])

        v = self._load_video_frames(vpath)
        a = load_mfcc_or_zeros(vpath)  # (n_mfcc, time)
        return v, a, torch.tensor(y, dtype=torch.long)

train_ds = RealVisionDataset(train_csv, k_frames=cfg.k_frames, train=True)
val_ds   = RealVisionDataset(val_csv,   k_frames=cfg.k_frames, train=False)
test_ds  = RealVisionDataset(test_csv,  k_frames=cfg.k_frames, train=False)

train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
                          num_workers=cfg.num_workers, pin_memory=True, drop_last=True)
val_loader   = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False,
                          num_workers=cfg.num_workers, pin_memory=True)
test_loader  = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False,
                          num_workers=cfg.num_workers, pin_memory=True)


# Data Sanity Checks

<p style="font-size:18px; font-weight:normal;">
This section performs basic sanity checks to verify correct data loading before training.  
It confirms balanced class distribution, validates input tensor shapes (video and audio), and prints a sample video path to ensure dataset integrity.
</p>

In [None]:
from collections import Counter
print("Train label counts:", Counter(train_ds.df["label"].tolist()))

v, a, y = train_ds[0]
print("Sample shapes:", v.shape, a.shape, y.item())
print("First video path:", train_ds.df.iloc[0]["video_path"])

# Model Architectures

<p style="font-size:18px; font-weight:normal;">
This section defines the model architectures used in the project.  
A video-only baseline model is implemented using a pre-trained EfficientNet backbone, while the main multi-modal model integrates spatial video features, temporal modeling via a Transformer encoder, and audio embeddings through MFCC-based convolutional layers.
</p>

In [None]:
class VideoOnlyBaseline(nn.Module):
    def __init__(self, backbone="tf_efficientnet_b0", emb_dim=256):
        super().__init__()
        self.cnn = timm.create_model(backbone, pretrained=True, num_classes=0, global_pool="")
        ch = self.cnn.num_features
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.proj = nn.Linear(ch, emb_dim)
        self.classifier = nn.Linear(emb_dim, 2)

    def forward(self, v):  # (B,T,3,H,W)
        B,T,C,H,W = v.shape
        x = v.view(B*T, C, H, W)
        feat = self.cnn(x)
        if feat.dim() == 4:
            feat = self.pool(feat).flatten(1)
        feat = feat.view(B, T, -1).mean(dim=1)
        feat = self.proj(feat)
        return self.classifier(feat)

class AudioEncoder(nn.Module):
    def __init__(self, n_mfcc=40, emb_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(n_mfcc, 128, kernel_size=5, padding=2),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.MaxPool1d(2),

            nn.Conv1d(128, 256, kernel_size=5, padding=2),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
        )
        self.proj = nn.Linear(256, emb_dim)

    def forward(self, a):  # (B, n_mfcc, time)
        x = self.net(a).squeeze(-1)      # (B,256)
        return self.proj(x)

class RealVisionAV(nn.Module):
    def __init__(self, backbone="tf_efficientnet_b0", v_emb=256, a_emb=256, nheads=4, nlayers=2):
        super().__init__()
        self.cnn = timm.create_model(backbone, pretrained=True, num_classes=0, global_pool="")
        ch = self.cnn.num_features
        self.v_pool = nn.AdaptiveAvgPool2d(1)
        self.v_proj = nn.Linear(ch, v_emb)

        enc_layer = nn.TransformerEncoderLayer(d_model=v_emb, nhead=nheads, batch_first=True)
        self.temporal = nn.TransformerEncoder(enc_layer, num_layers=nlayers)

        self.audio = AudioEncoder(n_mfcc=N_MFCC, emb_dim=a_emb)

        self.fusion = nn.Sequential(
            nn.Linear(v_emb + a_emb, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 2)
        )

    def forward(self, v, a):
        # v: (B,T,3,H,W), a: (B,n_mfcc,time)
        B,T,C,H,W = v.shape
        x = v.view(B*T, C, H, W)
        feat = self.cnn(x)
        if feat.dim() == 4:
            feat = self.v_pool(feat).flatten(1)
        feat = feat.view(B, T, -1)
        feat = self.v_proj(feat)          # (B,T,v_emb)

        feat = self.temporal(feat)        # (B,T,v_emb)
        v_emb = feat.mean(dim=1)          # (B,v_emb)
        a_emb = self.audio(a)             # (B,a_emb)

        alpha = 0.3
        fused = torch.cat([v_emb, a_emb], dim=1)
        return self.fusion(fused)

# Training Strategy and Evaluation Metrics

<p style="font-size:18px; font-weight:normal;">
This section defines the training loop, evaluation metrics (Accuracy, F1-score, AUC), and optimization strategy.  
The model is trained using weighted cross-entropy to address potential class imbalance, AdamW optimization, cosine learning rate scheduling, and early stopping based on validation F1-score.
</p>

In [None]:
def compute_metrics(y_true, y_prob):
    y_pred = (y_prob[:,1] >= 0.5).astype(int)
    acc = accuracy_score(y_true, y_pred)
    f1  = f1_score(y_true, y_pred)
    try:
        auc = roc_auc_score(y_true, y_prob[:,1])
    except:
        auc = float("nan")
    return acc, f1, auc

def get_class_weights_from_train(train_df):
    counts = train_df["label"].value_counts().to_dict()
    w0 = 1.0 / max(counts.get(0,1), 1)
    w1 = 1.0 / max(counts.get(1,1), 1)
    w = torch.tensor([w0, w1], dtype=torch.float32, device=device)
    return w

def train_one_epoch(model, loader, optimizer, scaler, criterion, is_av: bool):
    model.train()
    losses, all_y, all_p = [], [], []
    for v,a,y in loader:
        v = v.to(device, non_blocking=True)
        a = a.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        with autocast():
            logits = model(v,a) if is_av else model(v)
            loss = criterion(logits, y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        losses.append(loss.item())
        prob = torch.softmax(logits.detach(), dim=1).float().cpu().numpy()
        all_p.append(prob)
        all_y.append(y.detach().cpu().numpy())

    all_p = np.concatenate(all_p, axis=0)
    all_y = np.concatenate(all_y, axis=0)
    acc,f1,auc = compute_metrics(all_y, all_p)
    return float(np.mean(losses)), acc, f1, auc

@torch.no_grad()
def validate(model, loader, criterion, is_av: bool):
    model.eval()
    losses, all_y, all_p = [], [], []
    for v,a,y in loader:
        v = v.to(device, non_blocking=True)
        a = a.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        logits = model(v,a) if is_av else model(v)
        loss = criterion(logits, y)

        losses.append(loss.item())
        prob = torch.softmax(logits, dim=1).float().cpu().numpy()
        all_p.append(prob)
        all_y.append(y.cpu().numpy())

    all_p = np.concatenate(all_p, axis=0)
    all_y = np.concatenate(all_y, axis=0)
    acc,f1,auc = compute_metrics(all_y, all_p)
    return float(np.mean(losses)), acc, f1, auc

def run_training(model, is_av: bool, epochs: int, lr: float, ckpt_path: str):
    model = model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=cfg.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    scaler = GradScaler()

    weights = get_class_weights_from_train(train_df)
    criterion = nn.CrossEntropyLoss(weight=weights)

    best_f1 = -1.0
    bad = 0

    for ep in range(1, epochs + 1):
        tr = train_one_epoch(model, train_loader, optimizer, scaler, criterion, is_av=is_av)
        va = validate(model, val_loader, criterion, is_av=is_av)
        scheduler.step()

        tr_loss,tr_acc,tr_f1,tr_auc = tr
        va_loss,va_acc,va_f1,va_auc = va

        print(f"Epoch {ep:02d} | "
              f"TR loss {tr_loss:.4f} acc {tr_acc:.3f} f1 {tr_f1:.3f} auc {tr_auc:.3f} | "
              f"VA loss {va_loss:.4f} acc {va_acc:.3f} f1 {va_f1:.3f} auc {va_auc:.3f}")

        if va_f1 > best_f1:
            best_f1 = va_f1
            bad = 0
            torch.save({"model": model.state_dict()}, ckpt_path)
            print("saved best:", ckpt_path)
        else:
            bad += 1
            if bad >= cfg.patience:
                print("early stopping")
                break

    return ckpt_path

# Tiny Overfit Sanity Test

<p style="font-size:18px; font-weight:normal;">
This section performs a small overfitting test on a subset of 8 training samples to verify that the model and data pipeline are functioning correctly.  
If the model fails to achieve high accuracy on this tiny subset, it may indicate issues with data loading, labeling, or model implementation.
</p>

In [None]:
print("\n=== Tiny Overfit Test (8 samples) ===")
tiny_idx = list(range(min(8, len(train_ds))))
tiny = torch.utils.data.Subset(train_ds, tiny_idx)
tiny_loader = DataLoader(tiny, batch_size=2, shuffle=True)

tmp_model = RealVisionAV(backbone="tf_efficientnet_b0").to(device)
tmp_opt = torch.optim.AdamW(tmp_model.parameters(), lr=2e-4)
tmp_crit = nn.CrossEntropyLoss()
tmp_scaler = GradScaler()

for ep in range(1, 6):
    tmp_model.train()
    ys, ps, losses = [], [], []
    for v,a,y in tiny_loader:
        v,a,y = v.to(device), a.to(device), y.to(device)
        tmp_opt.zero_grad(set_to_none=True)
        with autocast():
            logits = tmp_model(v,a)
            loss = tmp_crit(logits, y)
        tmp_scaler.scale(loss).backward()
        tmp_scaler.step(tmp_opt)
        tmp_scaler.update()
        losses.append(loss.item())
        ys += y.detach().cpu().tolist()
        ps += logits.argmax(1).detach().cpu().tolist()
    acc = sum(int(p==t) for p,t in zip(ps,ys)) / len(ys)
    print(ep, "loss", float(np.mean(losses)), "acc", acc)

print("If this stays ~0.5, your labels/paths are likely wrong or data is not read correctly.\n")

# Baseline Model Training (Video-Only)

<p style="font-size:18px; font-weight:normal;">
This section trains the video-only baseline model using spatial features extracted from video frames.  
The model serves as a reference point for evaluating the contribution of the multi-modal architecture.
</p>

In [None]:
print("\n=== Train Baseline (Video-only) ===")
baseline = VideoOnlyBaseline(backbone="tf_efficientnet_b0")
baseline_ckpt = os.path.join(cfg.work_dir, "baseline_best.pt")
baseline_ckpt = run_training(baseline, is_av=False, epochs=cfg.epochs_baseline, lr=cfg.lr_baseline, ckpt_path=baseline_ckpt)

# Multi-Modal Model Training (Video + Audio)

<p style="font-size:18px; font-weight:normal;">
This section trains the proposed multi-modal model that integrates spatial video features, temporal modeling, and audio representations.  
The objective is to evaluate whether combining visual and audio information improves deepfake detection performance compared to the baseline.
</p>

In [None]:
print("\n=== Train Main Model (Video+Audio) ===")
av_model = RealVisionAV(backbone="tf_efficientnet_b0", v_emb=256, a_emb=256, nheads=4, nlayers=2)
av_ckpt = os.path.join(cfg.work_dir, "av_best.pt")
av_ckpt = run_training(av_model, is_av=True, epochs=cfg.epochs_av, lr=cfg.lr_av, ckpt_path=av_ckpt)

# Final Test Evaluation

<p style="font-size:18px; font-weight:normal;">
This section loads the best-performing checkpoints of both models and evaluates them on the held-out test set.  
Final performance metrics (Accuracy, F1-score, and AUC) are reported to provide an unbiased comparison between the baseline and the proposed multi-modal model.
</p>

In [None]:
@torch.no_grad()
def load_best(model, ckpt_path):
    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt["model"], strict=True)
    model.to(device).eval()
    return model

@torch.no_grad()
def eval_on_test(model, loader, is_av: bool):
    model.eval()
    all_y, all_p = [], []
    for v,a,y in loader:
        v = v.to(device, non_blocking=True)
        a = a.to(device, non_blocking=True)
        logits = model(v,a) if is_av else model(v)
        prob = torch.softmax(logits, dim=1).float().cpu().numpy()
        all_p.append(prob)
        all_y.append(y.numpy())
    all_p = np.concatenate(all_p, axis=0)
    all_y = np.concatenate(all_y, axis=0)
    acc,f1,auc = compute_metrics(all_y, all_p)
    return acc,f1,auc

print("\n=== Final Test Metrics ===")
baseline2 = load_best(VideoOnlyBaseline(backbone="tf_efficientnet_b0"), baseline_ckpt)
av2       = load_best(RealVisionAV(backbone="tf_efficientnet_b0", v_emb=256, a_emb=256, nheads=4, nlayers=2), av_ckpt)

b_acc,b_f1,b_auc = eval_on_test(baseline2, test_loader, is_av=False)
a_acc,a_f1,a_auc = eval_on_test(av2, test_loader, is_av=True)

print(f"Baseline  | ACC {b_acc:.3f}  F1 {b_f1:.3f}  AUC {b_auc:.3f}")
print(f"AV Model  | ACC {a_acc:.3f}  F1 {a_f1:.3f}  AUC {a_auc:.3f}")


# Advanced Evaluation and Performance Analysis

<p style="font-size:18px; font-weight:normal;">
This section provides a comprehensive evaluation of both models using Confusion Matrices, ROC curves, Precision-Recall curves, and detailed classification reports.  
These visual and quantitative analyses offer deeper insight into class-wise performance, false positives/negatives, and the overall discriminative capability of each model.
</p>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import (
    confusion_matrix, ConfusionMatrixDisplay,
    RocCurveDisplay, classification_report
)

def collect_probs_and_labels(model, loader, is_av: bool):
    model.eval()
    all_y = []
    all_prob1 = []  # prob of class 1 (fake)
    with torch.no_grad():
        for v, a, y in loader:
            v = v.to(device, non_blocking=True)
            a = a.to(device, non_blocking=True)
            logits = model(v, a) if is_av else model(v)
            prob = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy()
            all_prob1.append(prob)
            all_y.append(y.numpy())
    all_prob1 = np.concatenate(all_prob1, axis=0)
    all_y = np.concatenate(all_y, axis=0)
    return all_y, all_prob1

# Collect test predictions
y_b, p1_b = collect_probs_and_labels(baseline2, test_loader, is_av=False)
y_a, p1_a = collect_probs_and_labels(av2,       test_loader, is_av=True)

# Convert probs to predicted labels (threshold 0.5)
pred_b = (p1_b >= 0.5).astype(int)
pred_a = (p1_a >= 0.5).astype(int)

# ---- Confusion Matrices ----
plt.figure()
cm_b = confusion_matrix(y_b, pred_b, labels=[0,1])
ConfusionMatrixDisplay(cm_b, display_labels=["real(0)", "fake(1)"]).plot(values_format="d")
plt.title("Baseline - Confusion Matrix (Test)")
plt.show()

plt.figure()
cm_a = confusion_matrix(y_a, pred_a, labels=[0,1])
ConfusionMatrixDisplay(cm_a, display_labels=["real(0)", "fake(1)"]).plot(values_format="d")
plt.title("AV Model - Confusion Matrix (Test)")
plt.show()

# ---- ROC Curves ----
plt.figure()
RocCurveDisplay.from_predictions(y_b, p1_b, name="Baseline")
RocCurveDisplay.from_predictions(y_a, p1_a, name="AV Model")
plt.title("ROC Curves (Test)")
plt.show()

# ---- Per-class report ----
print("Baseline - Classification Report (Test):")
print(classification_report(y_b, pred_b, target_names=["real(0)", "fake(1)"], digits=3))

print("AV Model - Classification Report (Test):")
print(classification_report(y_a, pred_a, target_names=["real(0)", "fake(1)"], digits=3))

from sklearn.metrics import PrecisionRecallDisplay

plt.figure()
PrecisionRecallDisplay.from_predictions(y_b, p1_b, name="Baseline")
PrecisionRecallDisplay.from_predictions(y_a, p1_a, name="AV Model")
plt.title("Precision-Recall Curves (Test)")
plt.show()


# Explainability (XAI) – Grad-CAM Visualization

<p style="font-size:18px; font-weight:normal;">
This section applies Grad-CAM to the trained multi-modal model in order to visualize which spatial regions influenced the deepfake prediction.  
The heatmap highlights the most discriminative pixels, providing interpretability and validating that the model focuses on meaningful facial areas.
</p>

In [None]:
print("Loading best AV model...")

av2 = load_best(
    RealVisionAV(
        backbone="tf_efficientnet_b0",
        v_emb=256,
        a_emb=256,
        nheads=4,
        nlayers=2
    ),
    av_ckpt
)


In [None]:
print("\n=== XAI (Grad-CAM) Setup ===")
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

target_layers = [av2.cnn.conv_head] if hasattr(av2.cnn, "conv_head") else [list(av2.cnn.modules())[-1]]

class _Wrapper(nn.Module):
    def __init__(self, model, audio_tensor, T):
        super().__init__()
        self.model = model
        self.audio = audio_tensor
        self.T = T
    def forward(self, x):
        vv = x.unsqueeze(1).repeat(1, self.T, 1, 1, 1)
        return self.model(vv, self.audio)

def gradcam_for_sample(v, a, y, title_prefix=""):
    v_in = v.unsqueeze(0).to(device)
    a_in = a.unsqueeze(0).to(device)

    # predict
    with torch.no_grad():
        logits = av2(v_in, a_in)
        prob = torch.softmax(logits, dim=1)[0].detach().cpu().numpy()
        pred = int(np.argmax(prob))

    # choose a frame
    t_idx = min(5, v.shape[0] - 1)
    frame = v_in[:, t_idx]  # (1,3,H,W)

    wrapper = _Wrapper(av2, a_in, cfg.k_frames).to(device).eval()
    cam = GradCAM(model=wrapper, target_layers=target_layers)

    grayscale_cam = cam(input_tensor=frame, targets=None)[0]  # (H,W)

    img = frame.detach().cpu().squeeze(0).permute(1,2,0).numpy()
    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
    vis = show_cam_on_image(img, grayscale_cam, use_rgb=True)

    plt.figure(figsize=(6,6))
    plt.title(f"{title_prefix} | True={y} Pred={pred} ProbFake={prob[1]:.3f}")
    plt.imshow(vis)
    plt.axis("off")
    plt.show()

def find_example(test_ds, want_true, want_correct=True, max_tries=300):
    for i in range(min(max_tries, len(test_ds))):
        v, a, y = test_ds[i]
        y = int(y.item()) if hasattr(y, "item") else int(y)

        with torch.no_grad():
            logits = av2(v.unsqueeze(0).to(device), a.unsqueeze(0).to(device))
            pred = int(torch.argmax(logits, dim=1).item())

        correct = (pred == y)
        if (y == want_true) and (correct == want_correct):
            return i, v, a, y, pred
    return None


In [None]:
# ===== XAI (Grad-CAM): True Real (True=0, Pred=0) =====

print("\n=== XAI: True REAL (True=0, Pred=0) ===")

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# --- safety checks ---
assert "av2" in globals(), "av2 לא מוגדר. תריצי קודם את התא של load_best + יצירת av2."
assert "test_ds" in globals(), "test_ds לא מוגדר. תריצי קודם את התא של הדאטהסט."
assert "cfg" in globals(), "cfg לא מוגדר. תריצי קודם את תא ה-Config."
assert "device" in globals(), "device לא מוגדר."

# pick a reasonable target layer
target_layers = [av2.cnn.conv_head] if hasattr(av2.cnn, "conv_head") else [list(av2.cnn.modules())[-1]]

class _Wrapper(nn.Module):
    def __init__(self, model, audio_tensor, T):
        super().__init__()
        self.model = model
        self.audio = audio_tensor
        self.T = T
    def forward(self, x):
        vv = x.unsqueeze(1).repeat(1, self.T, 1, 1, 1)
        return self.model(vv, self.audio)

@torch.no_grad()
def predict_one(v, a):
    v_in = v.unsqueeze(0).to(device)
    a_in = a.unsqueeze(0).to(device)
    logits = av2(v_in, a_in)
    prob = torch.softmax(logits, dim=1)[0].detach().cpu().numpy()
    pred = int(np.argmax(prob))
    return pred, prob

def find_true_real_example(ds, max_tries=5000):
    # want: True=0 (real) and Pred=0 (real)
    for i in range(min(max_tries, len(ds))):
        v, a, y = ds[i]
        y_int = int(y.item()) if hasattr(y, "item") else int(y)
        pred, prob = predict_one(v, a)
        if y_int == 0 and pred == 0:
            return i, v, a, y_int, pred, float(prob[1])
    return None

def gradcam_show(v, a, y_true, pred, prob_fake, title_prefix="Grad-CAM (True Real)"):
    v_in = v.unsqueeze(0).to(device)
    a_in = a.unsqueeze(0).to(device)

    # choose a frame
    t_idx = min(5, v.shape[0] - 1)
    frame = v_in[:, t_idx]  # (1,3,H,W)

    wrapper = _Wrapper(av2, a_in, cfg.k_frames).to(device).eval()
    cam = GradCAM(model=wrapper, target_layers=target_layers)

    grayscale_cam = cam(input_tensor=frame, targets=None)[0]  # (H,W)

    img = frame.detach().cpu().squeeze(0).permute(1,2,0).numpy()
    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
    vis = show_cam_on_image(img, grayscale_cam, use_rgb=True)

    plt.figure(figsize=(6, 6))
    plt.title(f"{title_prefix} | True={y_true} Pred={pred} ProbFake={prob_fake:.3f}")
    plt.imshow(vis)
    plt.axis("off")
    plt.show()

res = find_true_real_example(test_ds, max_tries=5000)
if res is None:
    print("Could not find a True=0 & Pred=0 example in the first tries. נסי להגדיל max_tries או לבדוק אם המודל מנבא בעיקר 1.")
else:
    i, v, a, y_true, pred, prob_fake = res
    print(f"Using index: {i} | True={y_true} Pred={pred} ProbFake={prob_fake:.3f}")
    gradcam_show(v, a, y_true, pred, prob_fake, title_prefix="Grad-CAM (True Real)")


In [None]:
res = find_example(test_ds, want_true=1, want_correct=True)
if res is None:
    print("Could not find a correct FAKE example in the first 300 samples.")
else:
    i, v, a, y, pred = res
    print("Using index:", i, "| True:", y, "| Pred:", pred)
    gradcam_for_sample(v, a, y, title_prefix="Grad-CAM (True Fake)")


In [None]:
# Try to find any misclassified example (either real->fake or fake->real)
found = None
for want_true in [0, 1]:
    res = find_example(test_ds, want_true=want_true, want_correct=False)
    if res is not None:
        found = res
        break

if found is None:
    print("Could not find a misclassified example in the first 300 samples.")
else:
    i, v, a, y, pred = found
    print("Using index:", i, "| True:", y, "| Pred:", pred)
    gradcam_for_sample(v, a, y, title_prefix="Grad-CAM (Misclassified)")


#  XAI Analysis – Grad-CAM Interpretation

 <p style="font-size:18px; font-weight:normal;">
The Grad-CAM visualization highlights the spatial regions that most influenced the model’s prediction.In this example, the model focuses primarily on facial areas (eyes, nose, mouth, and skin boundaries), which are known to contain subtle inconsistencies in Deepfake manipulations.This supports the hypothesis that the network learns meaningful visual artifacts rather than relying on background or irrelevant cues, strengthening the interpretability and reliability of the proposed approach.
 </p>
