In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
import torchaudio
from transformers import VideoMAEModel, VideoMAEConfig
from torchvision.transforms import Compose, Resize, Normalize
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
import warnings
warnings.filterwarnings('ignore')

In [None]:
# ============================================================================
# CONFIGURATION
# ============================================================================
@dataclass
class Config:
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Temporal sampling
    n_windows: int = 3
    frames_per_clip: int = 16
    min_frames: int = 48
    
    # Spatial sampling
    full_res: Tuple[int, int] = (224, 224)
    
    # Feature dimensions
    mae_hidden: int = 768
    mae_compressed: int = 128
    fft_dim: int = 256
    fft_compressed: int = 64
    physics_compressed: int = 32
    audio_compressed: int = 32
    drift_compressed: int = 16
    
    # Audio
    sample_rate: int = 16000
    n_mfcc: int = 13
    
    # Scene change detection
    scene_change_threshold: float = 30.0  # Pixel diff threshold
    
    # Motion detection
    low_motion_threshold: float = 0.5  # Optical flow magnitude threshold
    
    # NEW: Transformer dimensions
    fusion_dim: int = 128
    n_heads: int = 4

    # Training
    batch_size: int = 16
    lr_branches: float = 1e-3
    lr_fusion: float = 5e-4
    
    # Model checkpoint path (for loading trained weights)
    checkpoint_path: Optional[str] = None

CONFIG = Config()

In [None]:
# SCENE CHANGE DETECTOR
class SceneChangeDetector:
    """Detects hard cuts in video to prevent false positives in drift detection."""
    
    def __init__(self, threshold: float = 30.0):
        self.threshold = threshold
    
    def detect_cuts(self, frames: List[np.ndarray]) -> List[int]:
        """
        Returns indices where scene cuts occur.
        Args:
            frames: List of BGR frames
        Returns:
            List of frame indices with scene cuts
        """
        cuts = []
        for i in range(len(frames) - 1):
            # Convert to grayscale for comparison
            gray1 = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
            gray2 = cv2.cvtColor(frames[i + 1], cv2.COLOR_BGR2GRAY)
            
            # Compute mean absolute difference
            diff = np.mean(np.abs(gray1.astype(float) - gray2.astype(float)))
            
            if diff > self.threshold:
                cuts.append(i + 1)
        
        return cuts


# STREAM 1: VideoMAE - Semantic Understanding
class VideoMAEStream(nn.Module):
    """
    Detects semantic drift and object permanence violations.
    Best for: Autoregressive models, scene coherence
    """
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        
        # Load pretrained VideoMAE
        mae_config = VideoMAEConfig.from_pretrained("MCG-NJU/videomae-base")
        self.model = VideoMAEModel.from_pretrained("MCG-NJU/videomae-base", config=mae_config)
        self.model.eval()
        
        # Feature compression MLP
        self.compression = nn.Sequential(
            nn.Linear(config.mae_hidden * config.n_windows, 512),
            nn.LayerNorm(512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, config.mae_compressed)
        )
        
        # Normalize for RGB (ImageNet stats)
        self.normalize = Normalize(
            mean=[0.485, 0.456, 0.406], 
            std=[0.229, 0.224, 0.225]
        )
    
    def extract_features(self, clips: torch.Tensor, has_scene_cuts: bool = False) -> Dict[str, torch.Tensor]:

        with torch.no_grad():
            n_windows = clips.shape[0]
            embeddings = []
            
            for i in range(n_windows):
                # VideoMAE expects (batch, frames, channels, H, W)
                clip = clips[i].unsqueeze(0)  # Already in correct format
                outputs = self.model(pixel_values=clip.to(self.config.device))
                
                # Pool over sequence length
                embedding = outputs.last_hidden_state.mean(dim=1)  # (1, hidden)
                embedding = F.normalize(embedding, p=2, dim=1)
                embeddings.append(embedding)
            
            embeddings = torch.cat(embeddings, dim=0)  # (n_windows, hidden)
        
        # Compute temporal drift metrics (with scene cut awareness)
        drift_metrics = self._compute_drift(embeddings, has_scene_cuts)
        
        # Compress features
        flat_embeddings = embeddings.flatten()
        compressed = self.compression(flat_embeddings.unsqueeze(0))
        
        return {
            'embeddings': embeddings.cpu(),
            'compressed': compressed.cpu(),
            'drift_metrics': drift_metrics
        }
    
    def _compute_drift(self, embeddings: torch.Tensor, has_scene_cuts: bool) -> Dict[str, float]:

        e = embeddings
        
        # Primary drift: early vs late
        drift_primary = 1 - F.cosine_similarity(e[0:1], e[-1:], dim=1).item()
        
        # Drift acceleration
        drift_1_to_2 = 1 - F.cosine_similarity(e[0:1], e[1:2], dim=1).item()
        drift_2_to_3 = 1 - F.cosine_similarity(e[1:2], e[2:3], dim=1).item()
        drift_acceleration = drift_2_to_3 - drift_1_to_2
        
        # PCA-based axis consistency
        drift_axis_consistency = torch.var(e, dim=0).mean().item()
        
        # If scene cuts present, dampen drift signals
        if has_scene_cuts:
            drift_primary *= 0.3
            drift_acceleration *= 0.3
        
        return {
            'drift_primary': drift_primary,
            'drift_acceleration': drift_acceleration,
            'drift_axis_consistency': drift_axis_consistency,
            'has_scene_cuts': float(has_scene_cuts)
        }


# STREAM 2: Frequency Analysis - GAN Artifact Detection
class FrequencyStream:
    """
    Detects upscaling artifacts and checkerboard patterns.
    Best for: GAN-generated videos (StyleGAN, ProGAN)
    """
    def __init__(self, config: Config):
        self.config = config
    
    def extract_features(self, frames: List[np.ndarray]) -> Dict[str, np.ndarray]:

        n_windows = self.config.n_windows
        chunk_size = len(frames) // n_windows
        
        signatures = []
        artifact_scores = []
        
        for i in range(n_windows):
            window = frames[i*chunk_size : (i+1)*chunk_size]
            if not window:
                signatures.append(np.zeros(self.config.fft_dim))
                artifact_scores.append(0.0)
                continue
            
            # Use middle keyframe
            keyframe = window[len(window)//2]
            
            # Normalize to [0, 1] for lighting invariance
            keyframe_norm = keyframe.astype(float) / 255.0
            
            # 2D FFT
            f = np.fft.fft2(keyframe_norm)
            fshift = np.fft.fftshift(f)
            magnitude = np.abs(fshift)
            
            # Log scale with proper normalization
            magnitude = np.log(magnitude + 1e-8)
            magnitude = (magnitude - magnitude.min()) / (magnitude.max() - magnitude.min() + 1e-8)
            
            # Azimuthal integration (radial profile)
            h, w = magnitude.shape
            center = (w//2, h//2)
            y, x = np.ogrid[:h, :w]
            r = np.sqrt((x - center[0])**2 + (y - center[1])**2).astype(int)
            
            tbin = np.bincount(r.ravel(), magnitude.ravel())
            nr = np.bincount(r.ravel())
            radial_profile = tbin / (nr + 1e-8)
            
            # Normalize to fixed size
            if len(radial_profile) > self.config.fft_dim:
                radial_profile = radial_profile[:self.config.fft_dim]
            else:
                radial_profile = np.pad(radial_profile, 
                                       (0, self.config.fft_dim - len(radial_profile)))
            
            signatures.append(radial_profile)
            
            # Artifact score: peak detection in checkerboard range (32-64 pixels)
            if len(radial_profile) >= 64:
                checkerboard_region = radial_profile[32:64]
                baseline = radial_profile[64:128].mean() if len(radial_profile) >= 128 else radial_profile.mean()
                artifact_score = (np.max(checkerboard_region) - baseline) / (baseline + 1e-8)
                artifact_scores.append(np.clip(artifact_score, 0, 1))
            else:
                artifact_scores.append(0.0)
        
        signatures = np.array(signatures)
        artifact_scores = np.array(artifact_scores)
        
        # Temporal stability: variance across windows
        temporal_stability = np.var(signatures, axis=0).mean()
        
        return {
            'signatures': signatures,
            'artifact_scores': artifact_scores,
            'temporal_stability': temporal_stability,
            'fingerprint': signatures.mean(axis=0)
        }


# STREAM 3: Physics Consistency - Diffusion Detector
class PhysicsStream:
    """
    Detects motion inconsistencies and physics violations.
    Best for: Diffusion models (Sora, SVD, Kling)
    """
    def __init__(self, config: Config):
        self.config = config
    
    def extract_features(self, frames: List[np.ndarray]) -> Dict[str, np.ndarray]:
        """
        Args:
            frames: List of grayscale frames
        Returns:
            Dict with flow consistency and motion metrics
        """
        n_windows = self.config.n_windows
        chunk_size = len(frames) // n_windows
        
        consistency_errors = []
        outlier_scores = []
        motion_ranges = []
        smoothness_scores = []
        is_low_motion_flags = []
        
        for i in range(n_windows):
            window = frames[i*chunk_size : (i+1)*chunk_size]
            if len(window) < 2:
                consistency_errors.append(0.0)
                outlier_scores.append(0.0)
                motion_ranges.append([0.0, 0.0])
                smoothness_scores.append(0.0)
                is_low_motion_flags.append(1.0)
                continue
            
            window_errors = []
            magnitudes = []
            
            for t in range(len(window) - 1):
                prev, curr = window[t], window[t+1]
                
                # Optical flow (Farneback - can be replaced with RAFT)
                flow_fw = cv2.calcOpticalFlowFarneback(
                    prev, curr, None, 0.5, 3, 15, 3, 5, 1.2, 0
                )
                flow_bw = cv2.calcOpticalFlowFarneback(
                    curr, prev, None, 0.5, 3, 15, 3, 5, 1.2, 0
                )
                
                # Consistency error
                mag_sq = np.sum((flow_fw + flow_bw)**2, axis=2)
                error = np.sqrt(mag_sq).mean()
                window_errors.append(error)
                
                # Motion magnitude
                mag = np.sqrt(np.sum(flow_fw**2, axis=2)).mean()
                magnitudes.append(mag)
            
            # Check if low motion
            avg_motion = np.mean(magnitudes) if magnitudes else 0.0
            is_low_motion = float(avg_motion < self.config.low_motion_threshold)
            is_low_motion_flags.append(is_low_motion)
            
            # Aggregate metrics
            consistency_errors.append(np.mean(window_errors) if window_errors else 0.0)
            outlier_scores.append(np.sum(np.array(window_errors) > 1.0) / len(window_errors) if window_errors else 0.0)
            
            if magnitudes:
                motion_ranges.append([np.min(magnitudes), np.max(magnitudes)])
                # Smoothness: variance of magnitudes (normalized)
                smoothness_scores.append(np.var(magnitudes) if avg_motion > 0.1 else 0.0)
            else:
                motion_ranges.append([0.0, 0.0])
                smoothness_scores.append(0.0)
        
        # Normalize to [0, 1]
        consistency_errors = np.array(consistency_errors)
        consistency_errors = np.clip(consistency_errors / 2.0, 0, 1)
        
        outlier_scores = np.array(outlier_scores)
        motion_ranges = np.array(motion_ranges)
        smoothness_scores = np.array(smoothness_scores)
        smoothness_scores = np.clip(smoothness_scores / 10.0, 0, 1)
        is_low_motion_flags = np.array(is_low_motion_flags)
        
        # Physics violation score with low-motion gating
        avg_low_motion = is_low_motion_flags.mean()
        weight = 1.0 - avg_low_motion  # Reduce weight if mostly static
        
        physics_violation = weight * (consistency_errors.mean() + 
                                      outlier_scores.mean() + 
                                      smoothness_scores.mean()) / 3.0
        
        return {
            'consistency_errors': consistency_errors,
            'outlier_scores': outlier_scores,
            'motion_ranges': motion_ranges,
            'smoothness_scores': smoothness_scores,
            'physics_violation': physics_violation,
            'low_motion_score': avg_low_motion
        }

# STREAM 4: Audio Analysis
class AudioStream:
    def __init__(self, config: Config):
        self.config = config
        self.mfcc_transform = torchaudio.transforms.MFCC(
            sample_rate=config.sample_rate,
            n_mfcc=config.n_mfcc
        )
    
    def extract_features(self, audio_waveform: Optional[torch.Tensor]) -> Dict[str, np.ndarray]:
        # If no audio, return zeros with temporal shape (n_windows, n_mfcc)
        if audio_waveform is None or audio_waveform.shape[-1] < 1000:
            return {
                'mfcc_features': np.zeros(self.config.n_mfcc * 3), # Global
                'temporal_features': np.zeros((self.config.n_windows, self.config.n_mfcc)), # NEW: Temporal
                'artifact_score': 0.0,
                'sync_score': 0.5,
                'has_audio': 0.0
            }
        
        if audio_waveform.shape[0] > 1:
            audio_waveform = audio_waveform.mean(dim=0, keepdim=True)
            
        mfcc = self.mfcc_transform(audio_waveform) # (n_mfcc, time)
        
        # --- NEW: Extract Temporal Chunks ---
        # Split audio into n_windows to match video clips
        total_time = mfcc.shape[-1]
        chunk_size = total_time // self.config.n_windows
        temporal_features = []
        
        for i in range(self.config.n_windows):
            # Get slice for this window
            start = i * chunk_size
            end = (i + 1) * chunk_size
            chunk = mfcc[:, start:end]
            # Average over this specific window
            temporal_features.append(chunk.mean(dim=-1).numpy())
            
        temporal_features = np.stack(temporal_features) # (n_windows, n_mfcc)
        # ------------------------------------

        # Global Stats (Keep for artifact detection)
        mfcc_mean = mfcc.mean(dim=-1)
        mfcc_std = mfcc.std(dim=-1)
        mfcc_delta = torch.diff(mfcc, dim=-1).mean(dim=-1)
        mfcc_features = torch.cat([mfcc_mean, mfcc_std, mfcc_delta]).numpy()
        
        spectral_smoothness = 1.0 - min(mfcc_std.mean().item() / 10.0, 1.0)
        
        return {
            'mfcc_features': mfcc_features,
            'temporal_features': temporal_features, # Return the sequence!
            'artifact_score': spectral_smoothness,
            'sync_score': 0.5,
            'has_audio': 1.0
        }

# FEATURE COMPRESSION & FUSION
class FeatureCompression(nn.Module):
    """Per-stream feature compression MLPs."""
    
    def __init__(self, config: Config):
        super().__init__()
	self.config = config
        
        # FFT compression
        fft_input_dim = config.fft_dim * config.n_windows + config.n_windows + 1
        self.fft_compressor = nn.Sequential(
            nn.Linear(fft_input_dim, 128),
            nn.LayerNorm(128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, config.fft_compressed)
        )
        
        # Physics compression (added low_motion_score)
        physics_input_dim = config.n_windows * 5 + 2
        self.physics_compressor = nn.Sequential(
            nn.Linear(physics_input_dim, 64),
            nn.LayerNorm(64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, config.physics_compressed)
        )
        
        # Audio compression (added has_audio flag)
        audio_input_dim = config.n_mfcc * 3 + 3
        self.audio_compressor = nn.Sequential(
            nn.Linear(audio_input_dim, 64),
            nn.LayerNorm(64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, config.audio_compressed)
        )
        
        # Drift compression (added scene cut flag)
        drift_input_dim = 4
        self.drift_compressor = nn.Sequential(
            nn.Linear(drift_input_dim, 32),
            nn.LayerNorm(32),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(32, config.drift_compressed)
        )

	# NEW: Project Audio Temporal Chunks to match VideoMAE dimension (or Fusion Dim)
        self.audio_temporal_proj = nn.Linear(config.n_mfcc, config.mae_compressed)
        
        # NEW: Project VideoMAE embeddings to Fusion Dim
        self.video_temporal_proj = nn.Linear(config.mae_hidden, config.mae_compressed)

    
    def forward(self, features: Dict) -> Dict[str, torch.Tensor]:

        # FFT
        fft_input = torch.cat([
            torch.from_numpy(features['fft']['signatures']).flatten(),
            torch.from_numpy(features['fft']['artifact_scores']),
            torch.tensor([features['fft']['temporal_stability']])
        ]).float()
        fft_compressed = self.fft_compressor(fft_input.unsqueeze(0))
        
        # Physics
        physics_input = torch.cat([
            torch.from_numpy(features['physics']['consistency_errors']),
            torch.from_numpy(features['physics']['outlier_scores']),
            torch.from_numpy(features['physics']['motion_ranges']).flatten(),
            torch.from_numpy(features['physics']['smoothness_scores']),
            torch.tensor([features['physics']['physics_violation']]),
            torch.tensor([features['physics']['low_motion_score']])
        ]).float()
        physics_compressed = self.physics_compressor(physics_input.unsqueeze(0))
        
        # Audio
        audio_input = torch.cat([
            torch.from_numpy(features['audio']['mfcc_features']),
            torch.tensor([features['audio']['artifact_score']]),
            torch.tensor([features['audio']['sync_score']]),
            torch.tensor([features['audio']['has_audio']])
        ]).float()
        audio_compressed = self.audio_compressor(audio_input.unsqueeze(0))
        
        # Drift
        drift_input = torch.tensor([
            features['mae']['drift_metrics']['drift_primary'],
            features['mae']['drift_metrics']['drift_acceleration'],
            features['mae']['drift_metrics']['drift_axis_consistency'],
            features['mae']['drift_metrics']['has_scene_cuts']
        ]).float()
        drift_compressed = self.drift_compressor(drift_input.unsqueeze(0))
        
	# Audio: (n_windows, n_mfcc) -> (n_windows, mae_compressed)
        audio_temp = torch.from_numpy(features['audio']['temporal_features']).float().to(self.config.device)
        audio_emb = self.audio_temporal_proj(audio_temp)
        
        # Video: (n_windows, mae_hidden) -> (n_windows, mae_compressed)
        # Note: We use the raw embeddings from VideoMAEStream, not the compressed global one
        video_emb = features['mae']['embeddings'].to(self.config.device)
        video_emb = self.video_temporal_proj(video_emb)

	
        return {
            'mae': features['mae']['compressed'],
            'fft': fft_compressed,
            'physics': physics_compressed,
            'audio': audio_compressed,
            'drift': drift_compressed
	    'sequence_video': video_emb, 
            'sequence_audio': audio_emb			
        }

# HIERARCHICAL DETECTOR
class HierarchicalDetector(nn.Module):
    """
    Layer 1: Per-stream specialized detectors
    Layer 2: Cross-modal fusion with confidence weighting
    """
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        
        # Layer 1: Branch detectors
        self.gan_detector = nn.Sequential(
            nn.Linear(config.fft_compressed, 32),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(32, 8),
            nn.ReLU(),
            nn.Linear(8, 1),
            nn.Sigmoid()
        )
        
        self.diffusion_detector = nn.Sequential(
            nn.Linear(config.physics_compressed, 16),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, 1),
            nn.Sigmoid()
        )
        
        self.autoregressive_detector = nn.Sequential(
            nn.Linear(config.drift_compressed, 8),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(8, 4),
            nn.ReLU(),
            nn.Linear(4, 1),
            nn.Sigmoid()
        )
        
        self.audio_detector = nn.Sequential(
            nn.Linear(config.audio_compressed, 16),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, 1),
            nn.Sigmoid()
        )
        
        self.semantic_detector = nn.Sequential(
            nn.Linear(config.mae_compressed, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )
        
        # Layer 2: Fusion network
        total_dim = 5 + 1 + (config.mae_compressed + config.fft_compressed + 
                            config.physics_compressed + config.audio_compressed + 
                            config.drift_compressed)
        

	# Attention: Query = Video, Key/Value = Audio
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=config.mae_compressed, 
            num_heads=config.n_heads,
            batch_first=True
        )
        
        # Final Classifier
        # Input: All branch scores (5) + Conflict (1) + Transformer Output (mae_compressed)
        fusion_input_dim = 6 + config.mae_compressed
        
        self.fusion_mlp = nn.Sequential(
            nn.Linear(fusion_input_dim, 64),
            nn.LayerNorm(64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    
    def forward(self, compressed_features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Hierarchical detection with interpretability."""
        
        # Layer 1: Branch predictions
        p_gan = self.gan_detector(compressed_features['fft'])
        p_diffusion = self.diffusion_detector(compressed_features['physics'])
        p_autoregressive = self.autoregressive_detector(compressed_features['drift'])
        p_audio = self.audio_detector(compressed_features['audio'])
        p_semantic = self.semantic_detector(compressed_features['mae'])
        
        branch_predictions = torch.cat([p_gan, p_diffusion, p_autoregressive, 
                                       p_audio, p_semantic], dim=1)
        
        conflict_score = torch.std(branch_predictions, dim=1, keepdim=True)
        
        # Layer 2: Cross-modal fusion
	# 2. NEW: Transformer Fusion (Sync Detection)
        # Get sequences: (Batch=1, n_windows, dim)
        vid_seq = compressed_features['sequence_video'].unsqueeze(0) 
        aud_seq = compressed_features['sequence_audio'].unsqueeze(0)
        
        # Cross Attention: "Does Audio explain the Video?"
        # Output is "Video features weighted by Audio relevance"
        attn_out, _ = self.cross_attn(query=vid_seq, key=aud_seq, value=aud_seq)
        
        # Pool the sequence (mean over windows) to get a global sync context vector
        sync_context = attn_out.mean(dim=1) # (1, dim)


	# 3. Final Prediction
	fusion_features = torch.cat([branch_predictions, conflict_score, sync_context], dim=1)
        p_fake = self.fusion_mlp(fusion_features)
        
        return {
            'p_fake': p_fake,
            'branch_predictions': branch_predictions,
            'conflict_score': conflict_score
        }

# MAIN PIPELINE
class AIGeneratedMediaDetectionPipeline:
    """End-to-end detection pipeline with all fixes applied."""
    
    def __init__(self, config: Config = CONFIG):
        self.config = config
        
        # Initialize streams
        self.mae_stream = VideoMAEStream(config).to(config.device)
        self.freq_stream = FrequencyStream(config)
        self.physics_stream = PhysicsStream(config)
        self.audio_stream = AudioStream(config)
        self.scene_detector = SceneChangeDetector(config.scene_change_threshold)
        
        # Initialize compression and detector
        self.compressor = FeatureCompression(config).to(config.device)
        self.detector = HierarchicalDetector(config).to(config.device)
        
        # Load trained weights if available
        if config.checkpoint_path:
            self.load_checkpoint(config.checkpoint_path)
        else:
            print("WARNING: No checkpoint loaded. Model has random weights.")
            print("For inference, please provide checkpoint_path in Config.")
    
    def load_checkpoint(self, path: str):
        """Load trained model weights."""
        checkpoint = torch.load(path, map_location=self.config.device)
        self.compressor.load_state_dict(checkpoint['compressor'])
        self.detector.load_state_dict(checkpoint['detector'])
        # VideoMAE weights are already pretrained, no need to load
        print(f"Loaded checkpoint from {path}")
    
    def save_checkpoint(self, path: str):
        """Save model weights."""
        torch.save({
            'compressor': self.compressor.state_dict(),
            'detector': self.detector.state_dict(),
            'config': self.config
        }, path)
        print(f"Saved checkpoint to {path}")
    
    def process_video(self, video_path: str, audio_path: Optional[str] = None) -> Dict:
        """
        Process a single video and return detection results.
        """
        # 1. Load video
        cap = cv2.VideoCapture(video_path)
        frames_bgr = []
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frames_bgr.append(frame)
        cap.release()
        
        if len(frames_bgr) < self.config.min_frames:
            raise ValueError(f"Video too short: {len(frames_bgr)} < {self.config.min_frames}")
        
        # 2. Detect scene changes
        scene_cuts = self.scene_detector.detect_cuts(frames_bgr)
        has_scene_cuts = len(scene_cuts) > 0
        
        # 3. Convert BGR to RGB for VideoMAE
        frames_rgb = [cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frames_bgr]
        
        # 4. Temporal sampling for VideoMAE
        total_f = len(frames_rgb)
        indices = [
            np.linspace(0, total_f//3, self.config.frames_per_clip, dtype=int),
            np.linspace(total_f//3, 2*total_f//3, self.config.frames_per_clip, dtype=int),
            np.linspace(2*total_f//3, total_f-1, self.config.frames_per_clip, dtype=int)
        ]
        
        # Prepare clips for VideoMAE (RGB, normalized)
        mae_clips = []
        for idx in indices:
            clip_frames = []
            for i in idx:
                # Convert to tensor and normalize to [0, 1]
                frame_tensor = torch.from_numpy(frames_rgb[i]).permute(2, 0, 1).float() / 255.0
                # Resize
                frame_tensor = F.interpolate(
                    frame_tensor.unsqueeze(0), 
                    size=self.config.full_res, 
                    mode='bilinear', 
                    align_corners=False
                ).squeeze(0)
                # Apply ImageNet normalization
                frame_tensor = self.mae_stream.normalize(frame_tensor)
                clip_frames.append(frame_tensor)
            
            mae_clips.append(torch.stack(clip_frames))  # (frames, 3, H, W)
        
        mae_clips = torch.stack(mae_clips)  # (n_windows, frames, 3, H, W)
        
        # 5. Convert to grayscale for physics/FFT
        gray_frames = [cv2.cvtColor(f, cv2.COLOR_BGR2GRAY) for f in frames_bgr]
        
        # 6. Extract features from all streams
        mae_features = self.mae_stream.extract_features(mae_clips, has_scene_cuts)
        fft_features = self.freq_stream.extract_features(gray_frames)
        physics_features = self.physics_stream.extract_features(gray_frames)
        
        # 7. Load audio
        audio_waveform = None
        if audio_path:
            try:
                audio_waveform, _ = torchaudio.load(audio_path)
            except Exception as e:
                print(f"Warning: Could not load audio: {e}")
        audio_features = self.audio_stream.extract_features(audio_waveform)
        
        all_features = {
            'mae': mae_features,
            'fft': fft_features,
            'physics': physics_features,
            'audio': audio_features
        }
        
        # 8. Compress features
        compressed = self.compressor(all_features)
        
        # 9. Hierarchical detection
        with torch.no_grad():
            results = self.detector(compressed)
        
        return {
            'probability_fake': results['p_fake'].item(),
            'gan_score': results['branch_predictions'][0, 0].item(),
            'diffusion_score': results['branch_predictions'][0, 1].item(),
            'autoregressive_score': results['branch_predictions'][0, 2].item(),
            'audio_score': results['branch_predictions'][0, 3].item(),
            'semantic_score': results['branch_predictions'][0, 4].item(),
            'conflict_score': results['conflict_score'].item(),
            'prediction': 'FAKE' if results['p_fake'].item() > 0.5 else 'REAL',
            'metadata': {
                'scene_cuts_detected': len(scene_cuts),
                'has_scene_cuts': has_scene_cuts,
                'low_motion_score': physics_features['low_motion_score'],
                'has_audio': audio_features['has_audio']
            }
        }

# TRAINING UTILITIES
class Trainer:
    """Training utilities for the detection pipeline."""
    
    def __init__(self, pipeline: AIGeneratedMediaDetectionPipeline, config: Config):
        self.pipeline = pipeline
        self.config = config
        
        # Separate optimizers for hierarchical training
        self.optimizer_branches = torch.optim.Adam([
            {'params': pipeline.mae_stream.compression.parameters()},
            {'params': pipeline.compressor.parameters()},
            {'params': pipeline.detector.gan_detector.parameters()},
            {'params': pipeline.detector.diffusion_detector.parameters()},
            {'params': pipeline.detector.autoregressive_detector.parameters()},
            {'params': pipeline.detector.audio_detector.parameters()},
            {'params': pipeline.detector.semantic_detector.parameters()}
        ], lr=config.lr_branches)
        
        self.optimizer_fusion = torch.optim.Adam(
            pipeline.detector.fusion.parameters(),
            lr=config.lr_fusion
        )
        
        self.criterion = nn.BCELoss()
    
    def train_step(self, video_paths: List[str], labels: List[int], 
                   audio_paths: Optional[List[str]] = None) -> Dict[str, float]:
        """
        Single training step.
        Args:
            video_paths: List of video file paths
            labels: List of binary labels (0=real, 1=fake)
            audio_paths: Optional list of audio file paths
        Returns:
            Dict with loss values
        """
        self.pipeline.compressor.train()
        self.pipeline.detector.train()
        
        batch_loss = 0.0
        batch_branch_losses = []
        
        for i, (video_path, label) in enumerate(zip(video_paths, labels)):
            audio_path = audio_paths[i] if audio_paths else None
            
            # Forward pass
            results = self.pipeline.process_video(video_path, audio_path)
            
            # Prepare tensors
            label_tensor = torch.tensor([[float(label)]], device=self.config.device)
            p_fake = torch.tensor([[results['probability_fake']]], device=self.config.device)
            branch_preds = torch.tensor([list(results.values())[1:6]], device=self.config.device)
            
            # Branch losses
            branch_loss = self.criterion(branch_preds, label_tensor.expand_as(branch_preds))
            
            # Fusion loss
            fusion_loss = self.criterion(p_fake, label_tensor)
            
            # Total loss
            loss = branch_loss + fusion_loss
            
            # Backward pass
            self.optimizer_branches.zero_grad()
            self.optimizer_fusion.zero_grad()
            loss.backward()
            self.optimizer_branches.step()
            self.optimizer_fusion.step()
            
            batch_loss += loss.item()
            batch_branch_losses.append(branch_loss.item())
        
        return {
            'total_loss': batch_loss / len(video_paths),
            'branch_loss': np.mean(batch_branch_losses),
            'fusion_loss': (batch_loss - sum(batch_branch_losses)) / len(video_paths)
        }

# USAGE EXAMPLE WITH PROPER ERROR HANDLING
if __name__ == "__main__":
    print("=" * 70)
    print("AI-Generated Media Detection Pipeline - Robust Version")
    print("=" * 70)
    print("\nFixes Applied:")
    print("✓ BGR→RGB conversion for VideoMAE")
    print("✓ Scene change detection to prevent false positives")
    print("✓ Low-motion gating for physics stream")
    print("✓ Improved FFT normalization for lighting invariance")
    print("✓ Checkpoint loading/saving support")
    print("✓ Training utilities included")
    print("=" * 70)
    
    # Configuration
    config = Config()
    config.checkpoint_path = None  # Set to your checkpoint path for inference
    
    # Initialize pipeline
    print("\nInitializing pipeline...")
    pipeline = AIGeneratedMediaDetectionPipeline(config)
    
    # Example: Process a video
    video_path = "test_video.mp4"
    
    try:
        print(f"\nProcessing video: {video_path}")
        results = pipeline.process_video(video_path)
        
        print("\n" + "=" * 70)
        print("DETECTION RESULTS")
        print("=" * 70)
        print(f"\nPrediction: {results['prediction']}")
        print(f"Confidence: {results['probability_fake']:.1%}")
        
        print("\n--- Specialized Detector Scores ---")
        print(f"  GAN Detector:           {results['gan_score']:.3f}")
        print(f"  Diffusion Detector:     {results['diffusion_score']:.3f}")
        print(f"  Autoregressive Detector: {results['autoregressive_score']:.3f}")
        print(f"  Audio Detector:         {results['audio_score']:.3f}")
        print(f"  Semantic Detector:      {results['semantic_score']:.3f}")
        
        print("\n--- Meta Information ---")
        print(f"  Conflict Score:         {results['conflict_score']:.3f}")
        print(f"  Scene Cuts Detected:    {results['metadata']['scene_cuts_detected']}")
        print(f"  Low Motion Score:       {results['metadata']['low_motion_score']:.3f}")
        print(f"  Has Audio:              {'Yes' if results['metadata']['has_audio'] else 'No'}")
        
        print("\n--- Interpretation ---")
        if results['conflict_score'] > 0.3:
            print("  ⚠ High disagreement between detectors - uncertain prediction")
        if results['metadata']['has_scene_cuts']:
            print("  ℹ Scene cuts detected - drift metrics dampened")
        if results['metadata']['low_motion_score'] > 0.7:
            print("  ℹ Low motion detected - physics metrics down-weighted")
        
        print("=" * 70)
        
    except FileNotFoundError:
        print(f"\nError: Video file not found: {video_path}")
        print("\nTo use this pipeline:")
        print("1. Prepare a dataset of real and fake videos")
        print("2. Train the model using the Trainer class")
        print("3. Save the checkpoint with pipeline.save_checkpoint('model.pth')")
        print("4. Set config.checkpoint_path = 'model.pth' for inference")
    
    except Exception as e:
        print(f"\nError during processing: {e}")
        import traceback
        traceback.print_exc()
    
    print("\n" + "=" * 70)
    print("Pipeline Overview:")
    print("=" * 70)
