# üéØ DEEPFAKE DETECTION - INFERENCE PIPELINE

**Mission:** Generate predictions on test set using trained InceptionResNetV2 + Attention model

**Model Architecture:**
- CNN: InceptionResNetV2 (feature extractor)
- Temporal: Attention Pooling (learns frame importance)
- Single model trained on 70:30 split

**Inference Strategy:**
- Test-Time Augmentation (TTA): 5 variants per video
- Batch size: 1 (optimized for 4GB VRAM)
- Mixed precision: FP16
- Frame count testing: Compare 24 vs 32 vs 48 frames

**Output:** `PREDICTIONS.CSV` with columns:
- `filename` - Video filename
- `label` - Predicted class (0=Real, 1=Fake)
- `probability` - Confidence score of predicted class (0.0 to 1.0)

**Hardware:** RTX 3050 Laptop (4GB VRAM)  
**Estimated Time:** 
- Frame testing (10 videos): ~2 minutes
- Full inference (200 videos √ó 5 TTA): ~25-35 minutes

---

## üìã INFERENCE CHECKLIST
- [ ] GPU verified
- [ ] Model checkpoint loaded
- [ ] Frame count tested (24/32/48)
- [ ] Configuration set based on test results
- [ ] Full inference complete
- [ ] predictions.csv generated and validated


## üîß STEP 1: Environment Setup & GPU Verification

In [None]:
# Check GPU and CUDA availability
import torch
import platform

print(f"System: {platform.system()} {platform.release()}")
print(f"Python: {platform.python_version()}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    print(f"\n‚úÖ GPU ready for inference!")
else:
    print("\n‚ö†Ô∏è WARNING: GPU not detected! Inference will be slow on CPU.")


In [None]:
# Import all required libraries
import os
import cv2
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import warnings
import time
warnings.filterwarnings('ignore')

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast

# Augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Model library
import timm

# GPU monitoring
import pynvml
try:
    pynvml.nvmlInit()
    gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(0)
    print("‚úÖ NVML initialized for GPU monitoring")
except:
    gpu_handle = None
    print("‚ö†Ô∏è NVML not available - VRAM monitoring limited")

# Set device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n‚úÖ All libraries imported!")
print(f"Device: {DEVICE}")


## üíæ STEP 2: Setup Paths & Load Test Data

In [None]:
# Setup local paths (Windows)
BASE_DIR = Path(r"D:\Data\Github\SheldonC2005\ModelArena")
BASE_PATH = BASE_DIR / "archive"
TEST_PATH = BASE_PATH / "test"
TEST_CSV_PATH = BASE_PATH / "test_public.csv"

# Model and output paths
MODELS_PATH = BASE_DIR / "models"
MODEL_CHECKPOINT = MODELS_PATH / "inception_resnet_v2_best.pt"
OUTPUT_DIR = BASE_DIR / "SUBMISSION_DIRECTORY"
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
OUTPUT_CSV = OUTPUT_DIR / "PREDICTIONS.CSV"

print(f"‚úÖ Paths configured:")
print(f"   Base directory: {BASE_DIR}")
print(f"   Test videos: {TEST_PATH}")
print(f"   Model checkpoint: {MODEL_CHECKPOINT}")
print(f"   Output CSV: {OUTPUT_CSV}")


In [None]:
# Load test CSV and verify all videos exist
print("üìä Loading test data and verifying videos...\n")

# Check if model checkpoint exists
if not MODEL_CHECKPOINT.exists():
    raise FileNotFoundError(
        f"‚ùå Model checkpoint not found: {MODEL_CHECKPOINT}\n"
        f"   Please train the model first using TRAINING_PIPELINE.ipynb"
    )
print(f"‚úÖ Model checkpoint found ({MODEL_CHECKPOINT.stat().st_size / 1024**2:.2f} MB)")

# Load test CSV
test_df = pd.read_csv(TEST_CSV_PATH)
print(f"‚úÖ Test CSV loaded: {len(test_df)} samples")
print(f"   Columns: {list(test_df.columns)}")

# Add full video paths
test_df['video_path'] = test_df['filename'].apply(lambda x: str(TEST_PATH / x))

# FULL VERIFICATION: Check every video exists
print(f"\nüîç Verifying all {len(test_df)} test videos exist...")
missing_videos = []
for idx, row in test_df.iterrows():
    if not Path(row['video_path']).exists():
        missing_videos.append(row['filename'])

if missing_videos:
    print(f"‚ùå ERROR: {len(missing_videos)} videos not found!")
    print(f"   First few missing: {missing_videos[:10]}")
    raise FileNotFoundError(f"Missing {len(missing_videos)} test videos")
else:
    print(f"‚úÖ All {len(test_df)} test videos verified!")

# Check if output CSV already exists
if OUTPUT_CSV.exists():
    print(f"\n‚ö†Ô∏è WARNING: {OUTPUT_CSV.name} already exists and will be overwritten!")

print(f"\n{'='*70}")
print(f"Ready for inference on {len(test_df)} test videos")
print(f"{'='*70}")


## üèóÔ∏è STEP 3: Model Architecture & Configuration

In [None]:
# Inference configuration
class InferenceConfig:
    # Model settings (MUST match training config)
    IMG_SIZE = 299  # InceptionResNetV2 input size
    FEATURE_DIM = 1536  # InceptionResNetV2 feature dimension
    
    # Inference settings
    BATCH_SIZE = 1  # Ultra-safe for 4GB VRAM
    NUM_WORKERS = 0  # Safe for Windows
    MIXED_PRECISION = True  # FP16 for efficiency
    
    # VRAM safety thresholds
    VRAM_EMERGENCY_THRESHOLD = 0.90  # Clear cache if VRAM > 90%

config = InferenceConfig()

print("‚úÖ Inference configuration loaded!")
print(f"   Image size: {config.IMG_SIZE}√ó{config.IMG_SIZE}")
print(f"   Batch size: {config.BATCH_SIZE}")
print(f"   Mixed precision: {config.MIXED_PRECISION}")
print(f"   VRAM emergency threshold: {config.VRAM_EMERGENCY_THRESHOLD*100:.0f}%")


In [None]:
# Define CNN + Attention Model Architecture (MUST match training exactly)
class CNN_Attention_Model(nn.Module):
    """
    InceptionResNetV2 (Feature Extractor) + Attention Pooling (Temporal Modeling)
    """
    def __init__(self, feature_dim=1536, attention_dim=256, fc_dropout=0.5, num_classes=2):
        super(CNN_Attention_Model, self).__init__()
        
        # Load pretrained InceptionResNetV2
        self.cnn = timm.create_model('inception_resnet_v2', pretrained=True, num_classes=0)
        
        # Attention mechanism for temporal modeling
        self.attention = nn.Sequential(
            nn.Linear(feature_dim, attention_dim),
            nn.Tanh(),
            nn.Linear(attention_dim, 1)
        )
        
        # Classifier
        self.fc = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Dropout(fc_dropout),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        # x shape: [batch_size, num_frames, C, H, W]
        batch_size, num_frames, C, H, W = x.shape
        
        # Reshape to process all frames: [batch_size * num_frames, C, H, W]
        x = x.view(batch_size * num_frames, C, H, W)
        
        # Extract features from CNN
        features = self.cnn(x)  # [batch_size * num_frames, feature_dim]
        
        # Reshape back to sequence: [batch_size, num_frames, feature_dim]
        features = features.view(batch_size, num_frames, -1)
        
        # Attention mechanism
        attn_scores = self.attention(features)  # [batch_size, num_frames, 1]
        attn_weights = F.softmax(attn_scores, dim=1)  # [batch_size, num_frames, 1]
        
        # Weighted sum of features
        context = torch.sum(features * attn_weights, dim=1)  # [batch_size, feature_dim]
        
        # Classification
        output = self.fc(context)  # [batch_size, num_classes]
        
        return output

print("‚úÖ CNN_Attention_Model architecture defined!")


## üì¶ STEP 4: Load Trained Model with Thorough Verification

In [None]:
# Load model checkpoint with thorough verification
print(f"üì¶ Loading model from: {MODEL_CHECKPOINT}\n")

# Load checkpoint
checkpoint = torch.load(MODEL_CHECKPOINT, map_location=DEVICE)

print("=" * 70)
print("CHECKPOINT VERIFICATION")
print("=" * 70)

# Verify checkpoint structure
required_keys = ['model_state_dict', 'config', 'val_acc', 'val_f1', 'val_auc']
missing_keys = [key for key in required_keys if key not in checkpoint]
if missing_keys:
    raise KeyError(f"‚ùå Checkpoint missing required keys: {missing_keys}")
print("‚úÖ Checkpoint structure valid")

# Verify config
required_config = ['attention_dim', 'fc_dropout', 'feature_dim']
missing_config = [key for key in required_config if key not in checkpoint['config']]
if missing_config:
    raise KeyError(f"‚ùå Checkpoint config missing: {missing_config}")
print("‚úÖ Checkpoint config valid")

# Verify architecture compatibility
if checkpoint['config'].get('feature_dim') != config.FEATURE_DIM:
    raise ValueError(
        f"‚ùå Architecture mismatch! "
        f"Checkpoint feature_dim={checkpoint['config'].get('feature_dim')}, "
        f"Expected {config.FEATURE_DIM}"
    )
print("‚úÖ Architecture compatibility verified")

# Display checkpoint information
print(f"\nCheckpoint Training Results:")
print(f"  Epoch: {checkpoint.get('epoch', 'N/A')}")
print(f"  Validation Accuracy: {checkpoint['val_acc']:.4f}")
print(f"  Validation F1: {checkpoint['val_f1']:.4f}")
print(f"  Validation AUC: {checkpoint['val_auc']:.4f}")

# Get model hyperparameters from checkpoint
attention_dim = checkpoint['config']['attention_dim']
fc_dropout = checkpoint['config']['fc_dropout']

print(f"\nModel Hyperparameters:")
print(f"  Attention dimension: {attention_dim}")
print(f"  FC dropout: {fc_dropout}")

# Create model with checkpoint hyperparameters
model = CNN_Attention_Model(
    feature_dim=config.FEATURE_DIM,
    attention_dim=attention_dim,
    fc_dropout=fc_dropout
).to(DEVICE)

# Load trained weights
try:
    model.load_state_dict(checkpoint['model_state_dict'])
    print("\n‚úÖ Model weights loaded successfully!")
except Exception as e:
    raise RuntimeError(f"‚ùå Failed to load model weights: {e}")

# Set to evaluation mode
model.eval()
print("‚úÖ Model set to evaluation mode")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"   Total parameters: {total_params:,}")

print("=" * 70)
print("‚úÖ MODEL READY FOR INFERENCE")
print("=" * 70)

# Clear memory
del checkpoint
torch.cuda.empty_cache()


In [None]:
# Frame testing function
def test_frame_counts(model, test_videos, frame_counts=[24, 32, 48], device=DEVICE):
    """
    Test different frame counts on a subset of videos
    Returns: Dictionary with results for each frame count
    """
    # Simple augmentation (no TTA for testing, just resize + normalize)
    transform = A.Compose([
        A.Resize(config.IMG_SIZE, config.IMG_SIZE),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Inception normalization
        ToTensorV2()
    ])
    
    results = {fc: {'predictions': [], 'confidences': []} for fc in frame_counts}
    
    print(f"üß™ Testing {len(frame_counts)} frame counts on {len(test_videos)} videos...\n")
    
    model.eval()
    with torch.no_grad():
        for idx, row in test_videos.iterrows():
            video_path = row['video_path']
            filename = row['filename']
            
            print(f"Processing {filename}...")
            
            for num_frames in frame_counts:
                # Extract frames
                cap = cv2.VideoCapture(video_path)
                total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
                
                if total_frames <= 0:
                    cap.release()
                    # Use default prediction for corrupted video
                    results[num_frames]['predictions'].append(0)
                    results[num_frames]['confidences'].append(0.5)
                    continue
                
                # Uniformly sample frames
                indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
                frames = []
                
                for idx_frame in indices:
                    cap.set(cv2.CAP_PROP_POS_FRAMES, idx_frame)
                    ret, frame = cap.read()
                    if ret:
                        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                        frames.append(frame)
                    else:
                        if len(frames) > 0:
                            frames.append(frames[-1])
                        else:
                            frames.append(np.zeros((config.IMG_SIZE, config.IMG_SIZE, 3), dtype=np.uint8))
                
                cap.release()
                
                # Transform frames
                transformed_frames = []
                for frame in frames:
                    augmented = transform(image=frame)
                    transformed_frames.append(augmented['image'])
                
                frames_tensor = torch.stack(transformed_frames).unsqueeze(0).to(device)  # [1, num_frames, C, H, W]
                
                # Forward pass
                with autocast(enabled=config.MIXED_PRECISION):
                    output = model(frames_tensor)
                    probs = F.softmax(output, dim=1)[0]  # [2]
                
                pred_label = torch.argmax(probs).item()
                confidence = torch.max(probs).item()
                
                results[num_frames]['predictions'].append(pred_label)
                results[num_frames]['confidences'].append(confidence)
        
        torch.cuda.empty_cache()
    
    return results

print("‚úÖ Frame testing function defined!")

## üß™ STEP 5: Frame Count Testing (24 vs 32 vs 48 frames)

**Purpose:** Test different frame counts to find optimal accuracy  
**Method:** Run inference on first 10 test videos with 24, 32, and 48 frames  
**Recommendation:** Choose frame count with highest average confidence

In [None]:
# Run frame count testing on first 10 videos
test_videos = test_df.head(10)
frame_test_results = test_frame_counts(model, test_videos, frame_counts=[24, 32, 48])

# Display detailed results
print("\n" + "=" * 90)
print("FRAME COUNT COMPARISON RESULTS")
print("=" * 90)

# Calculate statistics
for num_frames in [24, 32, 48]:
    avg_conf = np.mean(frame_test_results[num_frames]['confidences'])
    predictions = frame_test_results[num_frames]['predictions']
    print(f"\n{num_frames} frames:")
    print(f"  Average confidence: {avg_conf:.4f}")
    print(f"  Predicted Real (0): {predictions.count(0)}")
    print(f"  Predicted Fake (1): {predictions.count(1)}")

# Find best frame count (highest average confidence)
best_frame_count = max([24, 32, 48], 
                       key=lambda fc: np.mean(frame_test_results[fc]['confidences']))
best_confidence = np.mean(frame_test_results[best_frame_count]['confidences'])

print("\n" + "=" * 90)
print(f"üìä RECOMMENDATION: Use {best_frame_count} frames (highest avg confidence: {best_confidence:.4f})")
print("=" * 90)

# Display per-video comparison
print("\nüìã Per-Video Comparison:")
print(f"{'Video':<20} | {'24 frames':<20} | {'32 frames':<20} | {'48 frames':<20}")
print("-" * 90)
for idx in range(len(test_videos)):
    video_name = test_videos.iloc[idx]['filename'][:18]
    results_24 = f"L={frame_test_results[24]['predictions'][idx]}, C={frame_test_results[24]['confidences'][idx]:.3f}"
    results_32 = f"L={frame_test_results[32]['predictions'][idx]}, C={frame_test_results[32]['confidences'][idx]:.3f}"
    results_48 = f"L={frame_test_results[48]['predictions'][idx]}, C={frame_test_results[48]['confidences'][idx]:.3f}"
    print(f"{video_name:<20} | {results_24:<20} | {results_32:<20} | {results_48:<20}")

print("\n‚úÖ Frame testing complete! Review results above.")


## ‚öôÔ∏è STEP 6: Set Configuration Based on Test Results

**Instructions:** Review the frame count testing results above and set FRAMES_PER_VIDEO below.  
The recommended value (highest avg confidence) is auto-filled, but you can change it.

In [None]:
# Set frames per video based on test results
# Recommended value is auto-set, but you can change to 24, 32, or 48

FRAMES_PER_VIDEO = best_frame_count  # Auto-recommended based on test results

print(f"‚úÖ Configuration set:")
print(f"   FRAMES_PER_VIDEO = {FRAMES_PER_VIDEO}")
print(f"   (Based on test results: highest avg confidence)")
print(f"\n‚ö†Ô∏è If you want to use a different value, change FRAMES_PER_VIDEO above and re-run this cell.")


## üîÑ STEP 7: Test-Time Augmentation (TTA) & Dataset

In [None]:
# Define 5 TTA variants
def get_tta_transforms():
    """
    Returns 5 augmentation variants for Test-Time Augmentation
    All use Inception normalization (mean=0.5, std=0.5)
    """
    tta_transforms = [
        # Variant 1: Original (resize + normalize only)
        A.Compose([
            A.Resize(config.IMG_SIZE, config.IMG_SIZE),
            A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ToTensorV2()
        ]),
        
        # Variant 2: Horizontal flip
        A.Compose([
            A.Resize(config.IMG_SIZE, config.IMG_SIZE),
            A.HorizontalFlip(p=1.0),
            A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ToTensorV2()
        ]),
        
        # Variant 3: Shift and scale (mild version of training aug)
        A.Compose([
            A.Resize(config.IMG_SIZE, config.IMG_SIZE),
            A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=0, p=1.0),
            A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ToTensorV2()
        ]),
        
        # Variant 4: Brightness increase
        A.Compose([
            A.Resize(config.IMG_SIZE, config.IMG_SIZE),
            A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0, p=1.0),
            A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ToTensorV2()
        ]),
        
        # Variant 5: Brightness decrease
        A.Compose([
            A.Resize(config.IMG_SIZE, config.IMG_SIZE),
            A.RandomBrightnessContrast(brightness_limit=-0.1, contrast_limit=0, p=1.0),
            A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ToTensorV2()
        ])
    ]
    
    return tta_transforms

print("‚úÖ TTA transforms defined (5 variants)!")
print("   1. Original")
print("   2. Horizontal flip")
print("   3. Shift + Scale")
print("   4. Brightness increase (+0.1)")
print("   5. Brightness decrease (-0.1)")


In [None]:
# Video Dataset for inference (optimized, no training-specific code)
class VideoInferenceDataset(Dataset):
    """
    Dataset for video inference with TTA support
    Extracts frames on-the-fly and applies specified augmentation
    """
    def __init__(self, dataframe, num_frames, transform):
        self.df = dataframe.reset_index(drop=True)
        self.num_frames = num_frames
        self.transform = transform
        self.img_size = config.IMG_SIZE
    
    def __len__(self):
        return len(self.df)
    
    def extract_frames(self, video_path):
        """Extract uniformly spaced frames from video"""
        cap = cv2.VideoCapture(video_path)
        frames = []
        
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        # Handle corrupted or empty videos
        if total_frames <= 0:
            cap.release()
            return [np.zeros((self.img_size, self.img_size, 3), dtype=np.uint8) for _ in range(self.num_frames)]
        
        # Uniformly sample frames
        indices = np.linspace(0, total_frames - 1, self.num_frames, dtype=int)
        
        for idx in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            
            if ret:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(frame)
            else:
                # Fallback to last valid frame or black frame
                if len(frames) > 0:
                    frames.append(frames[-1])
                else:
                    frames.append(np.zeros((self.img_size, self.img_size, 3), dtype=np.uint8))
        
        cap.release()
        return frames
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        video_path = row['video_path']
        filename = row['filename']
        
        # Extract frames
        frames = self.extract_frames(video_path)
        
        # Apply augmentation to each frame
        transformed_frames = []
        for frame in frames:
            augmented = self.transform(image=frame)
            transformed_frames.append(augmented['image'])
        
        # Stack frames: [num_frames, C, H, W]
        frames_tensor = torch.stack(transformed_frames)
        
        return frames_tensor, filename

print("‚úÖ VideoInferenceDataset class defined!")


## üöÄ STEP 8: Run Full Inference with TTA

**Process:** For each test video, run inference with 5 TTA variants and average predictions  
**Safety:** VRAM monitoring with emergency cache clearing + enhanced OOM fallback  
**Estimated time:** ~25-35 minutes for 200 videos

In [None]:
# Helper function to check VRAM usage
def get_vram_usage():
    """Get current VRAM usage percentage"""
    try:
        if gpu_handle is None:
            return 0.0
        mem_info = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle)
        return mem_info.used / mem_info.total
    except:
        return 0.0

# Main inference function with TTA, cache management, and enhanced fallback
def run_inference_with_tta(model, test_dataframe, num_frames, num_tta=5, device=DEVICE):
    """
    Run inference on all test videos with TTA
    
    Features:
    - 5 TTA variants per video
    - Average softmax probabilities across TTA
    - VRAM monitoring with safety checks
    - Enhanced OOM fallback (retry ‚Üí reduce frames ‚Üí default)
    - Confidence score = probability of predicted class
    """
    tta_transforms = get_tta_transforms()
    all_predictions = {}
    
    model.eval()
    
    print(f"üöÄ Starting inference on {len(test_dataframe)} videos with {num_tta} TTA variants...")
    print(f"   Total inference passes: {len(test_dataframe) * num_tta}")
    print(f"   Estimated time: ~{len(test_dataframe) * num_tta * 0.02:.0f}-{len(test_dataframe) * num_tta * 0.03:.0f} minutes\n")
    
    start_time = time.time()
    
    with torch.no_grad():
        for idx, row in tqdm(test_dataframe.iterrows(), total=len(test_dataframe), desc="Processing videos"):
            filename = row['filename']
            video_path = row['video_path']
            
            tta_softmax_outputs = []
            
            # Run inference with each TTA variant
            for tta_idx in range(num_tta):
                # VRAM safety check before each TTA variant
                vram_usage = get_vram_usage()
                if vram_usage > config.VRAM_EMERGENCY_THRESHOLD:
                    torch.cuda.empty_cache()
                
                # Create dataset for this TTA variant
                single_video_df = pd.DataFrame([row])
                dataset = VideoInferenceDataset(single_video_df, num_frames, tta_transforms[tta_idx])
                loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
                
                # Get frames and filename
                try:
                    frames, _ = next(iter(loader))
                    frames = frames.to(device)
                    
                    # Forward pass with mixed precision
                    with autocast(enabled=config.MIXED_PRECISION):
                        output = model(frames)
                        softmax = F.softmax(output, dim=1)[0]  # [2] - [prob_real, prob_fake]
                    
                    tta_softmax_outputs.append(softmax.cpu())
                
                except RuntimeError as e:
                    if "out of memory" in str(e):
                        # Enhanced OOM fallback
                        torch.cuda.empty_cache()
                        
                        # Retry 1: Same frames, after cache clear
                        try:
                            frames = frames.to(device)
                            with autocast(enabled=config.MIXED_PRECISION):
                                output = model(frames)
                                softmax = F.softmax(output, dim=1)[0]
                            tta_softmax_outputs.append(softmax.cpu())
                        except:
                            # Retry 2: Reduce to 16 frames (emergency mode)
                            try:
                                dataset_reduced = VideoInferenceDataset(
                                    single_video_df, 16, tta_transforms[tta_idx]
                                )
                                loader_reduced = DataLoader(dataset_reduced, batch_size=1, shuffle=False, num_workers=0)
                                frames_reduced, _ = next(iter(loader_reduced))
                                frames_reduced = frames_reduced.to(device)
                                
                                with autocast(enabled=config.MIXED_PRECISION):
                                    output = model(frames_reduced)
                                    softmax = F.softmax(output, dim=1)[0]
                                tta_softmax_outputs.append(softmax.cpu())
                            except:
                                # Ultimate fallback: neutral prediction
                                tta_softmax_outputs.append(torch.tensor([0.5, 0.5]))
                    else:
                        raise e
            
            # Average softmax across all TTA variants (Method B - standard practice)
            avg_softmax = torch.mean(torch.stack(tta_softmax_outputs), dim=0)  # [2]
            
            # Get final prediction and confidence
            final_label = torch.argmax(avg_softmax).item()  # 0 or 1
            confidence = torch.max(avg_softmax).item()  # Confidence of predicted class
            
            all_predictions[filename] = {
                'label': final_label,
                'probability': confidence
            }
            
            # Regular cache clear after processing all TTA variants for this video
            torch.cuda.empty_cache()
    
    elapsed_time = time.time() - start_time
    print(f"\n‚úÖ Inference complete!")
    print(f"   Total time: {elapsed_time/60:.1f} minutes")
    print(f"   Average time per video: {elapsed_time/len(test_dataframe):.1f} seconds")
    
    return all_predictions

print("‚úÖ Inference function defined with TTA + enhanced safety!")


In [None]:
# Execute full inference on all 200 test videos
print("=" * 70)
print("STARTING FULL INFERENCE")
print("=" * 70)
print(f"Videos: {len(test_df)}")
print(f"Frames per video: {FRAMES_PER_VIDEO}")
print(f"TTA variants: 5")
print(f"Batch size: {config.BATCH_SIZE}")
print("=" * 70 + "\n")

predictions = run_inference_with_tta(
    model=model,
    test_dataframe=test_df,
    num_frames=FRAMES_PER_VIDEO,
    num_tta=5,
    device=DEVICE
)

print(f"\n‚úÖ Generated {len(predictions)} predictions!")


## üíæ STEP 9: Save & Validate Predictions

In [None]:
# Create predictions dataframe
predictions_df = pd.DataFrame([
    {
        'filename': filename,
        'label': pred['label'],
        'probability': pred['probability']
    }
    for filename, pred in predictions.items()
])

# Sort by filename for consistency
predictions_df = predictions_df.sort_values('filename').reset_index(drop=True)

print("=" * 70)
print("PREDICTIONS GENERATED - RUNNING VALIDATION")
print("=" * 70)

# ESSENTIAL VALIDATION CHECKS
validation_passed = True

# Check 1: Exactly 200 rows
if len(predictions_df) != 200:
    print(f"‚ùå ERROR: Expected 200 predictions, got {len(predictions_df)}")
    validation_passed = False
else:
    print(f"‚úÖ Row count: {len(predictions_df)}")

# Check 2: All labels are 0 or 1
invalid_labels = predictions_df[~predictions_df['label'].isin([0, 1])]
if len(invalid_labels) > 0:
    print(f"‚ùå ERROR: {len(invalid_labels)} invalid labels found")
    validation_passed = False
else:
    print(f"‚úÖ All labels valid (0 or 1)")

# Check 3: All probabilities between 0 and 1
invalid_probs = predictions_df[(predictions_df['probability'] < 0) | (predictions_df['probability'] > 1)]
if len(invalid_probs) > 0:
    print(f"‚ùå ERROR: {len(invalid_probs)} probabilities out of range [0, 1]")
    validation_passed = False
else:
    print(f"‚úÖ All probabilities in valid range [0, 1]")

# Check 4: All filenames match test_public.csv
missing_files = set(test_df['filename']) - set(predictions_df['filename'])
extra_files = set(predictions_df['filename']) - set(test_df['filename'])
if missing_files or extra_files:
    print(f"‚ùå ERROR: Filename mismatch!")
    if missing_files:
        print(f"   Missing: {list(missing_files)[:5]}")
    if extra_files:
        print(f"   Extra: {list(extra_files)[:5]}")
    validation_passed = False
else:
    print(f"‚úÖ All filenames match test_public.csv")

if not validation_passed:
    raise ValueError("‚ùå Validation failed! Check errors above.")

print("\n" + "=" * 70)
print("PREDICTION STATISTICS")
print("=" * 70)

# Display statistics
real_count = (predictions_df['label'] == 0).sum()
fake_count = (predictions_df['label'] == 1).sum()
avg_confidence = predictions_df['probability'].mean()

print(f"Total predictions: {len(predictions_df)}")
print(f"Predicted Real (0): {real_count} ({real_count/len(predictions_df)*100:.1f}%)")
print(f"Predicted Fake (1): {fake_count} ({fake_count/len(predictions_df)*100:.1f}%)")
print(f"Average confidence: {avg_confidence:.4f}")

print("\n" + "=" * 70)
print("‚úÖ ALL VALIDATIONS PASSED!")
print("=" * 70)


In [None]:
# Save predictions to CSV
predictions_df.to_csv(OUTPUT_CSV, index=False)

print("=" * 70)
print("PREDICTIONS SAVED")
print("=" * 70)
print(f"üìÅ Output file: {OUTPUT_CSV}")
print(f"   File size: {OUTPUT_CSV.stat().st_size / 1024:.2f} KB")
print(f"   Format: filename, label, probability")

# Verify CSV can be read back
verify_df = pd.read_csv(OUTPUT_CSV)
print(f"\n‚úÖ CSV verification:")
print(f"   Rows: {len(verify_df)}")
print(f"   Columns: {list(verify_df.columns)}")

# Display first 10 predictions as sample
print(f"\nüìã First 10 predictions:")
print(verify_df.head(10).to_string(index=False))

print("\n" + "=" * 70)
print("‚úÖ INFERENCE PIPELINE COMPLETE!")
print("=" * 70)
print(f"\nüéâ SUCCESS! Generated {len(verify_df)} predictions")
print(f"üìÅ Output: {OUTPUT_CSV}")
print(f"\nNext steps:")
print(f"  1. Review predictions above")
print(f"  2. Check {OUTPUT_DIR} folder")
print(f"  3. Submit PREDICTIONS.CSV to competition")
print("=" * 70)
