In [1]:
import torch
import sys
import warnings
warnings.filterwarnings('ignore')

print(f"Python: {sys.version.split()[0]}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU Memory: {gpu_memory_gb:.2f} GB")
    device = torch.device('cuda')
else:
    print("‚ö†Ô∏è WARNING: No GPU detected, using CPU")
    gpu_memory_gb = 0
    device = torch.device('cpu')

print(f"\n‚úÖ Device: {device}")


Python: 3.10.19
PyTorch: 2.5.1+cu121
CUDA Available: True
CUDA Version: 12.1
GPU: NVIDIA RTX A6000
GPU Memory: 48.31 GB

‚úÖ Device: cuda


In [2]:
# Import all required libraries
import os
import sys
import json
from pathlib import Path
from dataclasses import dataclass
from typing import Optional, Dict, List, Tuple

import numpy as np
import pandas as pd
import cv2
import librosa
import soundfile as sf
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import torchvision.transforms as transforms

# Vision models
import timm
try:
    import open_clip
    OPEN_CLIP_AVAILABLE = True
except:
    OPEN_CLIP_AVAILABLE = False
    print("‚ö†Ô∏è open_clip not available")

# Audio models
import torchaudio
from transformers import Wav2Vec2Model, Wav2Vec2Processor

# NLP models
try:
    from sentence_transformers import SentenceTransformer
    SENTENCE_TRANSFORMERS_AVAILABLE = True
except:
    SENTENCE_TRANSFORMERS_AVAILABLE = False
    print("‚ö†Ô∏è sentence-transformers not available")

# Metrics
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import confusion_matrix, roc_auc_score

print("‚úÖ All imports successful!")
print(f"Device: {device}")


‚úÖ All imports successful!
Device: cuda


In [3]:
@dataclass
class ModelConfig:
    """Configuration for model architecture"""
    
    # Model size
    preset: str = "large"
    d_model: int = 512
    n_heads: int = 8
    n_layers: int = 4
    dropout: float = 0.1
    
    # Encoders
    vision_backbone: str = "vit_base_patch16_224"
    audio_backbone: str = "facebook/wav2vec2-large-960h"
    text_backbone: str = "sentence-transformers/all-MiniLM-L6-v2"
    
    vision_pretrained: bool = True  # ‚Üê ADD THIS LINE
    
    freeze_vision: bool = True
    freeze_audio: bool = True
    freeze_text: bool = True
    
    # Training
    batch_size: int = 8
    learning_rate: float = 1e-4
    weight_decay: float = 1e-4
    epochs: int = 10
    gradient_accumulation_steps: int = 4
    alpha_domain: float = 0.5
    
    # Data
    k_frames: int = 5
    k_audio_chunks: int = 5
    sample_rate: int = 16000
    image_size: int = 224
    max_text_tokens: int = 256
    
    @classmethod
    def from_gpu_memory(cls, gpu_memory_gb: float):
        if gpu_memory_gb >= 40:
            print("üöÄ Using LARGE model configuration")
            return cls(preset="large")
        else:
            print("‚ö° Using SMALL model configuration")
            return cls(
                preset="small",
                vision_backbone="resnet50",
                audio_backbone="facebook/wav2vec2-base",
                d_model=256,
                n_heads=4,
                n_layers=2,
                batch_size=4
            )

# Create config based on GPU
config = ModelConfig.from_gpu_memory(gpu_memory_gb)
print(f"\nüìä Model Config:")
print(f"  - Preset: {config.preset.upper()}")
print(f"  - Model dim: {config.d_model}")
print(f"  - Layers: {config.n_layers}")
print(f"  - Heads: {config.n_heads}")
print(f"  - Batch size: {config.batch_size}")

üöÄ Using LARGE model configuration

üìä Model Config:
  - Preset: LARGE
  - Model dim: 512
  - Layers: 4
  - Heads: 8
  - Batch size: 8


In [4]:
class GradientReversalFunction(torch.autograd.Function):
    """
    Gradient Reversal Layer from:
    'Domain-Adversarial Training of Neural Networks'
    Reverses gradients during backward pass for domain adaptation.
    """
    
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)
    
    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha
        return output, None

class GradientReversalLayer(nn.Module):
    """Wrapper for gradient reversal"""
    
    def __init__(self, alpha=1.0):
        super().__init__()
        self.alpha = alpha
    
    def forward(self, x):
        return GradientReversalFunction.apply(x, self.alpha)
    
    def set_alpha(self, alpha):
        self.alpha = alpha

print("GRL defined!")



GRL defined!


In [5]:
class VisualEncoder(nn.Module):
    """
    Visual encoder for images/video frames.
    Extracts per-frame token embeddings using pretrained vision models.
    """
    
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        
        # Load backbone
        if "vit" in config.vision_backbone.lower():
            self.backbone = timm.create_model(
                config.vision_backbone,
                pretrained=config.vision_pretrained,
                num_classes=0  # Remove classification head
            )
            self.feature_dim = self.backbone.num_features
        elif "resnet" in config.vision_backbone.lower():
            self.backbone = timm.create_model(
                config.vision_backbone,
                pretrained=config.vision_pretrained,
                num_classes=0
            )
            self.feature_dim = self.backbone.num_features
        else:
            raise ValueError(f"Unsupported vision backbone: {config.vision_backbone}")
        
        # Freeze backbone if specified
        if config.freeze_vision:
            for param in self.backbone.parameters():
                param.requires_grad = False
        
        # Projection to common dimension
        self.projection = nn.Linear(self.feature_dim, config.d_model)
        
        # Image preprocessing
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((config.image_size, config.image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def forward(self, images):
        """
        Args:
            images: Tensor of shape (batch, num_frames, C, H, W) or (batch, C, H, W)
        
        Returns:
            tokens: Tensor of shape (batch, num_tokens, d_model)
            available: Boolean indicating if visual data is available
        """
        if images is None or images.numel() == 0:
            return None, False
        
        # Handle single images vs video frames
        if images.ndim == 4:
            # Single image: (batch, C, H, W)
            batch_size = images.size(0)
            num_frames = 1
            images = images.unsqueeze(1)  # (batch, 1, C, H, W)
        else:
            # Video frames: (batch, num_frames, C, H, W)
            batch_size, num_frames = images.size(0), images.size(1)
        
        # Reshape to process all frames
        images_flat = images.view(batch_size * num_frames, *images.shape[2:])
        
        # Extract features
        with torch.set_grad_enabled(not self.config.freeze_vision):
            features = self.backbone(images_flat)  # (batch*num_frames, feature_dim)
        
        # Project to common dimension
        tokens = self.projection(features)  # (batch*num_frames, d_model)
        
        # Reshape back to (batch, num_frames, d_model)
        tokens = tokens.view(batch_size, num_frames, -1)
        
        return tokens, True


class AudioEncoder(nn.Module):
    """
    Audio encoder using Wav2Vec2 or similar pretrained models.
    Extracts audio tokens from waveforms.
    """
    
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        
        # Load Wav2Vec2 model
        try:
            self.backbone = Wav2Vec2Model.from_pretrained(config.audio_backbone)
            self.processor = Wav2Vec2Processor.from_pretrained(config.audio_backbone)
            self.feature_dim = self.backbone.config.hidden_size
            
            # Freeze backbone if specified
            if config.freeze_audio:
                for param in self.backbone.parameters():
                    param.requires_grad = False
            
            # Projection to common dimension
            self.projection = nn.Linear(self.feature_dim, config.d_model)
            self.available = True
            
        except Exception as e:
            print(f"Warning: Could not load audio model: {e}")
            print("Using fallback CNN encoder")
            self.available = False
            self._build_fallback_encoder(config)
    
    def _build_fallback_encoder(self, config):
        """Build simple CNN encoder for audio spectrograms"""
        self.backbone = nn.Sequential(
            nn.Conv1d(1, 64, kernel_size=10, stride=5),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(8),
            nn.Conv1d(64, 128, kernel_size=3),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(32)
        )
        self.projection = nn.Linear(128 * 32, config.d_model)
        self.feature_dim = 128 * 32
    
    def forward(self, waveforms):
        """
        Args:
            waveforms: Tensor of shape (batch, num_chunks, samples) or (batch, samples)
        
        Returns:
            tokens: Tensor of shape (batch, num_tokens, d_model)
            available: Boolean indicating if audio data is available
        """
        if waveforms is None or waveforms.numel() == 0:
            return None, False
        
        # Handle single waveform vs chunks
        if waveforms.ndim == 2:
            batch_size = waveforms.size(0)
            num_chunks = 1
            waveforms = waveforms.unsqueeze(1)  # (batch, 1, samples)
        else:
            batch_size, num_chunks = waveforms.size(0), waveforms.size(1)
        
        # Reshape to process all chunks
        waveforms_flat = waveforms.view(batch_size * num_chunks, -1)
        
        # Extract features
        if self.available:
            with torch.set_grad_enabled(not self.config.freeze_audio):
                outputs = self.backbone(waveforms_flat)
                features = outputs.last_hidden_state.mean(dim=1)  # Pool over time
        else:
            # Fallback CNN
            waveforms_flat = waveforms_flat.unsqueeze(1)  # Add channel dim
            features = self.backbone(waveforms_flat)
            features = features.view(batch_size * num_chunks, -1)
        
        # Project to common dimension
        tokens = self.projection(features)  # (batch*num_chunks, d_model)
        
        # Reshape back
        tokens = tokens.view(batch_size, num_chunks, -1)
        
        return tokens, True


class TextEncoder(nn.Module):
    """
    Text encoder for transcripts using sentence transformers or similar.
    Extracts text embeddings from transcripts.
    """
    
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        
        # Load text model
        try:
            if SENTENCE_TRANSFORMERS_AVAILABLE:
                self.backbone = SentenceTransformer(config.text_backbone)
                self.feature_dim = self.backbone.get_sentence_embedding_dimension()
            else:
                # Fallback to distilbert
                self.backbone = AutoModel.from_pretrained('distilbert-base-uncased')
                self.tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
                self.feature_dim = 768
            
            # Freeze backbone if specified
            if config.freeze_text:
                for param in self.backbone.parameters():
                    param.requires_grad = False
            
            # Projection to common dimension
            self.projection = nn.Linear(self.feature_dim, config.d_model)
            self.available = True
            
        except Exception as e:
            print(f"Warning: Could not load text model: {e}")
            self.available = False
            self.feature_dim = config.d_model
            self.projection = nn.Identity()
    
    def forward(self, texts):
        """
        Args:
            texts: List of strings or None
        
        Returns:
            tokens: Tensor of shape (batch, 1, d_model) - pooled text embedding
            available: Boolean indicating if text data is available
        """
        if texts is None or len(texts) == 0:
            return None, False
        
        batch_size = len(texts)
        
        # Extract features
        if self.available:
            if SENTENCE_TRANSFORMERS_AVAILABLE:
                with torch.set_grad_enabled(not self.config.freeze_text):
                    embeddings = self.backbone.encode(
                        texts, 
                        convert_to_tensor=True,
                        show_progress_bar=False
                    )
            else:
                # Fallback: use tokenizer + model
                inputs = self.tokenizer(
                    texts, 
                    return_tensors='pt', 
                    padding=True, 
                    truncation=True,
                    max_length=self.config.max_text_tokens
                ).to(next(self.backbone.parameters()).device)
                
                with torch.set_grad_enabled(not self.config.freeze_text):
                    outputs = self.backbone(**inputs)
                    embeddings = outputs.last_hidden_state[:, 0, :]  # CLS token
        else:
            # Return zeros if not available
            device = next(self.projection.parameters()).device
            embeddings = torch.zeros(batch_size, self.feature_dim, device=device)
        
        # Project to common dimension
        tokens = self.projection(embeddings)  # (batch, d_model)
        
        # Add sequence dimension
        tokens = tokens.unsqueeze(1)  # (batch, 1, d_model)
        
        return tokens, True


class MetadataEncoder(nn.Module):
    """
    Metadata encoder for categorical features.
    Encodes metadata like uploader, platform, date, etc.
    """
    
    def __init__(self, config: ModelConfig, 
                 n_uploaders=100, n_platforms=10, n_date_buckets=12, n_likes_buckets=10):
        super().__init__()
        self.config = config
        
        # Categorical embeddings
        self.uploader_emb = nn.Embedding(n_uploaders, 64)
        self.platform_emb = nn.Embedding(n_platforms, 32)
        self.date_emb = nn.Embedding(n_date_buckets, 32)
        self.likes_emb = nn.Embedding(n_likes_buckets, 32)
        
        # MLP to project to common dimension
        total_dim = 64 + 32 + 32 + 32
        self.mlp = nn.Sequential(
            nn.Linear(total_dim, config.d_model),
            nn.LayerNorm(config.d_model),
            nn.ReLU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.d_model, config.d_model)
        )
    
    def forward(self, metadata):
        """
        Args:
            metadata: Dict with keys 'uploader', 'platform', 'date', 'likes' (LongTensor)
        
        Returns:
            tokens: Tensor of shape (batch, 1, d_model)
            available: Boolean indicating if metadata is available
        """
        if metadata is None or len(metadata) == 0:
            return None, False
        
        # Get embeddings for each field
        embs = []
        if 'uploader' in metadata:
            embs.append(self.uploader_emb(metadata['uploader']))
        if 'platform' in metadata:
            embs.append(self.platform_emb(metadata['platform']))
        if 'date' in metadata:
            embs.append(self.date_emb(metadata['date']))
        if 'likes' in metadata:
            embs.append(self.likes_emb(metadata['likes']))
        
        if len(embs) == 0:
            return None, False
        
        # Concatenate and project
        combined = torch.cat(embs, dim=-1)
        tokens = self.mlp(combined)
        
        # Add sequence dimension
        tokens = tokens.unsqueeze(1)  # (batch, 1, d_model)
        
        return tokens, True

print("All encoders defined!")



All encoders defined!


In [6]:
class CrossModalFusionTransformer(nn.Module):
    """
    Cross-modal fusion using Transformer encoder.
    Fuses tokens from all modalities using self-attention.
    """
    
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        
        # Modality embeddings (learned)
        self.modality_embeddings = nn.Embedding(4, config.d_model)  # 4 modalities
        
        # CLS token for pooling
        self.cls_token = nn.Parameter(torch.randn(1, 1, config.d_model))
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.d_model,
            nhead=config.n_heads,
            dim_feedforward=config.d_model * 4,
            dropout=config.dropout,
            activation='gelu',
            batch_first=True,
            norm_first=True
        )
        
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=config.n_layers,
            norm=nn.LayerNorm(config.d_model)
        )
        
        # Modality IDs
        self.VISUAL_ID = 0
        self.AUDIO_ID = 1
        self.TEXT_ID = 2
        self.META_ID = 3
    
    def forward(self, visual_tokens=None, audio_tokens=None, 
                text_tokens=None, meta_tokens=None, attention_mask=None):
        """
        Args:
            visual_tokens: (batch, n_visual, d_model) or None
            audio_tokens: (batch, n_audio, d_model) or None
            text_tokens: (batch, n_text, d_model) or None
            meta_tokens: (batch, n_meta, d_model) or None
            attention_mask: (batch, total_tokens) - True for valid tokens
        
        Returns:
            fused_vector: (batch, d_model) - pooled representation
            all_tokens: (batch, total_tokens, d_model) - all output tokens
        """
        batch_size = (visual_tokens.size(0) if visual_tokens is not None 
                     else audio_tokens.size(0) if audio_tokens is not None
                     else text_tokens.size(0) if text_tokens is not None
                     else meta_tokens.size(0))
        
        device = (visual_tokens.device if visual_tokens is not None
                 else audio_tokens.device if audio_tokens is not None
                 else text_tokens.device if text_tokens is not None
                 else meta_tokens.device)
        
        # Collect all tokens
        all_tokens = []
        modality_ids = []
        
        # Add CLS token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        all_tokens.append(cls_tokens)
        # CLS doesn't need modality embedding
        
        # Add visual tokens
        if visual_tokens is not None:
            n_visual = visual_tokens.size(1)
            visual_mod_emb = self.modality_embeddings(
                torch.full((batch_size, n_visual), self.VISUAL_ID, 
                          dtype=torch.long, device=device)
            )
            visual_tokens = visual_tokens + visual_mod_emb
            all_tokens.append(visual_tokens)
        
        # Add audio tokens
        if audio_tokens is not None:
            n_audio = audio_tokens.size(1)
            audio_mod_emb = self.modality_embeddings(
                torch.full((batch_size, n_audio), self.AUDIO_ID,
                          dtype=torch.long, device=device)
            )
            audio_tokens = audio_tokens + audio_mod_emb
            all_tokens.append(audio_tokens)
        
        # Add text tokens
        if text_tokens is not None:
            n_text = text_tokens.size(1)
            text_mod_emb = self.modality_embeddings(
                torch.full((batch_size, n_text), self.TEXT_ID,
                          dtype=torch.long, device=device)
            )
            text_tokens = text_tokens + text_mod_emb
            all_tokens.append(text_tokens)
        
        # Add metadata tokens
        if meta_tokens is not None:
            n_meta = meta_tokens.size(1)
            meta_mod_emb = self.modality_embeddings(
                torch.full((batch_size, n_meta), self.META_ID,
                          dtype=torch.long, device=device)
            )
            meta_tokens = meta_tokens + meta_mod_emb
            all_tokens.append(meta_tokens)
        
        # Concatenate all tokens
        if len(all_tokens) == 0:
            raise ValueError("At least one modality must be provided")
        
        combined_tokens = torch.cat(all_tokens, dim=1)  # (batch, total_tokens, d_model)
        
        # Create attention mask if not provided
        if attention_mask is None:
            attention_mask = torch.ones(
                batch_size, combined_tokens.size(1),
                dtype=torch.bool, device=device
            )
        
        # Convert mask for transformer (True = mask out)
        src_key_padding_mask = ~attention_mask
        
        # Apply transformer
        output_tokens = self.transformer(
            combined_tokens,
            src_key_padding_mask=src_key_padding_mask
        )
        
        # Extract CLS token as fused representation
        fused_vector = output_tokens[:, 0, :]  # (batch, d_model)
        
        return fused_vector, output_tokens


# =============================================================================
# Domain Discriminator
# =============================================================================

class DomainDiscriminator(nn.Module):
    """
    Domain discriminator for adversarial training.
    Classifies the source domain of the input.
    """
    
    def __init__(self, d_model, n_domains, dropout=0.3):
        super().__init__()
        
        self.network = nn.Sequential(
            nn.Linear(d_model, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 128),
            nn.LayerNorm(128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, n_domains)
        )
    
    def forward(self, x):
        """
        Args:
            x: (batch, d_model) - features from encoder
        
        Returns:
            logits: (batch, n_domains) - domain classification logits
        """
        return self.network(x)


# =============================================================================
# Classifier
# =============================================================================

class ClassifierMLP(nn.Module):
    """
    Binary classifier for fake/real detection.
    """
    
    def __init__(self, d_model, dropout=0.3):
        super().__init__()
        
        self.network = nn.Sequential(
            nn.Linear(d_model, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 64),
            nn.LayerNorm(64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 1)  # Binary classification (no sigmoid, use BCEWithLogitsLoss)
        )
    
    def forward(self, x):
        """
        Args:
            x: (batch, d_model) - fused features
        
        Returns:
            logits: (batch, 1) - raw logits for fake/real
        """
        return self.network(x)

print("Fusion transformer defined!")



Fusion transformer defined!


In [7]:
class DomainDiscriminator(nn.Module):
    """
    Domain discriminator for adversarial training.
    Classifies the source domain of the input.
    """
    
    def __init__(self, d_model, n_domains, dropout=0.3):
        super().__init__()
        
        self.network = nn.Sequential(
            nn.Linear(d_model, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 128),
            nn.LayerNorm(128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, n_domains)
        )
    
    def forward(self, x):
        """
        Args:
            x: (batch, d_model) - features from encoder
        
        Returns:
            logits: (batch, n_domains) - domain classification logits
        """
        return self.network(x)


# =============================================================================
# Classifier
# =============================================================================

class ClassifierMLP(nn.Module):
    """
    Binary classifier for fake/real detection.
    """
    
    def __init__(self, d_model, dropout=0.3):
        super().__init__()
        
        self.network = nn.Sequential(
            nn.Linear(d_model, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 64),
            nn.LayerNorm(64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 1)  # Binary classification (no sigmoid, use BCEWithLogitsLoss)
        )
    
    def forward(self, x):
        """
        Args:
            x: (batch, d_model) - fused features
        
        Returns:
            logits: (batch, 1) - raw logits for fake/real
        """
        return self.network(x)

print("Classifiers defined!")



Classifiers defined!


In [8]:
class MultimodalDeepfakeDetector(nn.Module):
    """
    Complete multimodal deepfake detection model with domain-adversarial training.
    """
    
    def __init__(self, config: ModelConfig, n_domains=5):
        super().__init__()
        self.config = config
        
        # Encoders
        self.visual_encoder = VisualEncoder(config)
        self.audio_encoder = AudioEncoder(config)
        self.text_encoder = TextEncoder(config)
        self.meta_encoder = MetadataEncoder(config)
        
        # Fusion
        self.fusion = CrossModalFusionTransformer(config)
        
        # Gradient Reversal Layer
        self.grl = GradientReversalLayer(alpha=config.alpha_domain)
        
        # Domain discriminator
        self.domain_discriminator = DomainDiscriminator(
            config.d_model, n_domains, config.dropout
        )
        
        # Classifier
        self.classifier = ClassifierMLP(config.d_model, config.dropout)
    
    def forward(self, images=None, audio=None, text=None, metadata=None,
                return_domain_logits=True):
        """
        Forward pass through the model.
        
        Args:
            images: (batch, num_frames, C, H, W) or None
            audio: (batch, num_chunks, samples) or None
            text: List of strings or None
            metadata: Dict of categorical features or None
            return_domain_logits: Whether to compute domain logits
        
        Returns:
            dict with keys:
                - 'logits': (batch, 1) - fake/real classification logits
                - 'domain_logits': (batch, n_domains) - domain classification logits
                - 'fused_vector': (batch, d_model) - fused representation
        """
        # Encode each modality
        visual_tokens, visual_avail = self.visual_encoder(images) if images is not None else (None, False)
        audio_tokens, audio_avail = self.audio_encoder(audio) if audio is not None else (None, False)
        text_tokens, text_avail = self.text_encoder(text) if text is not None else (None, False)
        meta_tokens, meta_avail = self.meta_encoder(metadata) if metadata is not None else (None, False)
        
        # Fuse modalities
        fused_vector, all_tokens = self.fusion(
            visual_tokens=visual_tokens if visual_avail else None,
            audio_tokens=audio_tokens if audio_avail else None,
            text_tokens=text_tokens if text_avail else None,
            meta_tokens=meta_tokens if meta_avail else None
        )
        
        # Classification
        class_logits = self.classifier(fused_vector)
        
        # Domain classification with GRL
        domain_logits = None
        if return_domain_logits:
            reversed_features = self.grl(fused_vector)
            domain_logits = self.domain_discriminator(reversed_features)
        
        return {
            'logits': class_logits,
            'domain_logits': domain_logits,
            'fused_vector': fused_vector
        }
    
    def set_grl_alpha(self, alpha):
        """Update GRL alpha for domain adaptation scheduling"""
        self.grl.set_alpha(alpha)


# =============================================================================
# Dataset Classes
# =============================================================================

print("Complete model defined!")



Complete model defined!


In [None]:
class EnhancedMultimodalDataset(Dataset):
    """Dataset with INTERNAL sampling - no Subset wrapping needed"""
    
    def __init__(self, data_root, config, split='train', sample_fraction=0.25):
        self.data_root = Path(data_root)
        self.config = config
        self.split = split
        self.samples = []
        
        print(f"\n{'='*60}")
        print(f"Loading {split.upper()} split")
        print(f"{'='*60}")
        
        # Load ALL 12 datasets
        self._load_deepfake_images()
        self._load_faceforensics()
        self._load_celebdf()
        self._load_kaggle_audio()
        self._load_demo_audio()
        self._load_fakeavceleb()
        self._load_dfd_faces()
        self._load_dfd_sequences()
        self._load_for_audio()
        self._load_140k_faces()
        self._load_youtube_faces()
        
        print(f"\n‚úÖ Total loaded: {len(self.samples):,} samples")
        
        # Apply intelligent balancing (1:1.33 ratio)
        self._apply_intelligent_balancing()
        
        # STRATIFIED REDUCTION DONE INTERNALLY (not with Subset!)
        if sample_fraction < 1.0:
            self._stratified_sample(sample_fraction)
        
        # Print final statistics
        self._print_statistics()
    
    def _stratified_sample(self, fraction):
        """Reduce dataset size while preserving domain distribution"""
        print(f"\n‚öñÔ∏è Stratified sampling to {fraction*100:.0f}% of data...")
        
        # Group samples by domain
        from collections import defaultdict
        domain_samples = defaultdict(list)
        for idx, sample in enumerate(self.samples):
            domain_samples[sample['domain']].append(idx)
        
        # Sample from each domain proportionally
        selected_indices = []
        for domain, indices in domain_samples.items():
            n_samples = max(1, int(len(indices) * fraction))
            import random
            random.seed(42)
            sampled = random.sample(indices, n_samples)
            selected_indices.extend(sampled)
        
        # Update samples to only include selected ones
        self.samples = [self.samples[i] for i in selected_indices]
        
        print(f"  Reduced to {len(self.samples):,} samples")
    
    def _apply_intelligent_balancing(self):
        """Balance to 1:1.33 Real:Fake ratio"""
        real_samples = [s for s in self.samples if s['label'] == 0]
        fake_samples = [s for s in self.samples if s['label'] == 1]
        
        print(f"\nüìä Before Balancing:")
        print(f"  Real: {len(real_samples):,}")
        print(f"  Fake: {len(fake_samples):,}")
        print(f"  Ratio: 1:{len(fake_samples)/max(1,len(real_samples)):.2f}")
        
        # Target ratio: 1:1.33
        target_ratio = 1.33
        target_fake_count = int(len(real_samples) * target_ratio)
        
        # Undersample fakes if too many
        if len(fake_samples) > target_fake_count:
            print(f"\n‚öñÔ∏è Undersampling Fake samples to achieve 1:{target_ratio} ratio")
            import random
            random.seed(42)
            fake_samples = random.sample(fake_samples, target_fake_count)
        
        # Combine balanced samples
        self.samples = real_samples + fake_samples
        import random
        random.seed(42)
        random.shuffle(self.samples)
        
        print(f"\n‚úÖ After Balancing:")
        print(f"  Real: {len(real_samples):,}")
        print(f"  Fake: {len(fake_samples):,}")
        print(f"  Ratio: 1:{len(fake_samples)/len(real_samples):.2f}")
        print(f"  Total: {len(self.samples):,}")
    
    def _print_statistics(self):
        """Print dataset statistics"""
        if len(self.samples) == 0:
            return
        
        # Count by dataset
        dataset_counts = {}
        for sample in self.samples:
            ds = sample['dataset']
            dataset_counts[ds] = dataset_counts.get(ds, 0) + 1
        
        # Count by type
        type_counts = {}
        for sample in self.samples:
            t = sample['type']
            type_counts[t] = type_counts.get(t, 0) + 1
        
        # Count labels
        fake_count = sum(1 for s in self.samples if s['label'] == 1)
        real_count = len(self.samples) - fake_count
        
        print(f"\n{'='*60}")
        print(f"üìä FINAL Dataset Statistics ({self.split}):")
        print(f"  Total: {len(self.samples):,} samples")
        print(f"  Real: {real_count:,} | Fake: {fake_count:,}")
        print(f"  Ratio: 1:{fake_count/max(1,real_count):.2f}")
        print(f"\n  By Type:")
        for t, count in type_counts.items():
            print(f"    {t}: {count:,}")
        print(f"\n  By Dataset:")
        for ds, count in sorted(dataset_counts.items()):
            print(f"    {ds}: {count:,}")
        print(f"{'='*60}\n")
    
    # Keep all your existing _load methods and __getitem__ exactly as they are
    # ...

# ============================================================================
# CREATE DATASETS - NO SUBSET WRAPPING!
# ============================================================================

print("\n" + "="*60)
print("CREATING DATASETS WITH INTERNAL SAMPLING")
print("="*60)

# Create datasets with 25% sampling done internally
train_dataset = EnhancedMultimodalDataset(
    data_root=Path('../'),
    config=config,
    split='train',
    sample_fraction=0.25  # 25% of data, sampled internally
)

test_dataset = EnhancedMultimodalDataset(
    data_root=Path('../'),
    config=config,
    split='test',
    sample_fraction=0.25
)

print(f"\n‚úÖ Datasets created!")
print(f"   Train: {len(train_dataset):,} samples")
print(f"   Test:  {len(test_dataset):,} samples")

# ============================================================================
# FOCAL LOSS
# ============================================================================

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.75, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, inputs, targets):
        bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * bce_loss
        return focal_loss.mean()

criterion = FocalLoss(alpha=0.75, gamma=2.0)
print("‚úÖ FocalLoss created")

# ============================================================================
# WEIGHTED RANDOM SAMPLER - DIRECT ACCESS
# ============================================================================

def get_class_weights_direct(dataset):
    """Get class weights from dataset.samples directly"""
    labels = [s['label'] for s in dataset.samples]
    
    from collections import Counter
    counts = Counter(labels)
    total = len(labels)
    weights = {label: total / count for label, count in counts.items()}
    
    sample_weights = torch.DoubleTensor([weights[label] for label in labels])
    return sample_weights

# Create sampler
sample_weights = get_class_weights_direct(train_dataset)
sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

print(f"‚úÖ WeightedRandomSampler created with {len(sample_weights):,} samples")

# ============================================================================
# DATALOADERS - SIMPLE & DIRECT (NO SUBSET!)
# ============================================================================

train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    sampler=sampler,
    collate_fn=multimodal_collate_fn,
    num_workers=0,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    collate_fn=multimodal_collate_fn,
    num_workers=0,
    pin_memory=True
)

print(f"\n‚úÖ DataLoaders Created:")
print(f"   Train batches: {len(train_loader):,}")
print(f"   Test batches:  {len(test_loader):,}")
print("="*60 + "\n")

In [9]:
# class EnhancedMultimodalDataset(Dataset):
#     """
#     Enhanced dataset that loads ALL available datasets.
#     Supports: Images, Audio, Video from 12 major sources.
#     """
    
#     def __init__(self, data_root, config, split='train'):
#         self.data_root = Path(data_root)
#         self.config = config
#         self.split = split
#         self.samples = []
        
#         print(f"\nüìÇ Scanning for datasets in: {data_root}")
#         self._scan_all_datasets()
#         print(f"\n‚úÖ Loaded {len(self.samples)} samples for {split} split")
#         self._print_statistics()
    
#     def _scan_all_datasets(self):
#         """Scan and load all available datasets"""
        
#         # 1. Deepfake image detection dataset
#         self._load_deepfake_images()
        
#         # 2. Archive dataset - DISABLED (not available)
#         # self._load_archive_dataset()
        
#         # 3. FaceForensics++
#         self._load_faceforensics()
        
#         # 4. Celeb-DF V2
#         self._load_celebdf()
        
#         # 5. KAGGLE Audio
#         self._load_kaggle_audio()
        
#         # 6. DEMONSTRATION Audio
#         self._load_demo_audio()
        
#         # 7. FakeAVCeleb
#         self._load_fakeavceleb()
        
#         # 8. DFD faces
#         self._load_dfd_faces()
        
#         # 9. DFD sequences
#         self._load_dfd_sequences()
        
#         # 10. FoR Audio Dataset (4 versions)
#         self._load_for_audio()
        
#         # 11. 140k Real and Fake Faces
#         self._load_140k_faces()
        
#         # 12. YouTube Faces videos
#         self._load_youtube_faces()
        
#         # Apply intelligent balancing to achieve 1:2 to 1:2.5 Real:Fake ratio
#         if self.split == 'train':
#             self._apply_intelligent_balancing()
    
#     def _load_deepfake_images(self):
#         """Load Deepfake image detection dataset - ALL SUBFOLDERS"""
#         base = self.data_root / 'Deepfake image detection dataset'
#         if not base.exists():
#             print(f"  ‚úó Deepfake Images not found")
#             return
        
#         count = 0
        
#         # Load from train-20250112T065955Z-001/train/
#         train_base = base / 'train-20250112T065955Z-001' / 'train'
#         if train_base.exists():
#             for label_name in ['fake', 'real']:
#                 label_dir = train_base / label_name
#                 if label_dir.exists():
#                     for ext in ['*.jpg', '*.png', '*.jpeg']:
#                         for img in label_dir.glob(ext):
#                             self.samples.append({
#                                 'path': str(img),
#                                 'type': 'image',
#                                 'label': 1 if label_name == 'fake' else 0,
#                                 'domain': 0,
#                                 'dataset': 'DeepfakeImages'
#                             })
#                             count += 1
        
#         # Load from test-20250112T065939Z-001/test/
#         test_base = base / 'test-20250112T065939Z-001' / 'test'
#         if test_base.exists():
#             for label_name in ['fake', 'real']:
#                 label_dir = test_base / label_name
#                 if label_dir.exists():
#                     for ext in ['*.jpg', '*.png', '*.jpeg']:
#                         for img in label_dir.glob(ext):
#                             self.samples.append({
#                                 'path': str(img),
#                                 'type': 'image',
#                                 'label': 1 if label_name == 'fake' else 0,
#                                 'domain': 0,
#                                 'dataset': 'DeepfakeImages'
#                             })
#                             count += 1
        
#         # Load from Sample_fake_images/
#         sample_base = base / 'Sample_fake_images'
#         if sample_base.exists():
#             for ext in ['*.jpg', '*.png', '*.jpeg']:
#                 for img in sample_base.glob(ext):
#                     self.samples.append({
#                         'path': str(img),
#                         'type': 'image',
#                         'label': 1,
#                         'domain': 0,
#                         'dataset': 'DeepfakeImages'
#                     })
#                     count += 1
        
#         print(f"  ‚úì DeepfakeImages: {count} samples")
    
#     def _load_faceforensics(self):
#         """Load FaceForensics++ dataset"""
#         base = self.data_root / 'FaceForensics++' / 'FaceForensics++_C23'
        
#         if not base.exists():
#             print(f"  ‚úó FaceForensics++ not found")
#             return
        
#         count = 0
#         for manip_type in ['Deepfakes', 'Face2Face', 'FaceSwap', 'NeuralTextures', 'FaceShifter', 'original']:
#             manip_dir = base / manip_type
#             if manip_dir.exists():
#                 for vid in manip_dir.glob('*.mp4'):
#                     self.samples.append({
#                         'path': str(vid),
#                         'type': 'video',
#                         'label': 0 if manip_type == 'original' else 1,
#                         'domain': 2,
#                         'dataset': 'FaceForensics++'
#                     })
#                     count += 1
#         print(f"  ‚úì FaceForensics++: {count} samples")
    
#     def _load_celebdf(self):
#         """Load Celeb-DF V2 dataset"""
#         base = self.data_root / 'Celeb V2'
        
#         if not base.exists():
#             print(f"  ‚úó Celeb-DF V2 not found")
#             return
        
#         count = 0
#         for split_type in ['Celeb-synthesis', 'Celeb-real', 'YouTube-real']:
#             split_dir = base / split_type
#             if split_dir.exists():
#                 for vid in split_dir.glob('*.mp4'):
#                     self.samples.append({
#                         'path': str(vid),
#                         'type': 'video',
#                         'label': 1 if 'synthesis' in split_type else 0,
#                         'domain': 3,
#                         'dataset': 'Celeb-DF'
#                     })
#                     count += 1
#         print(f"  ‚úì Celeb-DF V2: {count} samples")
    
#     def _load_kaggle_audio(self):
#         """Load KAGGLE Audio dataset"""
#         base = self.data_root / 'DeepFake_AudioDataset' / 'KAGGLE' / 'AUDIO'
#         if not base.exists():
#             print(f"  ‚úó KAGGLE Audio not found")
#             return
        
#         count = 0
#         for label_name in ['FAKE', 'REAL']:
#             label_dir = base / label_name
#             if label_dir.exists():
#                 for audio in label_dir.glob('*.wav'):
#                     self.samples.append({
#                         'path': str(audio),
#                         'type': 'audio',
#                         'label': 1 if label_name == 'FAKE' else 0,
#                         'domain': 4,
#                         'dataset': 'KAGGLE_Audio'
#                     })
#                     count += 1
#         print(f"  ‚úì KAGGLE Audio: {count} samples")
    
#     def _load_demo_audio(self):
#         """Load DEMONSTRATION Audio"""
#         base = self.data_root / 'DeepFake_AudioDataset' / 'DEMONSTRATION' / 'DEMONSTRATION'
#         if not base.exists():
#             print(f"  ‚úó DEMONSTRATION Audio not found")
#             return
        
#         count = 0
#         for audio in base.glob('*.mp3'):
#             label = 1 if 'to' in audio.stem else 0
#             self.samples.append({
#                 'path': str(audio),
#                 'type': 'audio',
#                 'label': label,
#                 'domain': 5,
#                 'dataset': 'DEMO_Audio'
#             })
#             count += 1
#         print(f"  ‚úì DEMONSTRATION Audio: {count} samples")
    
#     def _load_fakeavceleb(self):
#         """Load FakeAVCeleb dataset"""
#         base = self.data_root / 'FakeAVCeleb' / 'FakeAVCeleb_v1.2' / 'FakeAVCeleb_v1.2'
        
#         if not base.exists():
#             print(f"  ‚úó FakeAVCeleb not found")
#             return
        
#         count = 0
#         for category in ['FakeVideo-FakeAudio', 'FakeVideo-RealAudio', 'RealVideo-FakeAudio', 'RealVideo-RealAudio']:
#             cat_dir = base / category
#             if cat_dir.exists():
#                 for vid in cat_dir.rglob('*.mp4'):
#                     label = 1 if 'Fake' in category else 0
#                     self.samples.append({
#                         'path': str(vid),
#                         'type': 'video',
#                         'label': label,
#                         'domain': 6,
#                         'dataset': 'FakeAVCeleb'
#                     })
#                     count += 1
#         print(f"  ‚úì FakeAVCeleb: {count} samples")
    
#     def _load_dfd_faces(self):
#         """Load DFD faces (extracted frames)"""
#         base = self.data_root / 'dfd_faces'
#         if not base.exists():
#             print(f"  ‚úó DFD Faces not found")
#             return
        
#         split_dir = base / self.split
#         if not split_dir.exists():
#             print(f"  ‚úó DFD Faces {self.split} split not found")
#             return
        
#         count = 0
#         for label_name in ['fake', 'real']:
#             label_dir = split_dir / label_name
#             if label_dir.exists():
#                 for img in label_dir.rglob('*.jpg'):
#                     self.samples.append({
#                         'path': str(img),
#                         'type': 'image',
#                         'label': 1 if label_name == 'fake' else 0,
#                         'domain': 7,
#                         'dataset': 'DFD_Faces'
#                     })
#                     count += 1
#                 for img in label_dir.rglob('*.png'):
#                     self.samples.append({
#                         'path': str(img),
#                         'type': 'image',
#                         'label': 1 if label_name == 'fake' else 0,
#                         'domain': 7,
#                         'dataset': 'DFD_Faces'
#                     })
#                     count += 1
#         print(f"  ‚úì DFD Faces: {count} samples")
    
#     def _load_dfd_sequences(self):
#         """Load DFD sequences"""
#         base = self.data_root / 'DFD'
#         if not base.exists():
#             print(f"  ‚úó DFD sequences not found")
#             return
        
#         count = 0
#         # Manipulated sequences
#         manip_dir = base / 'DFD_manipulated_sequences' / 'DFD_manipulated_sequences'
#         if manip_dir.exists():
#             for vid in manip_dir.rglob('*.mp4'):
#                 self.samples.append({
#                     'path': str(vid),
#                     'type': 'video',
#                     'label': 1,
#                     'domain': 8,
#                     'dataset': 'DFD_Sequences'
#                 })
#                 count += 1
        
#         # Original sequences
#         orig_dir = base / 'DFD_original sequences' / 'DFD_original_sequences'
#         if orig_dir.exists():
#             for vid in orig_dir.rglob('*.mp4'):
#                 self.samples.append({
#                     'path': str(vid),
#                     'type': 'video',
#                     'label': 0,
#                     'domain': 8,
#                     'dataset': 'DFD_Sequences'
#                 })
#                 count += 1
        
#         print(f"  ‚úì DFD sequences: {count} samples")
    
#     def _load_for_audio(self):
#         """Load FoR (Fake-or-Real) Audio Dataset - 4 versions"""
#         base = self.data_root / 'The Fake-or-Real (FoR) Dataset (deepfake audio)'
        
#         if not base.exists():
#             print(f"  ‚úó FoR Audio not found")
#             return
        
#         count = 0
#         versions = {
#             'for-norm': 'for-norm/for-norm',
#             'for-2sec': 'for-2sec/for-2seconds',
#         }
        
#         for version_name, version_path in versions.items():
#             version_base = base / version_path
#             if not version_base.exists():
#                 continue
            
#             split_map = {'train': 'training', 'test': 'testing', 'val': 'validation'}
#             split_dir = version_base / split_map.get(self.split, 'training')
            
#             if not split_dir.exists():
#                 continue
            
#             for label_name in ['fake', 'real']:
#                 label_dir = split_dir / label_name
#                 if label_dir.exists():
#                     for ext in ['*.wav', '*.mp3', '*.flac']:
#                         for audio in label_dir.glob(ext):
#                             self.samples.append({
#                                 'path': str(audio),
#                                 'type': 'audio',
#                                 'label': 1 if label_name == 'fake' else 0,
#                                 'domain': 9,
#                                 'dataset': f'FoR_Audio_{version_name}'
#                             })
#                             count += 1
        
#         print(f"  ‚úì FoR Audio: {count} samples")
    
#     def _load_140k_faces(self):
#         """Load 140k Real and Fake Faces dataset"""
#         base = self.data_root / '140k Real and Fake Faces' / 'real_vs_fake' / 'real-vs-fake'
        
#         if not base.exists():
#             print(f"  ‚úó 140k Faces not found")
#             return
        
#         count = 0
#         split_map = {'train': 'train', 'test': 'test', 'val': 'valid'}
#         split_dir = base / split_map.get(self.split, 'train')
        
#         if not split_dir.exists():
#             print(f"  ‚úó 140k Faces {self.split} split not found")
#             return
        
#         for label_name in ['fake', 'real']:
#             label_dir = split_dir / label_name
#             if label_dir.exists():
#                 for ext in ['*.jpg', '*.png', '*.jpeg']:
#                     for img in label_dir.glob(ext):
#                         self.samples.append({
#                             'path': str(img),
#                             'type': 'image',
#                             'label': 1 if label_name == 'fake' else 0,
#                             'domain': 10,
#                             'dataset': '140k_Faces'
#                         })
#                         count += 1
        
#         print(f"  ‚úì 140k Faces: {count} samples")
    
#     def _load_youtube_faces(self):
#         """Load YouTube Faces Dataset with Facial Keypoints"""
#         base = self.data_root / 'YouTube Faces With Facial Keypoints'
        
#         if not base.exists():
#             print(f"  ‚úó YouTube Faces not found")
#             return
        
#         count = 0
#         for folder_num in range(1, 5):
#             folder = base / f'youtube_faces_with_keypoints_full_{folder_num}' / f'youtube_faces_with_keypoints_full_{folder_num}'
#             if folder.exists():
#                 for npz_file in folder.glob('*.npz'):
#                     self.samples.append({
#                         'path': str(npz_file),
#                         'type': 'video',
#                         'label': 0,
#                         'domain': 11,
#                         'dataset': 'YouTube_Faces'
#                     })
#                     count += 1
        
#         print(f"  ‚úì YouTube Faces: {count} samples (REAL videos - critical for balancing!)")
    
#     def _apply_intelligent_balancing(self):
#         """Apply intelligent balancing to achieve 1:2 to 1:2.5 Real:Fake ratio"""
#         real_samples = [s for s in self.samples if s['label'] == 0]
#         fake_samples = [s for s in self.samples if s['label'] == 1]
        
#         print(f"\nüìä Before Balancing:")
#         print(f"  Real: {len(real_samples):,}")
#         print(f"  Fake: {len(fake_samples):,}")
#         print(f"  Ratio: 1:{len(fake_samples)/len(real_samples):.2f}")
        
#         # Target ratio: 1:2.25 (middle of 1:2 to 1:2.5)
#         target_ratio = 2.25
#         target_fake_count = int(len(real_samples) * target_ratio)
        
#         # If we have too many fakes, undersample
#         if len(fake_samples) > target_fake_count:
#             print(f"\n‚öñÔ∏è Undersampling Fake samples to achieve 1:{target_ratio} ratio")
#             from sklearn.utils import resample
#             fake_samples = resample(fake_samples, 
#                                    n_samples=target_fake_count,
#                                    random_state=42,
#                                    replace=False)
        
#         # Combine balanced samples
#         self.samples = real_samples + fake_samples
        
#         print(f"\n‚úÖ After Balancing:")
#         print(f"  Real: {len(real_samples):,}")
#         print(f"  Fake: {len(fake_samples):,}")
#         print(f"  Ratio: 1:{len(fake_samples)/len(real_samples):.2f}")
#         print(f"  Total: {len(self.samples):,}")
    
#     def _print_statistics(self):
#         """Print dataset statistics"""
#         if len(self.samples) == 0:
#             return
        
#         # Count by dataset
#         dataset_counts = {}
#         for sample in self.samples:
#             ds = sample['dataset']
#             dataset_counts[ds] = dataset_counts.get(ds, 0) + 1
        
#         # Count by type
#         type_counts = {}
#         for sample in self.samples:
#             t = sample['type']
#             type_counts[t] = type_counts.get(t, 0) + 1
        
#         # Count labels
#         fake_count = sum(1 for s in self.samples if s['label'] == 1)
#         real_count = len(self.samples) - fake_count
        
#         print(f"\nüìä Dataset Statistics:")
#         print(f"  Total: {len(self.samples)} samples")
#         print(f"  Real: {real_count} | Fake: {fake_count}")
#         print(f"\n  By Type:")
#         for t, count in type_counts.items():
#             print(f"    {t}: {count}")
#         print(f"\n  By Dataset:")
#         for ds, count in sorted(dataset_counts.items()):
#             print(f"    {ds}: {count}")
    
#     def __len__(self):
#         return len(self.samples)
    
#     def __getitem__(self, idx):
#         sample = self.samples[idx]
        
#         # Load data based on type
#         if sample['type'] == 'image':
#             image = self._load_image(sample['path'])
#             return {
#                 'image': image,
#                 'audio': None,
#                 'text': None,
#                 'metadata': None,
#                 'label': sample['label'],
#                 'domain': sample['domain']
#             }
#         elif sample['type'] == 'audio':
#             audio = self._load_audio(sample['path'])
#             return {
#                 'image': None,
#                 'audio': audio,
#                 'text': None,
#                 'metadata': None,
#                 'label': sample['label'],
#                 'domain': sample['domain']
#             }
#         elif sample['type'] == 'video':
#             # For videos, extract first frame for now
#             image = self._load_video_frame(sample['path'])
#             return {
#                 'image': image,
#                 'audio': None,
#                 'text': None,
#                 'metadata': None,
#                 'label': sample['label'],
#                 'domain': sample['domain']
#             }
    
#     def _load_image(self, path):
#         try:
#             img = cv2.imread(path)
#             img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#             img = cv2.resize(img, (self.config.image_size, self.config.image_size))
#             img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0
#             mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
#             std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
#             img = (img - mean) / std
#             return img
#         except:
#             return torch.zeros(3, self.config.image_size, self.config.image_size)
    
#     def _load_audio(self, path):
#         try:
#             waveform, sr = librosa.load(path, sr=self.config.sample_rate, duration=10)
#             target_length = self.config.sample_rate * 10
#             if len(waveform) < target_length:
#                 waveform = np.pad(waveform, (0, target_length - len(waveform)))
#             else:
#                 waveform = waveform[:target_length]
#             return torch.from_numpy(waveform).float()
#         except:
#             return torch.zeros(self.config.sample_rate * 10)
    
#     def _load_video_frame(self, path):
#         try:
#             cap = cv2.VideoCapture(path)
#             ret, frame = cap.read()
#             if ret:
#                 frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
#                 frame = cv2.resize(frame, (self.config.image_size, self.config.image_size))
#                 frame = torch.from_numpy(frame).permute(2, 0, 1).float() / 255.0
#                 mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
#                 std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
#                 frame = (frame - mean) / std
#                 cap.release()
#                 return frame
#             else:
#                 cap.release()
#                 return torch.zeros(3, self.config.image_size, self.config.image_size)
#         except:
#             return torch.zeros(3, self.config.image_size, self.config.image_size)

# print("‚úÖ Enhanced dataset loader defined!")

‚úÖ Enhanced dataset loader defined!


In [10]:
# Add this method to EnhancedMultimodalDataset class to handle .npz files

def _load_youtube_npz(self, path):
    """Load YouTube Faces .npz file containing video frames"""
    try:
        # Load .npz file
        data = np.load(path)
        
        # YouTube Faces .npz contains 'colorImages' key with video frames
        if 'colorImages' in data:
            frames = data['colorImages']
            
            # Select first frame or random frame
            if len(frames) > 0:
                frame_idx = 0  # or: np.random.randint(0, len(frames))
                frame = frames[frame_idx]
                
                # Convert to RGB if needed
                if frame.shape[-1] != 3:
                    frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
                
                # Resize and normalize
                frame = cv2.resize(frame, (self.config.image_size, self.config.image_size))
                frame = torch.from_numpy(frame).permute(2, 0, 1).float() / 255.0
                mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
                std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
                frame = (frame - mean) / std
                return frame
    except Exception as e:
        print(f"Error loading YouTube .npz: {e}")
        pass
    
    return torch.zeros(3, self.config.image_size, self.config.image_size)

print("‚úÖ YouTube Faces .npz handler defined!")

‚úÖ YouTube Faces .npz handler defined!


In [11]:
def collate_fn(batch):
    """Custom collate for variable modalities"""
    images, audios, texts, metadatas = [], [], [], []
    labels, domains = [], []
    
    for item in batch:
        # Always append labels and domains
        labels.append(item['label'])
        domains.append(item['domain'])
        
        # Append modality data (use zeros if not available)
        if item['image'] is not None:
            images.append(item['image'])
        else:
            # Add zero tensor as placeholder
            images.append(torch.zeros(3, 224, 224))
            
        if item['audio'] is not None:
            audios.append(item['audio'])
        else:
            # Add zero tensor as placeholder
            audios.append(torch.zeros(16000 * 10))
    
    return {
        'images': torch.stack(images) if images else None,
        'audio': torch.stack(audios) if audios else None,
        'text': None,
        'metadata': None,
        'labels': torch.tensor(labels, dtype=torch.float32),
        'domains': torch.tensor(domains, dtype=torch.long)
    }

print("‚úÖ Collate function defined!")

‚úÖ Collate function defined!


In [12]:
def train_epoch(model, dataloader, optimizer, scaler, config, epoch):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    # Update GRL alpha
    progress = epoch / config.epochs
    alpha = config.alpha_domain * (2 / (1 + np.exp(-10 * progress)) - 1)
    model.set_grl_alpha(alpha)
    
    pbar = tqdm(dataloader, desc=f'Epoch {epoch+1}')
    for batch in pbar:
        images = batch['images'].to(device) if batch['images'] is not None else None
        audio = batch['audio'].to(device) if batch['audio'] is not None else None
        labels = batch['labels'].to(device)
        domains = batch['domains'].to(device)
        
        with autocast():
            outputs = model(images=images, audio=audio, text=None, metadata=None)
            cls_loss = F.binary_cross_entropy_with_logits(outputs['logits'].squeeze(), labels)
            dom_loss = F.cross_entropy(outputs['domain_logits'], domains) if outputs['domain_logits'] is not None else 0
            loss = cls_loss + alpha * dom_loss
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        
        total_loss += loss.item()
        preds = (torch.sigmoid(outputs['logits']) > 0.5).float()
        correct += (preds.squeeze() == labels).sum().item()
        total += labels.size(0)
        
        pbar.set_postfix({'loss': total_loss/len(pbar), 'acc': 100.*correct/total})
    
    return total_loss / len(dataloader), 100. * correct / total

def evaluate(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    all_preds, all_labels = [], []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Evaluating'):
            images = batch['images'].to(device) if batch['images'] is not None else None
            audio = batch['audio'].to(device) if batch['audio'] is not None else None
            labels = batch['labels'].to(device)
            
            outputs = model(images=images, audio=audio, text=None, metadata=None, return_domain_logits=False)
            preds = (torch.sigmoid(outputs['logits']) > 0.5).float()
            correct += (preds.squeeze() == labels).sum().item()
            total += labels.size(0)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    acc = 100. * correct / total
    precision = precision_score(all_labels, all_preds, zero_division=0) * 100
    recall = recall_score(all_labels, all_preds, zero_division=0) * 100
    f1 = f1_score(all_labels, all_preds, zero_division=0) * 100
    
    return {'accuracy': acc, 'precision': precision, 'recall': recall, 'f1': f1}

print("Training functions defined!")


Training functions defined!


In [13]:
# ===========================
# CLASS BALANCING UTILITIES
# ===========================

from torch.utils.data import WeightedRandomSampler

class FocalLoss(nn.Module):
    """
    Focal Loss for handling class imbalance.
    Focuses on hard-to-classify examples.
    """
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, inputs, targets):
        bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * bce_loss
        return focal_loss.mean()

def get_class_weights(dataset):
    """
    Calculate class weights FAST by accessing metadata directly.
    """
    # ‚úÖ FAST: Access labels from metadata (no file loading)
    labels = [s['label'] for s in dataset.samples]
    labels_array = np.array(labels)
    class_counts = np.bincount(labels_array.astype(int))
    
    print(f"\nüìä Class Distribution:")
    print(f"   Real (0): {class_counts[0]:,} samples")
    print(f"   Fake (1): {class_counts[1]:,} samples")
    print(f"   Imbalance Ratio: {class_counts[1]/class_counts[0]:.2f}:1")
    
    # Calculate weights (inverse frequency)
    weights = 1. / class_counts
    sample_weights = [weights[int(label)] for label in labels]
    
    return torch.DoubleTensor(sample_weights)

def calculate_pos_weight(dataset):
    """
    Calculate pos_weight FAST from metadata.
    """
    # ‚úÖ FAST: Access labels from metadata
    labels = [s['label'] for s in dataset.samples]
    labels_array = np.array(labels)
    class_counts = np.bincount(labels_array.astype(int))
    
    num_real = class_counts[0]
    num_fake = class_counts[1]
    pos_weight = num_real / num_fake
    
    print(f"\n‚öñÔ∏è Pos Weight for BCE Loss: {pos_weight:.4f}")
    
    return torch.tensor([pos_weight])

print("‚úÖ Class balancing utilities defined (OPTIMIZED VERSION)!")
print("   - FocalLoss: Handles hard examples")
print("   - get_class_weights: For balanced sampling")
print("   - calculate_pos_weight: For weighted BCE loss")

‚úÖ Class balancing utilities defined (OPTIMIZED VERSION)!
   - FocalLoss: Handles hard examples
   - get_class_weights: For balanced sampling
   - calculate_pos_weight: For weighted BCE loss


In [14]:
# # ===========================
# # CREATE DATASETS WITH AUTOMATIC BALANCING
# # ===========================

# print("="*60)
# print("LOADING & BALANCING DATASETS")
# print("="*60)

# # Create datasets
# print("\nüìÇ Loading datasets...")
# train_dataset = EnhancedMultimodalDataset('../', config, split='train')
# test_dataset = EnhancedMultimodalDataset('../', config, split='test')

# # ===========================
# # STRATIFIED DATASET REDUCTION (25% FROM EACH DATASET) - DO THIS FIRST!
# # ===========================

# from torch.utils.data import Subset
# import random
# from collections import defaultdict

# print("\n" + "="*60)
# print("‚ö†Ô∏è REDUCING DATASET SIZE (STRATIFIED)")
# print("="*60)

# # ========== REDUCE TRAINING SET ==========
# print("\nüìä Training Set Reduction:")

# # Group samples by dataset
# dataset_samples = defaultdict(list)
# for idx, sample in enumerate(train_dataset.samples):
#     dataset_samples[sample['dataset']].append(idx)

# # Take 25% from EACH dataset
# selected_indices = []
# for dataset_name, indices in dataset_samples.items():
#     n_samples = max(1, len(indices) // 4)
#     random.seed(42)
#     selected = random.sample(indices, n_samples)
#     selected_indices.extend(selected)
#     print(f"   {dataset_name}: {len(indices):,} ‚Üí {len(selected):,}")

# original_train_size = len(train_dataset.samples)
# train_dataset = Subset(train_dataset, selected_indices)

# print(f"\n   Total: {original_train_size:,} ‚Üí {len(train_dataset):,} samples")

# # ========== REDUCE TEST SET ==========
# print("\nüìä Test Set Reduction:")

# # Group test samples by dataset
# test_dataset_samples = defaultdict(list)
# for idx, sample in enumerate(test_dataset.samples):
#     test_dataset_samples[sample['dataset']].append(idx)

# # Take 25% from EACH dataset
# test_selected_indices = []
# for dataset_name, indices in test_dataset_samples.items():
#     n_samples = max(1, len(indices) // 4)
#     random.seed(42)
#     selected = random.sample(indices, n_samples)
#     test_selected_indices.extend(selected)
#     print(f"   {dataset_name}: {len(indices):,} ‚Üí {len(selected):,}")

# original_test_size = len(test_dataset.samples)
# test_dataset = Subset(test_dataset, test_selected_indices)

# print(f"\n   Total: {original_test_size:,} ‚Üí {len(test_dataset):,} samples")

# # ===========================
# # NOW BALANCE BOTH SETS TO 1:1.33 RATIO
# # ===========================

# def balance_subset_dataset(subset_dataset, target_ratio=1.33, name="Dataset"):
#     """Balance a Subset dataset to target ratio"""
#     base_dataset = subset_dataset.dataset
#     indices = subset_dataset.indices
    
#     # Get labels for subset indices
#     labels = [base_dataset.samples[i]['label'] for i in indices]
    
#     # Separate indices by label
#     real_indices = [idx for idx, label in zip(indices, labels) if label == 0]
#     fake_indices = [idx for idx, label in zip(indices, labels) if label == 1]
    
#     current_ratio = len(fake_indices) / len(real_indices) if len(real_indices) > 0 else 0
    
#     print(f"\n‚öñÔ∏è {name} Balancing:")
#     print(f"  Before: Real={len(real_indices):,}, Fake={len(fake_indices):,}, Ratio=1:{current_ratio:.2f}")
    
#     # Undersample fakes to match target ratio
#     target_fake_count = int(len(real_indices) * target_ratio)
    
#     if len(fake_indices) > target_fake_count:
#         random.seed(42)
#         fake_indices = random.sample(fake_indices, target_fake_count)
    
#     # Combine and create new subset
#     balanced_indices = real_indices + fake_indices
#     random.shuffle(balanced_indices)
    
#     # Create new Subset with balanced indices
#     balanced_subset = Subset(base_dataset, balanced_indices)
    
#     print(f"  After:  Real={len(real_indices):,}, Fake={len(fake_indices):,}, Ratio=1:{target_ratio:.2f}")
#     print(f"  Total:  {len(balanced_subset):,} samples")
#     print(f"  ‚úÖ {name} balanced!")
    
#     return balanced_subset

# # Balance both datasets
# train_dataset = balance_subset_dataset(train_dataset, target_ratio=1.33, name="Training Set")
# test_dataset = balance_subset_dataset(test_dataset, target_ratio=1.33, name="Test Set")

# print(f"\n‚úÖ Both datasets reduced AND balanced!")
# print(f"‚ö° Estimated time: ~15-20 min training + 2-3 min eval per epoch")
# print("="*60 + "\n")

# # ===========================
# # SUBSET-AWARE HELPER FUNCTIONS
# # ===========================

# def get_class_weights_from_subset(subset_dataset):
#     """Calculate class weights from a Subset object"""
#     base_dataset = subset_dataset.dataset
#     indices = subset_dataset.indices
    
#     labels = [base_dataset.samples[i]['label'] for i in indices]
#     labels_array = np.array(labels)
#     class_counts = np.bincount(labels_array.astype(int))
    
#     print(f"\nüìä Class Distribution:")
#     print(f"   Real (0): {class_counts[0]:,} samples")
#     print(f"   Fake (1): {class_counts[1]:,} samples")
#     print(f"   Imbalance Ratio: {class_counts[1]/class_counts[0]:.2f}:1")
    
#     weights = 1. / class_counts
#     sample_weights = [weights[int(label)] for label in labels]
    
#     return torch.DoubleTensor(sample_weights)

# def calculate_pos_weight_from_subset(subset_dataset):
#     """Calculate pos_weight from a Subset object"""
#     base_dataset = subset_dataset.dataset
#     indices = subset_dataset.indices
    
#     labels = [base_dataset.samples[i]['label'] for i in indices]
#     labels_array = np.array(labels)
#     class_counts = np.bincount(labels_array.astype(int))
    
#     num_real = class_counts[0]
#     num_fake = class_counts[1]
#     pos_weight = num_real / num_fake
    
#     print(f"\n‚öñÔ∏è Pos Weight for BCE Loss: {pos_weight:.4f}")
    
#     return torch.tensor([pos_weight])

# # ===========================
# # CREATE DATALOADERS
# # ===========================

# import gc

# torch.cuda.empty_cache()
# gc.collect()

# print("\nüîÑ Setting up balanced sampling for training...")

# sample_weights = get_class_weights_from_subset(train_dataset)

# sampler = WeightedRandomSampler(
#     sample_weights, 
#     num_samples=len(train_dataset),
#     replacement=True
# )

# pos_weight = calculate_pos_weight_from_subset(train_dataset).to(device)

# train_loader = DataLoader(
#     train_dataset, 
#     batch_size=config.batch_size, 
#     sampler=sampler,
#     collate_fn=collate_fn, 
#     num_workers=0,
#     pin_memory=True,
#     drop_last=True
# )

# test_loader = DataLoader(
#     test_dataset, 
#     batch_size=config.batch_size, 
#     shuffle=False,
#     collate_fn=collate_fn, 
#     num_workers=0, 
#     pin_memory=True
# )

# focal_loss_fn = FocalLoss(alpha=0.75, gamma=2.0).to(device)

# # ===========================
# # SUMMARY
# # ===========================

# print(f"\n{'='*60}")
# print("OPTIMIZED DATALOADER SUMMARY")
# print("="*60)
# print(f"\nüìä Training Set:")
# print(f"  Total samples: {len(train_dataset):,}")
# print(f"  Batches per epoch: {len(train_loader):,}")
# print(f"  Batch size: {config.batch_size}")
# print(f"  Sampling: WeightedRandomSampler (balanced)")

# print(f"\nüìä Test Set:")
# print(f"  Total samples: {len(test_dataset):,}")
# print(f"  Batches: {len(test_loader):,}")

# print(f"\nüéØ Loss Configuration:")
# print(f"  Loss Function: Focal Loss")
# print(f"  Alpha (Œ±): 0.75, Gamma (Œ≥): 2.0")
# print(f"  pos_weight: {pos_weight.item():.4f}")

# print(f"\n‚úÖ GUARANTEED 1:1.33 ratio in BOTH train and test!")
# print("="*60)

LOADING & BALANCING DATASETS

üìÇ Loading datasets...

üìÇ Scanning for datasets in: ../
  ‚úì DeepfakeImages: 978 samples
  ‚úì FaceForensics++: 6000 samples
  ‚úì Celeb-DF V2: 6529 samples
  ‚úì KAGGLE Audio: 64 samples
  ‚úó DEMONSTRATION Audio not found
  ‚úì FakeAVCeleb: 21560 samples
  ‚úì DFD Faces: 7808 samples
  ‚úì DFD sequences: 3068 samples
  ‚úì FoR Audio: 67824 samples
  ‚úì 140k Faces: 100000 samples
  ‚úì YouTube Faces: 2194 samples (REAL videos - critical for balancing!)

üìä Before Balancing:
  Real: 92,659
  Fake: 123,366
  Ratio: 1:1.33

‚úÖ After Balancing:
  Real: 92,659
  Fake: 123,366
  Ratio: 1:1.33
  Total: 216,025

‚úÖ Loaded 216025 samples for train split

üìä Dataset Statistics:
  Total: 216025 samples
  Real: 92659 | Fake: 123366

  By Type:
    image: 108786
    video: 39351
    audio: 67888

  By Dataset:
    140k_Faces: 100000
    Celeb-DF: 6529
    DFD_Faces: 7808
    DFD_Sequences: 3068
    DeepfakeImages: 978
    FaceForensics++: 6000
    FakeAVC

In [None]:
import gc
import torch

print("="*60)
print("AGGRESSIVE MEMORY OPTIMIZATION")
print("="*60)

# Clear all caches
torch.cuda.empty_cache()
gc.collect()

# Limit PyTorch memory allocation
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.cuda.set_per_process_memory_fraction(0.90)

print(f"\n‚úÖ Memory optimization configured")
print(f"   GPU: {torch.cuda.get_device_name(0)}")
print(f"   Total Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print("="*60)

# Build model
print("\nBuilding model...")
n_domains = 12
model = MultimodalDeepfakeDetector(config, n_domains=n_domains).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"‚úÖ Model built successfully!")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Setup training
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.epochs)
scaler = GradScaler()

print("\nOptimizer ready!")

# ===========================
# TRAINING LOOP WITH ANTI-FREEZE FIXES
# ===========================

print("\n" + "="*60)
print("STARTING TRAINING WITH ANTI-FREEZE PROTECTION")
print("="*60)
print(f"‚ö° FIXES: Reduced batch processing, aggressive memory clearing")
print("="*60 + "\n")

best_acc = 0
best_f1 = 0
results_history = []

for epoch in range(config.epochs):
    print(f"\n{'='*60}")
    print(f"EPOCH {epoch+1}/{config.epochs}")
    print(f"{'='*60}")
    
    # Clear memory before epoch
    torch.cuda.empty_cache()
    gc.collect()
    
    # ==================== TRAINING ====================
    print("\n[TRAINING]")
    model.train()
    total_loss = 0
    total_cls_loss = 0
    correct = 0
    total = 0
    
    # Update GRL alpha
    progress = epoch / config.epochs
    alpha = config.alpha_domain * (2 / (1 + np.exp(-10 * progress)) - 1)
    model.set_grl_alpha(alpha)
    print(f"  GRL Alpha: {alpha:.4f}")
    
    # ‚ö° ANTI-FREEZE FIX: Process in smaller chunks
    from tqdm import tqdm
    
    # Calculate total batches
    total_batches = len(train_loader)
    print(f"  Total batches: {total_batches:,}")
    
    # Create iterator
    train_iter = iter(train_loader)
    
    step = 0
    checkpoint_interval = 100  # Save progress every 100 batches
    
    pbar = tqdm(total=total_batches, desc='  Training', ncols=100, leave=True)
    
    while step < total_batches:
        try:
            # ‚ö° CRITICAL: Get batch with timeout protection
            batch = next(train_iter)
            
            # Move data to GPU
            images = batch['images'].to(device, non_blocking=True) if batch['images'] is not None else None
            audio = batch['audio'].to(device, non_blocking=True) if batch['audio'] is not None else None
            labels = batch['labels'].to(device, non_blocking=True)
            domains = batch['domains'].to(device, non_blocking=True)
            
            with autocast():
                outputs = model(images=images, audio=audio, text=None, metadata=None)
                cls_loss = focal_loss_fn(outputs['logits'].squeeze(), labels)
                dom_loss = F.cross_entropy(outputs['domain_logits'], domains) if outputs['domain_logits'] is not None else 0
                loss = cls_loss + alpha * dom_loss
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            
            # Track metrics
            loss_val = loss.item() if isinstance(loss, torch.Tensor) else 0
            cls_loss_val = cls_loss.item() if isinstance(cls_loss, torch.Tensor) else 0
            
            preds = (torch.sigmoid(outputs['logits']) > 0.5).float()
            correct += (preds.squeeze() == labels).sum().item()
            total += labels.size(0)
            
            # ‚ö° CRITICAL: Delete tensors immediately
            del images, audio, labels, domains, outputs, loss, cls_loss, preds
            if isinstance(dom_loss, torch.Tensor):
                del dom_loss
            
            total_loss += loss_val
            total_cls_loss += cls_loss_val
            
            step += 1
            
            # ‚ö° ANTI-FREEZE: Clear memory every 20 steps (not 50)
            if step % 20 == 0:
                torch.cuda.empty_cache()
                gc.collect()
            
            # ‚ö° ANTI-FREEZE: Print progress every 100 batches
            if step % checkpoint_interval == 0:
                current_acc = 100. * correct / total if total > 0 else 0
                print(f"\n  Checkpoint {step}/{total_batches} - Loss: {total_loss/step:.4f}, Acc: {current_acc:.2f}%")
            
            # Update progress bar
            pbar.update(1)
            pbar.set_postfix({
                'loss': f'{total_loss/step:.4f}',
                'acc': f'{100.*correct/total:.1f}%'
            })
            
        except StopIteration:
            break
        except Exception as e:
            print(f"\n‚ö†Ô∏è Error at batch {step}: {e}")
            torch.cuda.empty_cache()
            gc.collect()
            continue
    
    pbar.close()
    
    train_loss = total_loss / step if step > 0 else 0
    train_acc = 100. * correct / total if total > 0 else 0
    
    print(f"\n  >>> TRAINING RESULTS:")
    print(f"      Loss:     {train_loss:.4f}")
    print(f"      Accuracy: {train_acc:.2f}%")
    
    torch.cuda.empty_cache()
    gc.collect()
    
    # ==================== EVALUATION ====================
    print(f"\n[EVALUATION]")
    model.eval()
    correct = 0
    total = 0
    all_preds, all_labels, all_probs = [], [], []
    
    eval_batches = len(test_loader)
    print(f"  Eval batches: {eval_batches:,}")
    
    pbar = tqdm(test_loader, desc='  Testing', total=eval_batches, ncols=100, leave=True)
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(pbar):
            try:
                images = batch['images'].to(device, non_blocking=True) if batch['images'] is not None else None
                audio = batch['audio'].to(device, non_blocking=True) if batch['audio'] is not None else None
                labels = batch['labels'].to(device, non_blocking=True)
                
                outputs = model(images=images, audio=audio, text=None, metadata=None, return_domain_logits=False)
                probs = torch.sigmoid(outputs['logits'])
                preds = (probs > 0.5).float()
                
                correct += (preds.squeeze() == labels).sum().item()
                total += labels.size(0)
                
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())
                
                del images, audio, labels, outputs, probs, preds
                
                # Clear memory every 50 batches during eval
                if (batch_idx + 1) % 50 == 0:
                    torch.cuda.empty_cache()
                
                pbar.set_postfix({'acc': f'{100.*correct/total:.1f}%'})
                
            except Exception as e:
                print(f"\n‚ö†Ô∏è Eval error at batch {batch_idx}: {e}")
                continue
    
    pbar.close()
    
    # Calculate metrics
    test_acc = 100. * correct / total if total > 0 else 0
    test_precision = precision_score(all_labels, all_preds, zero_division=0) * 100
    test_recall = recall_score(all_labels, all_preds, zero_division=0) * 100
    test_f1 = f1_score(all_labels, all_preds, zero_division=0) * 100

    from sklearn.metrics import classification_report
    print(f"\n  >>> DETAILED METRICS:")
    print(classification_report(all_labels, all_preds, 
                              target_names=['Real', 'Fake'], 
                              digits=2, 
                              zero_division=0))
    
    print(f"\n  >>> TEST RESULTS:")
    print(f"      Accuracy:  {test_acc:.2f}%")
    print(f"      Precision: {test_precision:.2f}%")
    print(f"      Recall:    {test_recall:.2f}%")
    print(f"      F1 Score:  {test_f1:.2f}%")
    
    scheduler.step()
    
    results_history.append({
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'train_acc': train_acc,
        'test_acc': test_acc,
        'test_f1': test_f1
    })
    
    if test_f1 > best_f1:
        best_f1 = test_f1
        best_acc = test_acc
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'test_f1': test_f1,
            'config': config
        }, 'best_multimodal_balanced.pth')
        print(f"\n  ‚úÖ NEW BEST! F1: {best_f1:.2f}%")
    
    torch.cuda.empty_cache()
    gc.collect()
    
    print(f"\n{'='*60}\n")

print("\n" + "="*60)
print("TRAINING COMPLETE!")
print("="*60)
print(f"Best F1: {best_f1:.2f}% | Best Acc: {best_acc:.2f}%")

AGGRESSIVE MEMORY OPTIMIZATION

‚úÖ Memory optimization configured
   GPU: NVIDIA RTX A6000
   Total Memory: 48.31 GB

Building model...
See the vulnerability report here https://nvd.nist.gov/vuln/detail/CVE-2025-32434
Using fallback CNN encoder
‚úÖ Model built successfully!
Total parameters: 124,507,853
Trainable parameters: 15,995,981

Optimizer ready!

STARTING TRAINING WITH ANTI-FREEZE PROTECTION
‚ö° FIXES: Reduced batch processing, aggressive memory clearing


EPOCH 1/10

[TRAINING]
  GRL Alpha: 0.0000
  Total batches: 6,700


  Training:   1%|‚ñé                       | 100/6700 [00:24<20:48,  5.29it/s, loss=0.1317, acc=56.0%]


  Checkpoint 100/6700 - Loss: 0.1317, Acc: 56.00%


  Training:   3%|‚ñã                       | 201/6700 [00:42<14:55,  7.25it/s, loss=0.1262, acc=60.4%]


  Checkpoint 200/6700 - Loss: 0.1262, Acc: 60.44%


  Training:   4%|‚ñà                       | 301/6700 [01:01<16:31,  6.46it/s, loss=0.1209, acc=63.9%]


  Checkpoint 300/6700 - Loss: 0.1209, Acc: 63.96%


  Training:   6%|‚ñà‚ñç                      | 401/6700 [01:20<35:31,  2.96it/s, loss=0.1188, acc=65.2%]


  Checkpoint 400/6700 - Loss: 0.1188, Acc: 65.19%


  Training:   7%|‚ñà‚ñã                      | 467/6700 [01:29<18:40,  5.56it/s, loss=0.1179, acc=66.2%]