In [None]:
# ============================================================================
# AI-GENERATED MEDIA DETECTION - GEN 3 AGNOSTIC MODEL (FIXED)
# ============================================================================
# This model detects AI-generated content across GAN, Diffusion, and 
# Autoregressive generators using multi-stream feature extraction and
# hierarchical fusion.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
import torchaudio
import os
import glob
import random
from torch.utils.data import Dataset, DataLoader
from transformers import VideoMAEModel, VideoMAEConfig
from torchvision.transforms import Compose, Resize, Normalize
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
import warnings

warnings.filterwarnings('ignore')

In [None]:
# ============================================================================
# CONFIGURATION
# ============================================================================
@dataclass
class Config:
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Temporal sampling
    n_windows: int = 3
    frames_per_clip: int = 16
    min_frames: int = 48
    
    # Spatial sampling
    full_res: Tuple[int, int] = (224, 224)
    
    # Feature dimensions
    mae_hidden: int = 768
    mae_compressed: int = 128
    fft_dim: int = 256
    fft_compressed: int = 64
    physics_compressed: int = 32
    audio_compressed: int = 32
    drift_compressed: int = 16
    
    # Audio
    sample_rate: int = 16000
    n_mfcc: int = 13
    
    # Scene change detection
    scene_change_threshold: float = 30.0
    
    # Motion detection
    low_motion_threshold: float = 0.5
    
    # Transformer dimensions
    fusion_dim: int = 128
    n_heads: int = 4

    # Training
    batch_size: int = 8
    lr_branches: float = 1e-4
    lr_fusion: float = 5e-5
    
    checkpoint_path: Optional[str] = None

CONFIG = Config()
print(f"Using device: {CONFIG.device}")

In [None]:
# ============================================================================
# UTILITY CLASSES
# ============================================================================

class SceneChangeDetector:
    """Detects hard cuts in video to prevent false positives in drift detection."""
    
    def __init__(self, threshold: float = 30.0):
        self.threshold = threshold
    
    def detect_cuts(self, frames: List[np.ndarray]) -> List[int]:
        cuts = []
        for i in range(len(frames) - 1):
            gray1 = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
            gray2 = cv2.cvtColor(frames[i + 1], cv2.COLOR_BGR2GRAY)
            diff = np.mean(np.abs(gray1.astype(float) - gray2.astype(float)))
            if diff > self.threshold:
                cuts.append(i + 1)
        return cuts

In [None]:
# ============================================================================
# STREAM 1: VideoMAE - Semantic Understanding (Autoregressive Detection)
# ============================================================================

class VideoMAEStream(nn.Module):
    """
    Detects semantic drift and object permanence violations.
    Best for: Autoregressive models, scene coherence
    """
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        mae_config = VideoMAEConfig.from_pretrained("MCG-NJU/videomae-base")
        self.model = VideoMAEModel.from_pretrained("MCG-NJU/videomae-base", config=mae_config)
        
        # Freeze backbone to save VRAM (unfreeze for fine-tuning)
        for param in self.model.parameters():
            param.requires_grad = False
            
        self.compression = nn.Sequential(
            nn.Linear(config.mae_hidden * config.n_windows, 512),
            nn.LayerNorm(512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, config.mae_compressed)
        )
        self.normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    
    def extract_features(self, clips: torch.Tensor, has_scene_cuts: bool = False) -> Dict[str, torch.Tensor]:
        """
        Args:
            clips: (n_windows, frames, 3, H, W) tensor
        Returns:
            Dict with embeddings, compressed features, and drift metrics
        """
        n_windows = clips.shape[0]
        embeddings = []
        
        with torch.no_grad():  # Backbone is frozen
            for i in range(n_windows):
                clip = clips[i].unsqueeze(0)
                outputs = self.model(pixel_values=clip.to(self.config.device))
                embedding = outputs.last_hidden_state.mean(dim=1)
                embedding = F.normalize(embedding, p=2, dim=1)
                embeddings.append(embedding)
        
        embeddings = torch.cat(embeddings, dim=0)  # (n_windows, hidden)
        
        # Drift metrics (statistical, detached from gradient)
        e = embeddings
        drift_primary = 1 - F.cosine_similarity(e[0:1], e[-1:], dim=1).item()
        drift_1_to_2 = 1 - F.cosine_similarity(e[0:1], e[1:2], dim=1).item()
        drift_2_to_3 = 1 - F.cosine_similarity(e[1:2], e[2:3], dim=1).item()
        drift_acceleration = drift_2_to_3 - drift_1_to_2
        drift_axis_consistency = torch.var(e, dim=0).mean().item()
        
        if has_scene_cuts:
            drift_primary *= 0.3
            drift_acceleration *= 0.3
            
        # Compress (Gradient flows through here!)
        flat_embeddings = embeddings.flatten().unsqueeze(0) 
        compressed = self.compression(flat_embeddings)
        
        return {
            'embeddings': embeddings,  # Keep on graph for transformer
            'compressed': compressed,
            'drift_metrics': {
                'drift_primary': drift_primary,
                'drift_acceleration': drift_acceleration,
                'drift_axis_consistency': drift_axis_consistency
            }
        }

In [None]:
# ============================================================================
# STREAM 2: Audio Analysis (Voice Cloning / Lip-Sync Detection)
# ============================================================================

class AudioStream:
    """
    Extracts MFCC features for audio artifact detection and temporal sync.
    Best for: Voice cloners (ElevenLabs), lip-sync tools (Wav2Lip)
    """
    def __init__(self, config: Config):
        self.config = config
        self.mfcc_transform = torchaudio.transforms.MFCC(
            sample_rate=config.sample_rate,
            n_mfcc=config.n_mfcc
        )
    
    def extract_features(self, audio_waveform: Optional[torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Args:
            audio_waveform: (channels, samples) tensor or None
        Returns:
            Dict with MFCC features (as tensors for gradient potential)
        """
        # No audio case - return zeros
        if audio_waveform is None or audio_waveform.shape[-1] < 1000:
            return {
                'mfcc_features': torch.zeros(self.config.n_mfcc * 3),
                'temporal_features': torch.zeros((self.config.n_windows, self.config.n_mfcc)),
                'artifact_score': 0.0,
                'sync_score': 0.5,
                'has_audio': 0.0
            }
        
        # Convert stereo to mono
        if audio_waveform.shape[0] > 1:
            audio_waveform = audio_waveform.mean(dim=0, keepdim=True)
            
        # MFCC: output shape is (1, n_mfcc, time)
        mfcc = self.mfcc_transform(audio_waveform) 
        
        # FIX: Correct temporal slicing on Dimension 2 (Time), not Dimension 1 (Freq)
        total_time = mfcc.shape[-1]
        chunk_size = max(1, total_time // self.config.n_windows)
        temporal_features = []
        
        for i in range(self.config.n_windows):
            start = i * chunk_size
            end = min((i + 1) * chunk_size, total_time)
            # Slice [channel, n_mfcc, time_slice] -> mean over time
            chunk = mfcc[:, :, start:end]
            temporal_features.append(chunk.mean(dim=-1).squeeze(0))  # (n_mfcc,)
            
        temporal_features = torch.stack(temporal_features)  # (n_windows, n_mfcc)
        
        # Global stats for artifact detection
        mfcc_mean = mfcc.mean(dim=-1).squeeze(0)  # (n_mfcc,)
        mfcc_std = mfcc.std(dim=-1).squeeze(0)    # (n_mfcc,)
        mfcc_delta = torch.diff(mfcc, dim=-1).mean(dim=-1).squeeze(0)  # (n_mfcc,)
        mfcc_features = torch.cat([mfcc_mean, mfcc_std, mfcc_delta])
        
        # Spectral smoothness (synthetic audio is often "too clean")
        spectral_smoothness = 1.0 - min(mfcc_std.mean().item() / 10.0, 1.0)
        
        return {
            'mfcc_features': mfcc_features,
            'temporal_features': temporal_features,
            'artifact_score': spectral_smoothness,
            'sync_score': 0.5,  # TODO: Replace with actual sync computation
            'has_audio': 1.0
        }

In [None]:
# ============================================================================
# STREAM 3: Physics Consistency (Diffusion Model Detection)
# ============================================================================

class PhysicsStream:
    """
    Detects motion inconsistencies and physics violations via Optical Flow.
    Best for: Diffusion models (Sora, SVD, Kling, Runway)
    
    Note: OpenCV-based, so gradients don't flow. Acts as static feature input.
    For differentiable flow, consider RAFT integration.
    """
    def __init__(self, config: Config):
        self.config = config

    def extract_features(self, frames: List[np.ndarray]) -> Dict[str, torch.Tensor]:
        """
        Args:
            frames: List of grayscale frames
        Returns:
            Dict with physics metrics as tensors
        """
        n_windows = self.config.n_windows
        chunk_size = len(frames) // n_windows
        
        consistency_errors = []
        outlier_scores = []
        smoothness_scores = []
        is_low_motion_flags = []
        
        for i in range(n_windows):
            window = frames[i*chunk_size : (i+1)*chunk_size]
            if len(window) < 2:
                consistency_errors.append(0.0)
                outlier_scores.append(0.0)
                smoothness_scores.append(0.0)
                is_low_motion_flags.append(1.0)
                continue
            
            window_errors = []
            magnitudes = []
            
            for t in range(len(window) - 1):
                prev, curr = window[t], window[t+1]
                # Forward-Backward Optical Flow Consistency
                flow_fw = cv2.calcOpticalFlowFarneback(prev, curr, None, 0.5, 3, 15, 3, 5, 1.2, 0)
                flow_bw = cv2.calcOpticalFlowFarneback(curr, prev, None, 0.5, 3, 15, 3, 5, 1.2, 0)
                mag_sq = np.sum((flow_fw + flow_bw)**2, axis=2)
                window_errors.append(np.sqrt(mag_sq).mean())
                magnitudes.append(np.sqrt(np.sum(flow_fw**2, axis=2)).mean())
            
            consistency_errors.append(np.mean(window_errors) if window_errors else 0.0)
            avg_mag = np.mean(magnitudes) if magnitudes else 0.0
            smoothness_scores.append(np.var(magnitudes) if avg_mag > 0.1 else 0.0)
            outlier_scores.append(np.sum(np.array(window_errors) > 1.0) / len(window_errors) if window_errors else 0.0)
            is_low_motion_flags.append(float(avg_mag < self.config.low_motion_threshold))

        # Pack into tensors for downstream processing
        return {
            'consistency_errors': torch.tensor(consistency_errors, dtype=torch.float32),
            'outlier_scores': torch.tensor(outlier_scores, dtype=torch.float32),
            'smoothness_scores': torch.tensor(smoothness_scores, dtype=torch.float32),
            'motion_ranges': torch.zeros(n_windows * 2),  # Simplified placeholder
            'physics_violation': torch.tensor([np.mean(consistency_errors)], dtype=torch.float32),
            'low_motion_score': torch.tensor([np.mean(is_low_motion_flags)], dtype=torch.float32)
        }

In [None]:
# ============================================================================
# STREAM 4: Frequency Analysis (GAN Artifact Detection)
# ============================================================================

class FrequencyStream:
    """
    Detects upscaling artifacts and checkerboard patterns in frequency domain.
    Best for: GAN-generated videos (StyleGAN, ProGAN, DeepFaceLab)
    """
    def __init__(self, config: Config):
        self.config = config
    
    def extract_features(self, frames: List[np.ndarray]) -> Dict[str, np.ndarray]:
        """
        Args:
            frames: List of grayscale frames
        Returns:
            Dict with FFT signatures (as numpy for legacy compatibility)
        """
        n_windows = self.config.n_windows
        chunk_size = len(frames) // n_windows
        
        signatures = []
        artifact_scores = []
        
        for i in range(n_windows):
            window = frames[i*chunk_size : (i+1)*chunk_size]
            if not window:
                signatures.append(np.zeros(self.config.fft_dim))
                artifact_scores.append(0.0)
                continue
            
            # Use middle keyframe
            keyframe = window[len(window)//2]
            keyframe_norm = keyframe.astype(float) / 255.0
            
            # 2D FFT
            f = np.fft.fft2(keyframe_norm)
            fshift = np.fft.fftshift(f)
            magnitude = np.abs(fshift)
            
            # Log scale normalization
            magnitude = np.log(magnitude + 1e-8)
            magnitude = (magnitude - magnitude.min()) / (magnitude.max() - magnitude.min() + 1e-8)
            
            # Azimuthal integration (radial profile)
            h, w = magnitude.shape
            center = (w//2, h//2)
            y, x = np.ogrid[:h, :w]
            r = np.sqrt((x - center[0])**2 + (y - center[1])**2).astype(int)
            
            tbin = np.bincount(r.ravel(), magnitude.ravel())
            nr = np.bincount(r.ravel())
            radial_profile = tbin / (nr + 1e-8)
            
            # Normalize to fixed size
            if len(radial_profile) > self.config.fft_dim:
                radial_profile = radial_profile[:self.config.fft_dim]
            else:
                radial_profile = np.pad(radial_profile, (0, self.config.fft_dim - len(radial_profile)))
            
            signatures.append(radial_profile)
            
            # Artifact score: checkerboard detection (32-64 pixel range)
            if len(radial_profile) >= 64:
                checkerboard = radial_profile[32:64]
                baseline = radial_profile[64:128].mean() if len(radial_profile) >= 128 else radial_profile.mean()
                artifact_scores.append(np.clip((np.max(checkerboard) - baseline) / (baseline + 1e-8), 0, 1))
            else:
                artifact_scores.append(0.0)
        
        signatures = np.array(signatures)
        artifact_scores = np.array(artifact_scores)
        temporal_stability = np.var(signatures, axis=0).mean()
        
        return {
            'signatures': signatures,
            'artifact_scores': artifact_scores,
            'temporal_stability': temporal_stability
        }

In [None]:
# ============================================================================
# FEATURE COMPRESSION MODULE (Gradient-Safe)
# ============================================================================

class FeatureCompression(nn.Module):
    """
    Compresses multi-stream features into fixed-size vectors for fusion.
    All operations preserve gradient flow where applicable.
    """
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        
        # FFT Compression
        fft_input_dim = config.fft_dim * config.n_windows + config.n_windows + 1
        self.fft_compressor = nn.Sequential(
            nn.Linear(fft_input_dim, 128),
            nn.LayerNorm(128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, config.fft_compressed)
        )
        
        # Physics Compression
        physics_input_dim = config.n_windows * 5 + 2
        self.physics_compressor = nn.Sequential(
            nn.Linear(physics_input_dim, 64),
            nn.LayerNorm(64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, config.physics_compressed)
        )
        
        # Audio Compression
        audio_input_dim = config.n_mfcc * 3 + 3
        self.audio_compressor = nn.Sequential(
            nn.Linear(audio_input_dim, 64),
            nn.LayerNorm(64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, config.audio_compressed)
        )
        
        # Drift Compression
        drift_input_dim = 4
        self.drift_compressor = nn.Sequential(
            nn.Linear(drift_input_dim, 32),
            nn.LayerNorm(32),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(32, config.drift_compressed)
        )
        
        # Projections for Cross-Attention Transformer
        self.audio_temporal_proj = nn.Linear(config.n_mfcc, config.mae_compressed)
        self.video_temporal_proj = nn.Linear(config.mae_hidden, config.mae_compressed)

    def forward(self, features: Dict) -> Dict[str, torch.Tensor]:
        """Compress all stream features for hierarchical fusion."""
        dev = self.config.device
        
        # --- FFT ---
        fft_sig = torch.from_numpy(features['fft']['signatures']).float().to(dev).flatten()
        fft_scores = torch.from_numpy(features['fft']['artifact_scores']).float().to(dev)
        fft_stab = torch.tensor([features['fft']['temporal_stability']]).float().to(dev)
        fft_in = torch.cat([fft_sig, fft_scores, fft_stab])
        fft_out = self.fft_compressor(fft_in.unsqueeze(0))
        
        # --- Physics ---
        phy = features['physics']
        phy_in = torch.cat([
            phy['consistency_errors'].float().to(dev),
            phy['outlier_scores'].float().to(dev),
            phy['motion_ranges'].float().to(dev),
            phy['smoothness_scores'].float().to(dev),
            phy['physics_violation'].float().to(dev),
            phy['low_motion_score'].float().to(dev)
        ])
        phy_out = self.physics_compressor(phy_in.unsqueeze(0))
        
        # --- Audio ---
        aud = features['audio']
        aud_in = torch.cat([
            aud['mfcc_features'].float().to(dev),
            torch.tensor([aud['artifact_score']]).float().to(dev),
            torch.tensor([aud['sync_score']]).float().to(dev),
            torch.tensor([aud['has_audio']]).float().to(dev)
        ])
        aud_out = self.audio_compressor(aud_in.unsqueeze(0))
        
        # --- Drift (from MAE) ---
        mae = features['mae']
        drift_metrics = mae['drift_metrics']
        drift_in = torch.tensor([
            drift_metrics['drift_primary'],
            drift_metrics['drift_acceleration'],
            drift_metrics['drift_axis_consistency'],
            0.0  # has_scene_cuts placeholder
        ]).float().to(dev)
        drift_out = self.drift_compressor(drift_in.unsqueeze(0))
        
        # --- Sequence Projections for Transformer ---
        # Audio: (n_windows, n_mfcc) -> (n_windows, mae_compressed)
        aud_seq_in = aud['temporal_features'].float().to(dev)
        aud_seq_emb = self.audio_temporal_proj(aud_seq_in)
        
        # Video: (n_windows, mae_hidden) -> (n_windows, mae_compressed)
        vid_seq_in = mae['embeddings'].to(dev)
        vid_seq_emb = self.video_temporal_proj(vid_seq_in)
        
        return {
            'mae': mae['compressed'],
            'fft': fft_out,
            'physics': phy_out,
            'audio': aud_out,
            'drift': drift_out,
            'sequence_video': vid_seq_emb,
            'sequence_audio': aud_seq_emb
        }

In [None]:
# ============================================================================
# HIERARCHICAL DETECTOR (Multi-Branch + Transformer Fusion)
# ============================================================================

class HierarchicalDetector(nn.Module):
    """
    Two-layer hierarchical detection:
    Layer 1: Specialized branch detectors (GAN, Diffusion, Autoregressive, Audio, Semantic)
    Layer 2: Cross-modal attention fusion for final prediction
    """
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        
        # Layer 1: Branch Detectors (simplified for gradient stability)
        self.gan_detector = nn.Sequential(
            nn.Linear(config.fft_compressed, 1),
            nn.Sigmoid()
        )
        self.diffusion_detector = nn.Sequential(
            nn.Linear(config.physics_compressed, 1),
            nn.Sigmoid()
        )
        self.autoregressive_detector = nn.Sequential(
            nn.Linear(config.drift_compressed, 1),
            nn.Sigmoid()
        )
        self.audio_detector = nn.Sequential(
            nn.Linear(config.audio_compressed, 1),
            nn.Sigmoid()
        )
        self.semantic_detector = nn.Sequential(
            nn.Linear(config.mae_compressed, 1),
            nn.Sigmoid()
        )
        
        # Layer 2: Cross-Modal Attention (Video ← Audio)
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=config.mae_compressed,
            num_heads=config.n_heads,
            batch_first=True
        )
        
        # Final Fusion MLP
        # Input: 5 branch scores + 1 conflict + mae_compressed sync context
        fusion_dim = 5 + 1 + config.mae_compressed
        self.fusion_mlp = nn.Sequential(
            nn.Linear(fusion_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        
    def forward(self, feats: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            feats: Compressed features from FeatureCompression
        Returns:
            (p_fake, branch_predictions) tuple
        """
        # Layer 1: Branch predictions
        p_gan = self.gan_detector(feats['fft'])
        p_diffusion = self.diffusion_detector(feats['physics'])
        p_autoregressive = self.autoregressive_detector(feats['drift'])
        p_audio = self.audio_detector(feats['audio'])
        p_semantic = self.semantic_detector(feats['mae'])
        
        branch_preds = torch.cat([p_gan, p_diffusion, p_autoregressive, p_audio, p_semantic], dim=1)
        
        # Conflict score (high std = branches disagree = uncertainty)
        conflict = torch.std(branch_preds, dim=1, keepdim=True)
        
        # Layer 2: Cross-Attention for A/V Sync
        vid = feats['sequence_video'].unsqueeze(0)  # (1, n_windows, dim)
        aud = feats['sequence_audio'].unsqueeze(0)  # (1, n_windows, dim)
        attn_out, _ = self.cross_attn(query=vid, key=aud, value=aud)
        sync_ctx = attn_out.mean(dim=1)  # (1, dim)
        
        # Final Fusion
        fusion_in = torch.cat([branch_preds, conflict, sync_ctx], dim=1)
        p_final = self.fusion_mlp(fusion_in)
        
        return p_final, branch_preds

In [None]:
# ============================================================================
# DETECTION PIPELINE (End-to-End)
# ============================================================================

class DetectionPipeline(nn.Module):
    """
    Complete end-to-end pipeline for AI-generated media detection.
    Handles video loading, feature extraction, and classification.
    """
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        
        # Feature Extraction Streams
        self.mae_stream = VideoMAEStream(config)
        self.audio_stream = AudioStream(config)
        self.physics_stream = PhysicsStream(config)
        self.freq_stream = FrequencyStream(config)
        
        # Neural Processing
        self.compressor = FeatureCompression(config)
        self.detector = HierarchicalDetector(config)
        
        # Utility
        self.scene_detector = SceneChangeDetector(config.scene_change_threshold)
        
    def forward_one_video(self, video_path: str, audio_path: Optional[str] = None):
        """
        Process a single video through the full pipeline.
        
        Args:
            video_path: Path to video file
            audio_path: Optional separate audio file path
        Returns:
            (p_fake, branch_predictions) or None if video too short
        """
        # 1. LOAD VIDEO
        cap = cv2.VideoCapture(video_path)
        frames = []
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frames.append(frame)
        cap.release()
        
        if len(frames) < self.config.min_frames:
            return None  # Skip short videos
        
        # 2. PREPROCESS
        frames_rgb = [cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frames]
        gray_frames = [cv2.cvtColor(f, cv2.COLOR_BGR2GRAY) for f in frames]
        
        # Detect scene cuts
        has_scene_cuts = len(self.scene_detector.detect_cuts(frames)) > 0
        
        # 3. PREPARE VideoMAE INPUT
        total_samples = self.config.n_windows * self.config.frames_per_clip
        indices = np.linspace(0, len(frames)-1, total_samples, dtype=int)
        
        mae_input = []
        for i in range(self.config.n_windows):
            clip_frames = []
            for j in range(self.config.frames_per_clip):
                idx = indices[i * self.config.frames_per_clip + j]
                tensor = torch.from_numpy(frames_rgb[idx]).permute(2, 0, 1).float() / 255.0
                tensor = F.interpolate(tensor.unsqueeze(0), size=self.config.full_res).squeeze(0)
                tensor = self.mae_stream.normalize(tensor)
                clip_frames.append(tensor)
            mae_input.append(torch.stack(clip_frames))
        mae_input = torch.stack(mae_input)  # (n_windows, frames, 3, H, W)
        
        # 4. EXTRACT FEATURES
        f_mae = self.mae_stream.extract_features(mae_input, has_scene_cuts)
        f_phy = self.physics_stream.extract_features(gray_frames)
        f_frq = self.freq_stream.extract_features(gray_frames)
        
        # Audio extraction
        waveform = None
        if audio_path is None:
            audio_path = video_path.replace('.mp4', '.wav')
        if os.path.exists(audio_path):
            try:
                waveform, sr = torchaudio.load(audio_path)
                if sr != self.config.sample_rate:
                    resampler = torchaudio.transforms.Resample(sr, self.config.sample_rate)
                    waveform = resampler(waveform)
            except Exception as e:
                print(f"Audio load warning: {e}")
        f_aud = self.audio_stream.extract_features(waveform)
        
        # 5. FORWARD PASS (Gradient flows here!)
        all_feats = {'mae': f_mae, 'physics': f_phy, 'fft': f_frq, 'audio': f_aud}
        compressed = self.compressor(all_feats)
        p_fake, branches = self.detector(compressed)
        
        return p_fake, branches
    
    def predict(self, video_path: str, audio_path: Optional[str] = None) -> Dict:
        """
        Inference mode prediction with detailed results.
        """
        self.eval()
        with torch.no_grad():
            result = self.forward_one_video(video_path, audio_path)
            
        if result is None:
            return {'error': 'Video too short'}
        
        p_fake, branches = result
        return {
            'probability_fake': p_fake.item(),
            'prediction': 'FAKE' if p_fake.item() > 0.5 else 'REAL',
            'confidence': abs(p_fake.item() - 0.5) * 2,
            'gan_score': branches[0, 0].item(),
            'diffusion_score': branches[0, 1].item(),
            'autoregressive_score': branches[0, 2].item(),
            'audio_score': branches[0, 3].item(),
            'semantic_score': branches[0, 4].item()
        }

In [None]:
# ============================================================================
# TRAINING & TESTING
# ============================================================================

def create_dummy_video(path: str, n_frames: int = 60, size: Tuple[int, int] = (224, 224)):
    """Create a dummy video for testing."""
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(path, fourcc, 30, size)
    for i in range(n_frames):
        # Create gradient frame for some variation
        frame = np.zeros((size[1], size[0], 3), dtype=np.uint8)
        frame[:, :, 0] = i * 4 % 256  # Blue channel varies
        out.write(frame)
    out.release()
    print(f"Created dummy video: {path}")

# Initialize pipeline
print("="*70)
print("AI-Generated Media Detection - Gen 3 Agnostic Model")
print("="*70)
print(f"\nDevice: {CONFIG.device}")
print(f"Loading VideoMAE from HuggingFace...")

pipeline = DetectionPipeline(CONFIG).to(CONFIG.device)
print(f"✓ Pipeline initialized")
print(f"  - VideoMAE backbone: frozen")
print(f"  - Trainable params: {sum(p.numel() for p in pipeline.parameters() if p.requires_grad):,}")

In [None]:
# ============================================================================
# TRAINING DEMO
# ============================================================================

# Create test video if it doesn't exist
test_video_path = "dummy_train.mp4"
if not os.path.exists(test_video_path):
    create_dummy_video(test_video_path)

# Setup training
optimizer = torch.optim.Adam(pipeline.parameters(), lr=CONFIG.lr_branches)
criterion = nn.BCELoss()

# Training step demo
pipeline.train()
print("\n" + "="*70)
print("TRAINING DEMO")
print("="*70)

optimizer.zero_grad()
result = pipeline.forward_one_video(test_video_path)

if result is not None:
    p_fake, branches = result
    label = torch.tensor([[1.0]]).to(CONFIG.device)  # Assume FAKE
    
    # Compute loss
    loss = criterion(p_fake, label)
    
    # Backward pass
    loss.backward()
    optimizer.step()
    
    print(f"\n✓ Training step completed!")
    print(f"  Loss: {loss.item():.4f}")
    print(f"  P(Fake): {p_fake.item():.4f}")
    print(f"\n  Branch Scores:")
    print(f"    GAN:           {branches[0,0].item():.4f}")
    print(f"    Diffusion:     {branches[0,1].item():.4f}")
    print(f"    Autoregressive: {branches[0,2].item():.4f}")
    print(f"    Audio:         {branches[0,3].item():.4f}")
    print(f"    Semantic:      {branches[0,4].item():.4f}")
else:
    print("❌ Video too short for processing")

print("\n" + "="*70)
print("Model ready for training on real dataset!")
print("="*70)

In [None]:
# ============================================================================
# INFERENCE DEMO
# ============================================================================

print("\n" + "="*70)
print("INFERENCE DEMO")
print("="*70)

# Run inference
results = pipeline.predict(test_video_path)

if 'error' not in results:
    print(f"\nPrediction: {results['prediction']}")
    print(f"Confidence: {results['confidence']:.1%}")
    print(f"P(Fake):    {results['probability_fake']:.4f}")
    
    print(f"\nSpecialized Detector Scores:")
    print(f"  GAN Detector:           {results['gan_score']:.4f}")
    print(f"  Diffusion Detector:     {results['diffusion_score']:.4f}")
    print(f"  Autoregressive Detector: {results['autoregressive_score']:.4f}")
    print(f"  Audio Detector:         {results['audio_score']:.4f}")
    print(f"  Semantic Detector:      {results['semantic_score']:.4f}")
else:
    print(f"Error: {results['error']}")

print("\n" + "="*70)
print("ARCHITECTURE SUMMARY")
print("="*70)
print(f"""
Category A (Forensic - GAN Detection):
  └─ FrequencyStream → fft_compressor → gan_detector
  
Category B (Semantic/Physics - Diffusion Detection):
  └─ VideoMAEStream → video_temporal_proj + compression → semantic_detector
  └─ PhysicsStream → physics_compressor → diffusion_detector
  └─ Drift metrics → drift_compressor → autoregressive_detector
  
Category C (Multi-Modal - Sync Detection):
  └─ AudioStream → audio_compressor → audio_detector
  └─ Video×Audio Cross-Attention → sync_context
  
Fusion:
  └─ All branches + conflict + sync_context → fusion_mlp → P(Fake)
""")
print("="*70)