In [11]:
"""
Complete Test Script for MesoNet+LSTM Deepfake Detection
Test your trained model on a single example from your dataset
Just run this entire script in Google Colab!
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np
from torchvision import transforms
import os
import time
import pandas as pd

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

##############################################
# MODEL ARCHITECTURE (Exact copy from training)
##############################################

class EnhancedMesoNet(nn.Module):
    """Enhanced MesoNet with adaptive architecture"""
    def __init__(self, image_size=128):
        super(EnhancedMesoNet, self).__init__()
        self.image_size = image_size

        # First conv block
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.pool1 = nn.MaxPool2d(2, 2)

        # Second conv block
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(32)
        self.pool2 = nn.MaxPool2d(2, 2)

        # Third conv block
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(64)
        self.pool3 = nn.MaxPool2d(2, 2)

        # Fourth conv block
        self.conv4 = nn.Conv2d(64, 128, 3, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(128)
        self.pool4 = nn.MaxPool2d(2, 2)

        # Fifth conv block
        self.conv5 = nn.Conv2d(128, 256, 3, padding=1, bias=False)
        self.bn5 = nn.BatchNorm2d(256)
        self.pool5 = nn.MaxPool2d(2, 2)

        # Calculate feature dimension
        feature_dim = 256 * (image_size // 32) * (image_size // 32)
        self.feature_size = feature_dim

    def forward(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        x = self.pool4(F.relu(self.bn4(self.conv4(x))))
        x = self.pool5(F.relu(self.bn5(self.conv5(x))))
        return x


class EnhancedMesoNetLSTM(nn.Module):
    """Enhanced MesoNet + LSTM with adaptive parameters"""
    def __init__(self, config):
        super(EnhancedMesoNetLSTM, self).__init__()

        self.config = config

        # Enhanced MesoNet base model
        self.mesonet = EnhancedMesoNet(config['image_size'])
        feature_dim = self.mesonet.feature_size

        # Flatten features
        self.flatten = nn.Flatten()

        # Feature reduction layer
        self.feature_reducer = nn.Sequential(
            nn.Linear(feature_dim, config['lstm_hidden_size']),
            nn.ReLU(),
            nn.BatchNorm1d(config['lstm_hidden_size']),
            nn.Dropout(config['dropout_rate'] * 0.5)
        )

        # LSTM layers
        self.lstm = nn.LSTM(
            input_size=config['lstm_hidden_size'],
            hidden_size=config['lstm_hidden_size'],
            num_layers=config['lstm_layers'],
            batch_first=True,
            bidirectional=True,
            dropout=config['dropout_rate'] if config['lstm_layers'] > 1 else 0
        )

        # Attention mechanism
        self.attention = nn.MultiheadAttention(
            embed_dim=config['lstm_hidden_size'] * 2,
            num_heads=8,
            dropout=config['dropout_rate'],
            batch_first=True
        )

        # Final classification layers
        self.classifier = nn.Sequential(
            nn.Dropout(config['dropout_rate']),
            nn.Linear(config['lstm_hidden_size'] * 2, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(config['dropout_rate'] * 0.5),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Dropout(config['dropout_rate'] * 0.3),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        # x shape: [batch, frames, channels, height, width]
        batch_size, seq_len = x.size(0), x.size(1)

        # Process each frame through MesoNet
        frame_features = []
        for t in range(seq_len):
            frame = x[:, t, :, :, :]
            features = self.mesonet(frame)
            features = self.flatten(features)
            features = self.feature_reducer(features)
            frame_features.append(features)

        # Stack features for LSTM
        lstm_input = torch.stack(frame_features, dim=1)

        # LSTM processing
        lstm_out, _ = self.lstm(lstm_input)

        # Apply attention mechanism
        attended_out, _ = self.attention(lstm_out, lstm_out, lstm_out)

        # Global average pooling over sequence dimension
        pooled_out = torch.mean(attended_out, dim=1)

        # Final classification
        output = self.classifier(pooled_out)

        return output

##############################################
# DEEPFAKE DETECTOR CLASS
##############################################

class DeepfakeDetector:
    """Professional deepfake detection system"""

    def __init__(self, model_path, device=None):
        """
        Initialize the deepfake detector

        Args:
            model_path (str): Path to model_for_local_inference.pth
            device (str): Device to use ('cuda', 'cpu', or None for auto)
        """
        if device is None:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = torch.device(device)

        print(f"🚀 Initializing MesoNet+LSTM Deepfake Detector")
        print(f"   Device: {self.device}")

        self.model = None
        self.config = None
        self.load_model(model_path)
        self.setup_transforms()

    def load_model(self, model_path):
        """Load model from state dict file"""

        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Model file not found: {model_path}")

        print(f"📁 Loading model from: {os.path.basename(model_path)}")

        try:
            # Load model data with weights_only=False for compatibility
            checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)

            # Extract configuration and create model
            self.config = checkpoint['config']
            self.model = EnhancedMesoNetLSTM(self.config)

            # Load trained weights
            self.model.load_state_dict(checkpoint['model_state_dict'])

            # Move to device and set evaluation mode
            self.model.to(self.device)
            self.model.eval()

            accuracy = checkpoint.get('accuracy', 'Unknown')

            print(f"✅ Model loaded successfully!")
            print(f"   Configuration: {self.config['name']}")
            print(f"   Accuracy: {accuracy:.4f} (91.06% on DFDC dataset)")
            print(f"   Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
            print(f"   Frame count: {self.config['frame_count']}")
            print(f"   Image size: {self.config['image_size']}x{self.config['image_size']}")
            print(f"   LSTM hidden size: {self.config['lstm_hidden_size']}")
            print(f"   LSTM layers: {self.config['lstm_layers']}")

        except Exception as e:
            raise RuntimeError(f"Failed to load model: {str(e)}")

    def setup_transforms(self):
        """Setup image preprocessing transforms"""
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def extract_frames(self, video_path):
        """Extract frames from video file"""

        print(f"📹 Opening video: {os.path.basename(video_path)}")

        cap = cv2.VideoCapture(video_path)

        if not cap.isOpened():
            raise ValueError(f"Could not open video: {video_path}")

        frames = []
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        fps = cap.get(cv2.CAP_PROP_FPS)
        duration = frame_count / fps if fps > 0 else 0

        print(f"   📊 Video info: {frame_count} frames, {fps:.1f} FPS, {duration:.1f}s")

        if frame_count <= 0:
            cap.release()
            raise ValueError(f"No frames found in video")

        # Calculate frame indices to extract (evenly distributed)
        indices = np.linspace(0, frame_count - 1, self.config['frame_count'], dtype=int)

        print(f"   🎬 Extracting {self.config['frame_count']} frames...")

        for idx in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()

            if ret:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = cv2.resize(frame, (self.config['image_size'], self.config['image_size']))
                frames.append(frame)

        cap.release()

        # Pad with last frame if needed
        while len(frames) < self.config['frame_count']:
            if len(frames) > 0:
                frames.append(frames[-1])
            else:
                frames.append(np.zeros((self.config['image_size'], self.config['image_size'], 3), dtype=np.uint8))

        print(f"   ✅ Extracted {len(frames)} frames")
        return np.array(frames)

    def predict(self, video_path):
        """
        Predict if a video is fake or real

        Args:
            video_path (str): Path to the video file

        Returns:
            dict: Prediction results
        """

        print(f"\n🔍 Analyzing: {os.path.basename(video_path)}")
        start_time = time.time()

        # Extract and preprocess frames
        frames = self.extract_frames(video_path)

        # Transform frames
        print(f"   🔄 Preprocessing frames...")
        transformed_frames = []
        for frame in frames:
            frame_tensor = self.transform(frame)
            transformed_frames.append(frame_tensor)

        # Create batch tensor
        video_tensor = torch.stack(transformed_frames).unsqueeze(0).to(self.device)

        # Run inference
        print(f"   🧠 Running inference on {self.device}...")
        with torch.no_grad():
            output = self.model(video_tensor)
            probability = torch.sigmoid(output).item()

        # Calculate results
        prediction = "FAKE" if probability > 0.5 else "REAL"
        confidence = max(probability, 1 - probability)

        processing_time = time.time() - start_time

        result = {
            'prediction': prediction,
            'fake_probability': probability,
            'real_probability': 1 - probability,
            'confidence': confidence,
            'processing_time': processing_time
        }

        return result

##############################################
# DATASET INTEGRATION FUNCTIONS
##############################################

def load_dataset_metadata(base_path="/content/drive/MyDrive/Dataset-3"):
    """Load the dataset metadata for random video selection"""

    csv_path = os.path.join(base_path, "global_metadata_cleaned.csv")

    if not os.path.exists(csv_path):
        raise FileNotFoundError(f"Dataset metadata not found: {csv_path}")

    print(f"📊 Loading dataset metadata from: {csv_path}")
    df = pd.read_csv(csv_path)

    print(f"✅ Loaded {len(df)} total samples")
    print(f"📈 Available datasets: {list(df['dataset'].unique())}")
    print(f"🏷️  Label distribution:")
    print(df['label'].value_counts())
    print(f"📂 Subset distribution:")
    print(df['subset'].value_counts())

    return df, base_path

def get_random_video(df, base_path, dataset='DFDC', subset='test', label=None):
    """
    Get one random video from the dataset

    Args:
        df: Dataset metadata DataFrame
        base_path: Base path to dataset
        dataset: Dataset name ('DFDC', 'FF', etc.)
        subset: 'train' or 'test'
        label: 'REAL', 'FAKE', or None for both

    Returns:
        tuple: (video_path, video_info)
    """

    # Filter dataset
    filtered_df = df[df['dataset'] == dataset]
    if subset:
        filtered_df = filtered_df[filtered_df['subset'] == subset]
    if label:
        filtered_df = filtered_df[filtered_df['label'] == label]

    if len(filtered_df) == 0:
        raise ValueError(f"No videos found with criteria: dataset={dataset}, subset={subset}, label={label}")

    # Sample one random video
    sampled_video = filtered_df.sample(n=1, random_state=None).iloc[0]

    # Get full path and info
    video_path = os.path.join(base_path, sampled_video['file_path'])

    video_info = {
        'path': video_path,
        'filename': os.path.basename(video_path),
        'label': sampled_video['label'],
        'dataset': sampled_video['dataset'],
        'subset': sampled_video['subset']
    }

    print(f"🎲 Selected random video from {dataset} {subset} set:")
    print(f"   📹 {video_info['filename']} - {video_info['label']}")

    return video_path, video_info

def print_result_with_ground_truth(video_info, result):
    """Print results with ground truth comparison"""

    print(f"\n🎯 RESULTS for {video_info['filename']}")
    print("=" * 60)

    # Ground truth vs prediction
    gt_emoji = "✅" if video_info['label'] == "REAL" else "🚨"
    pred_emoji = "✅" if result['prediction'] == "REAL" else "🚨"
    correct = result['prediction'] == video_info['label']
    correct_emoji = "✅" if correct else "❌"

    print(f"🏷️  Ground Truth: {gt_emoji} {video_info['label']}")
    print(f"🤖 Prediction: {pred_emoji} {result['prediction']}")
    print(f"🎯 Correct: {correct_emoji} {correct}")
    print(f"🎲 Confidence: {result['confidence']:.4f} ({result['confidence']*100:.1f}%)")

    # Detailed probabilities
    print(f"\n📊 Detailed Analysis:")
    print(f"   Fake probability: {result['fake_probability']:.4f} ({result['fake_probability']*100:.1f}%)")
    print(f"   Real probability: {result['real_probability']:.4f} ({result['real_probability']*100:.1f}%)")
    print(f"   Processing time: {result['processing_time']:.2f}s")

    # Interpretation
    print(f"\n💡 Interpretation:")
    if result['confidence'] > 0.9:
        confidence_level = "Very High"
    elif result['confidence'] > 0.8:
        confidence_level = "High"
    elif result['confidence'] > 0.7:
        confidence_level = "Medium"
    else:
        confidence_level = "Low"

    print(f"   Confidence Level: {confidence_level}")

    if result['prediction'] == "FAKE":
        print(f"   ⚠️  This video appears to be artificially generated or manipulated")
    else:
        print(f"   ✅ This video appears to be authentic")

    # Performance assessment
    if correct:
        print(f"   🎉 Model prediction is CORRECT!")
    else:
        print(f"   😞 Model prediction is INCORRECT")

    return correct

##############################################
# MAIN TESTING FUNCTION
##############################################

def test_single_video_from_dataset():
    """Test the model on a single video from the dataset"""

    print("🎬 MesoNet+LSTM Single Video Test")
    print("📈 Trained accuracy: 91.06% on DFDC dataset")
    print("🎲 Testing on one random video from your training dataset")
    print("=" * 70)

    # Set paths
    base_path = "/content/drive/MyDrive/Dataset-3"
    model_path = "/content/drive/MyDrive/Dataset-3/dfdc_training_run_20250527_143110/model_for_local_inference.pth"

    try:
        # Load dataset metadata
        df, base_path = load_dataset_metadata(base_path)

        # Get one random video from test set
        video_path, video_info = get_random_video(
            df, base_path,
            dataset='DFDC',
            subset='test'  # Use test set for fair evaluation
        )

        # Check if video file exists
        if not os.path.exists(video_path):
            print(f"❌ Video file not found: {video_path}")
            return

        # Initialize detector
        detector = DeepfakeDetector(model_path)

        # Run prediction
        result = detector.predict(video_path)

        # Print results with ground truth comparison
        correct = print_result_with_ground_truth(video_info, result)

        # Final summary
        print(f"\n📋 TEST SUMMARY:")
        print("=" * 30)
        print(f"Video: {video_info['filename']}")
        print(f"Ground Truth: {video_info['label']}")
        print(f"Prediction: {result['prediction']}")
        print(f"Correct: {'✅ YES' if correct else '❌ NO'}")
        print(f"Confidence: {result['confidence']:.4f}")
        print(f"Processing Time: {result['processing_time']:.2f}s")

        if correct:
            print(f"\n🎉 SUCCESS! Your model correctly identified this video!")
        else:
            print(f"\n😞 Your model made an incorrect prediction on this video.")
            print(f"💡 This is normal - even 91.06% accuracy means some errors occur.")

        return result, video_info, correct

    except Exception as e:
        print(f"❌ Error during testing: {str(e)}")
        import traceback
        traceback.print_exc()
        return None, None, False

##############################################
# RUN THE TEST
##############################################

if __name__ == "__main__":
    print("🚀 Starting single video test...")
    print("🔄 This will test your 91.06% accuracy model on one random video from your dataset")
    print()

    # Run the test
    result, video_info, correct = test_single_video_from_dataset()

    if result is not None:
        print(f"\n✅ Test completed successfully!")
        print(f"🎯 Your model's prediction was {'CORRECT' if correct else 'INCORRECT'}")
    else:
        print(f"\n❌ Test failed - check the error messages above")

    print(f"\n🏁 Test finished!")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
🚀 Starting single video test...
🔄 This will test your 91.06% accuracy model on one random video from your dataset

🎬 MesoNet+LSTM Single Video Test
📈 Trained accuracy: 91.06% on DFDC dataset
🎲 Testing on one random video from your training dataset
📊 Loading dataset metadata from: /content/drive/MyDrive/Dataset-3/global_metadata_cleaned.csv
✅ Loaded 5281 total samples
📈 Available datasets: ['DFDC', 'FF']
🏷️  Label distribution:
label
REAL    2719
FAKE    2562
Name: count, dtype: int64
📂 Subset distribution:
subset
train    4224
test     1057
Name: count, dtype: int64
🎲 Selected random video from DFDC test set:
   📹 bztdemptfg.mp4 - FAKE
🚀 Initializing MesoNet+LSTM Deepfake Detector
   Device: cpu
📁 Loading model from: model_for_local_inference.pth
✅ Model loaded successfully!
   Configuration: base_config
   Accuracy: 0.9106 (91.06% on DFDC dataset)
   Paramet