In [15]:
import os
import cv2
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image, ImageOps
from transformers import AutoModelForImageClassification, AutoImageProcessor
from facenet_pytorch import MTCNN
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report
from torch.utils.data import Dataset, DataLoader

# =================CONFIGURATION=================
DATASET_ROOT = "/kaggle/input/celeb-df-v2" 
TXT_FILE_NAME = "List_of_testing_videos.txt"

MODEL_ID = "sakshamkr1/deitfake-v2"

# Accuracy Settings
FRAMES_PER_VIDEO = 15      
BATCH_SIZE = 1 
NUM_WORKERS = 4
MARGIN = 1.3  # Margin multiplier for face crop
ENABLE_TTA = True # Test Time Augmentation (Horizontal Flip)
# ===============================================

# Device Setup
try:
    import torch_xla.core.xla_model as xm
    DEVICE = xm.xla_device()
    print(f"--- Running on TPU: {DEVICE} ---")
except:
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"--- Running on Device: {DEVICE} ---")

class CelebDFDataset(Dataset):
    def __init__(self, video_paths, labels, processor, frames_per_video=10, mtcnn=None):
        self.video_paths = video_paths
        self.labels = labels
        self.processor = processor
        self.frames_per_video = frames_per_video
        self.mtcnn = mtcnn

    def __len__(self):
        return len(self.video_paths)

    def extract_faces_high_res(self, video_path):
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened(): return []
        
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        if total_frames <= 0: return []

        frame_indices = np.linspace(0, total_frames - 1, self.frames_per_video, dtype=int)
        frames_pil = []
        
        for idx in frame_indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if not ret: continue
            
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames_pil.append(Image.fromarray(frame_rgb))

        cap.release()
        if not frames_pil: return []

        try:
            boxes_list, _ = self.mtcnn.detect(frames_pil)
        except Exception as e:
            return []

        final_faces = []
        for i, boxes in enumerate(boxes_list):
            if boxes is not None:
                box = boxes[0] 
                x1, y1, x2, y2 = box
                w = x2 - x1
                h = y2 - y1
                cx = x1 + w / 2
                cy = y1 + h / 2
                
                new_w = w * MARGIN
                new_h = h * MARGIN
                
                x1 = max(0, cx - new_w / 2)
                y1 = max(0, cy - new_h / 2)
                x2 = min(frames_pil[i].width, cx + new_w / 2)
                y2 = min(frames_pil[i].height, cy + new_h / 2)
                
                face = frames_pil[i].crop((x1, y1, x2, y2))
                face = face.resize((224, 224), Image.Resampling.BILINEAR)
                final_faces.append(face)
        
        return final_faces

    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        label = self.labels[idx]
        
        if not os.path.exists(video_path): return None

        faces = self.extract_faces_high_res(video_path)
        if not faces: return None

        # --- Test Time Augmentation (TTA) Logic ---
        if ENABLE_TTA:
            # Create flipped versions of all faces
            flipped_faces = [ImageOps.mirror(f) for f in faces]
            # Combine original + flipped
            all_faces = faces + flipped_faces
            inputs = self.processor(images=all_faces, return_tensors="pt")
        else:
            inputs = self.processor(images=faces, return_tensors="pt")

        return {
            "pixel_values": inputs["pixel_values"], 
            "label": torch.tensor(label, dtype=torch.long),
            "video_path": video_path
        }

def collate_fn(batch):
    batch = [b for b in batch if b is not None]
    if len(batch) == 0: return None
    return batch

def load_paths_from_txt():
    real_paths = []
    fake_paths = []
    
    txt_path = os.path.join(DATASET_ROOT, TXT_FILE_NAME)
    print(f"--- Loading metadata from {txt_path} ---")
    
    if not os.path.exists(txt_path):
        print(f"!! ERROR: Metadata file not found at {txt_path}")
        return [], []

    with open(txt_path, 'r') as f:
        lines = f.readlines()

    for line in lines:
        line = line.strip()
        if not line: continue
        parts = line.split()
        if len(parts) < 2: continue
        
        label_in_file = parts[0] 
        rel_path = parts[1]
        full_path = os.path.join(DATASET_ROOT, rel_path)
        
        if label_in_file == '1':
            real_paths.append(full_path)
        else:
            fake_paths.append(full_path)

    print(f"Loaded {len(real_paths)} Real videos")
    print(f"Loaded {len(fake_paths)} Fake videos")
    
    paths = real_paths + fake_paths
    labels = [0] * len(real_paths) + [1] * len(fake_paths)
    return paths, labels

def main():
    print(f"--- Loading DeitFake: {MODEL_ID} ---")
    processor = AutoImageProcessor.from_pretrained(MODEL_ID, use_fast=True)
    model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
    model.to(DEVICE)
    model.eval()
    
    print("--- Init MTCNN (High Accuracy) ---")
    mtcnn_device = torch.device("cpu") if "xla" in str(DEVICE) else DEVICE
    mtcnn = MTCNN(
        keep_all=False, 
        select_largest=True, 
        device=mtcnn_device,
        thresholds=[0.6, 0.7, 0.7] 
    )

    video_paths, labels = load_paths_from_txt()
    
    if not video_paths:
        print("No videos found! Check DATASET_ROOT.")
        return

    dataset = CelebDFDataset(video_paths, labels, processor, FRAMES_PER_VIDEO, mtcnn)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=NUM_WORKERS)

    print(f"--- Starting Validation (TTA Enabled: {ENABLE_TTA}) ---")
    
    y_true = []
    y_scores = []
    results = []

    with torch.no_grad():
        for batch in tqdm(dataloader):
            if batch is None: continue
            data = batch[0]
            
            pixel_values = data["pixel_values"].to(DEVICE)
            label = data["label"].item()
            path = data["video_path"]

            # Forward pass
            outputs = model(pixel_values)
            probs = torch.softmax(outputs.logits, dim=1)
            
            # --- TTA Aggregation ---
            # If TTA is on, pixel_values has 2x frames (Originals then Flips)
            # We take the mean across ALL of them (robustness)
            
            # Index 0 is Fake (from your training setup)
            fake_prob = probs[:, 0].mean().item() 
            
            y_true.append(label)
            y_scores.append(fake_prob)
            results.append({"video": os.path.basename(path), "label": label, "score": fake_prob})

    # Save results
    df_res = pd.DataFrame(results)
    df_res.to_csv("celebdf_results_tta.csv", index=False)
    
    y_pred_binary = (np.array(y_scores) > 0.5).astype(int)
    acc = accuracy_score(y_true, y_pred_binary)
    auc = roc_auc_score(y_true, y_scores)

    print("\n" + "="*30)
    print(f"ACCURACY: {acc:.4f}")
    print(f"AUC:      {auc:.4f}")
    print("="*30)
    print(classification_report(y_true, y_pred_binary, target_names=["Real", "Fake"]))

if __name__ == "__main__":
    main()

  DEVICE = xm.xla_device()


--- Running on TPU: xla:0 ---
--- Loading DeitFake: sakshamkr1/deitfake-v2 ---
--- Init MTCNN (High Accuracy) ---
--- Loading metadata from /kaggle/input/celeb-df-v2/List_of_testing_videos.txt ---
Loaded 178 Real videos
Loaded 340 Fake videos
--- Starting Validation (TTA Enabled: True) ---



  0%|          | 0/518 [00:00<?, ?it/s][A
  0%|          | 1/518 [00:09<1:18:29,  9.11s/it][A
  1%|          | 4/518 [00:09<15:03,  1.76s/it]  [A
  1%|▏         | 7/518 [00:09<07:06,  1.20it/s][A
  2%|▏         | 10/518 [00:11<06:51,  1.24it/s][A
  2%|▏         | 12/518 [00:12<05:24,  1.56it/s][A
  3%|▎         | 14/518 [00:13<05:21,  1.57it/s][A
  3%|▎         | 16/518 [00:13<04:24,  1.90it/s][A
  3%|▎         | 18/518 [00:16<05:45,  1.45it/s][A
  4%|▍         | 20/518 [00:16<04:37,  1.79it/s][A
  4%|▍         | 22/518 [00:18<05:30,  1.50it/s][A
  5%|▍         | 25/518 [00:18<03:27,  2.37it/s][A
  5%|▌         | 26/518 [00:21<06:09,  1.33it/s][A
  6%|▌         | 29/518 [00:21<03:47,  2.15it/s][A
  6%|▌         | 31/518 [00:23<05:11,  1.56it/s][A
  7%|▋         | 34/518 [00:24<04:44,  1.70it/s][A
  7%|▋         | 37/518 [00:24<03:10,  2.52it/s][A
  8%|▊         | 39/518 [00:26<04:21,  1.83it/s][A
  8%|▊         | 42/518 [00:29<05:09,  1.54it/s][A
  9%|▊         | 45


ACCURACY: 0.6544
AUC:      0.5000
              precision    recall  f1-score   support

        Real       0.49      0.16      0.24       178
        Fake       0.67      0.91      0.78       340

    accuracy                           0.65       518
   macro avg       0.58      0.54      0.51       518
weighted avg       0.61      0.65      0.59       518




