# Video Deepfake Prediction

This notebook loads a trained deepfake detection model and runs it on a new video. It performs the following steps:

1.  **Configuration**: Set paths for the input video and the trained model.
2.  **Model Loading**: Load the specified model architecture and its saved weights.
3.  **Video Processing**: 
    - Open the video file.
    - For each frame (or a subset of frames), detect faces using MTCNN.
    - For each detected face, run a `REAL` or `FAKE` prediction using the model.
    - Draw a bounding box and the prediction result on the frame.
4.  **Save Output**: Save the processed frames into a new output video file.

### Imports

In [12]:
import os
import cv2
import torch
from PIL import Image
from torchvision import transforms
from facenet_pytorch import MTCNN
import timm
from tqdm.notebook import tqdm # Use tqdm.notebook for better notebook integration

### 1. Configuration

**Action Required:** Change the `video_path` to point to your test video. Ensure `model_path` and `model_name` match the model you trained.

In [13]:
# --- Parameters to set ---
video_path = "../custom_test/test_videos/01.mp4" # <--- CHANGE THIS
model_path = "../models/best_model_efficientnet_b3.pth"          # Path to your trained model
model_name = "efficientnet_b3"                                  # Model architecture used for training

# --- Optional parameters ---
output_dir = "../custom_test/predictions"                                  # Directory to save the output video
frame_skip = 5                                                   # Process 1 frame every 'n' frames to speed up
confidence_threshold = 0.95                                      # Confidence threshold for face detection

### 2. Setup Device and Directories

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

os.makedirs(output_dir, exist_ok=True)

Using device: cuda


### 3. Load Model and Face Detector

In [15]:
# --- Load Deepfake Detection Model ---
print(f"Loading model: {model_name} from {model_path}")
try:
    model = timm.create_model(model_name, pretrained=False, num_classes=2)
    
    # Get model-specific transforms
    model_config = model.default_cfg
    image_size = model_config['input_size'][1:]
    mean = model_config['mean']
    std = model_config['std']

    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    model.eval()
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading model: {e}")
    print("Please ensure model_name and model_path are correct.")

# --- Load Face Detector ---
print("Loading face detector...")
mtcnn = MTCNN(keep_all=True, device=device, post_process=False, min_face_size=40)
print("Face detector loaded.")

Loading model: efficientnet_b3 from ../models/best_model_efficientnet_b3.pth
Model loaded successfully.
Loading face detector...
Face detector loaded.


### 4. Define Image Transforms

These transforms must match the validation transforms used during training.

In [16]:
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

### 5. Process the Video

This cell contains the main loop. It will read the video, detect faces, run predictions, and write the annotated frames to a new video file. This may take some time depending on the video length and your hardware.

In [17]:
# --- Open Video File ---
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
    print(f"Error: Could not open video file {video_path}")
else:
    # --- Video Writer Setup ---
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    output_filename = os.path.join(output_dir, f"predicted_{os.path.basename(video_path)}")
    out = cv2.VideoWriter(output_filename, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))

    frame_count = 0
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    # --- Main Loop ---
    with torch.no_grad(), tqdm(total=total_frames, desc="Processing video") as pbar:
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            if frame_count % frame_skip == 0:
                # Convert frame for face detection (BGR to RGB)
                img_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                boxes, probs = mtcnn.detect(img_pil)

                if boxes is not None:
                    for box, prob in zip(boxes, probs):
                        if prob < confidence_threshold:
                            continue

                        x1, y1, x2, y2 = [int(b) for b in box]
                        
                        # Ensure coordinates are within frame boundaries
                        x1, y1 = max(0, x1), max(0, y1)
                        x2, y2 = min(frame_width, x2), min(frame_height, y2)

                        # Crop face and predict
                        face_pil = img_pil.crop((x1, y1, x2, y2))
                        
                        if face_pil.size[0] == 0 or face_pil.size[1] == 0:
                            continue

                        face_tensor = transform(face_pil).unsqueeze(0).to(device)
                        output = model(face_tensor)
                        softmax_probs = torch.softmax(output, dim=1)
                        pred_prob, pred_label = torch.max(softmax_probs, 1)
                        
                        label_text = "REAL" if pred_label.item() == 0 else "FAKE"
                        confidence = pred_prob.item()
                        
                        # Set color based on prediction
                        color = (0, 255, 0) if label_text == "REAL" else (0, 0, 255) # Green for REAL, Red for FAKE
                        
                        # Draw bounding box and label on the frame
                        cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
                        label_display = f"{label_text}: {confidence:.2f}"
                        cv2.putText(frame, label_display, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)

            out.write(frame)
            frame_count += 1
            pbar.update(1)

    # --- Release Resources ---
    cap.release()
    out.release()
    print(f"\nProcessing complete. Output video saved to: {output_filename}")

Processing video:   0%|          | 0/388 [00:00<?, ?it/s]


Processing complete. Output video saved to: ../custom_test/predictions\predicted_01.mp4
