In [None]:
import os
import torch
import numpy as np
import torchaudio
import cv2
import av
import moviepy.editor as mp
from transformers import (
    VivitImageProcessor, 
    VivitForVideoClassification,
    Wav2Vec2Processor, 
    Wav2Vec2ForSequenceClassification
)
import torchvision.transforms as T
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.nn as nn
import tempfile
import time
import sys
import gc
from collections import deque
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
                            QHBoxLayout, QLabel, QPushButton, QSlider, QStyle,
                            QFileDialog, QFrame, QProgressBar, QComboBox, QMessageBox,
                            QCheckBox)
from PyQt5.QtCore import Qt, QTimer, pyqtSignal, QThread, QMutex
from PyQt5.QtGui import QImage, QPixmap, QPainter, QColor, QPen, QBrush

In [None]:
def normalize_path(path):
    """Normalize file path by converting backslashes to forward slashes."""
    if path:
        return path.replace('\\', '/')
    return path

In [None]:
class BidirectionalCrossModalAttention(nn.Module):
    """Enhanced bidirectional cross-modal attention for superior fusion"""
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.visual_proj = nn.Linear(dim, dim)
        self.audio_proj = nn.Linear(dim, dim)
        self.output_proj = nn.Linear(dim*2, dim)  # Wider projection
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        
        # Two attention mechanisms for bidirectional flow
        self.v2a_attention = nn.MultiheadAttention(
            embed_dim=dim, num_heads=num_heads, batch_first=True)
        self.a2v_attention = nn.MultiheadAttention(
            embed_dim=dim, num_heads=num_heads, batch_first=True)
        
    def forward(self, visual, audio):
        """Apply bidirectional cross-attention between modalities"""
        # Project and normalize
        visual_proj = self.norm1(self.visual_proj(visual))
        audio_proj = self.norm2(self.audio_proj(audio))
        
        # Bidirectional attention
        v2a_output, _ = self.v2a_attention(visual_proj, audio_proj, audio_proj)
        a2v_output, _ = self.a2v_attention(audio_proj, visual_proj, visual_proj)
        
        # Combine bidirectional information
        combined = torch.cat([v2a_output, a2v_output], dim=-1)
        output = self.output_proj(combined)
        
        return output

In [None]:
class IntraModalityFusion(nn.Module):
    """Fusion module for combining RGB and Optical Flow features"""
    def __init__(self, dim):
        super().__init__()
        self.rgb_proj = nn.Linear(dim, dim)
        self.flow_proj = nn.Linear(dim, dim)
        self.fusion = nn.Sequential(
            nn.Linear(dim * 2, dim),
            nn.GELU(),
            nn.LayerNorm(dim),
            nn.Dropout(0.1)
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        
    def forward(self, rgb_features, flow_features):
        """Combine RGB and optical flow features within visual modality"""
        # Handle batched inputs - reshape if needed
        if len(rgb_features.shape) == 3:  # [batch_size, seq_len, hidden_dim]
            batch_size, seq_len, hidden_dim = rgb_features.shape
            
            # Reshape to [batch_size*seq_len, hidden_dim] for processing
            rgb_flat = rgb_features.reshape(-1, hidden_dim)
            flow_flat = flow_features.reshape(-1, hidden_dim)
            
            # Project and normalize
            rgb_proj = self.norm1(self.rgb_proj(rgb_flat))
            flow_proj = self.norm2(self.flow_proj(flow_flat))
            
            # Concatenate and fuse
            combined = torch.cat([rgb_proj, flow_proj], dim=-1)
            fused = self.fusion(combined)
            
            # Reshape back to [batch_size, seq_len, hidden_dim]
            fused = fused.reshape(batch_size, seq_len, hidden_dim)
        else:
            # Standard processing for single frames
            rgb_proj = self.norm1(self.rgb_proj(rgb_features))
            flow_proj = self.norm2(self.flow_proj(flow_features))
            
            # Concatenate and fuse
            combined = torch.cat([rgb_proj, flow_proj], dim=-1)
            fused = self.fusion(combined)
        
        return fused

In [None]:
class InterModalityFusion(nn.Module):
    """Cross-modal fusion between visual and audio modalities"""
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.visual_proj = nn.Linear(dim, dim)
        self.audio_proj = nn.Linear(dim, dim)
        
        # Bidirectional attention mechanisms
        self.v2a_attention = nn.MultiheadAttention(
            embed_dim=dim, num_heads=num_heads, batch_first=True)
        self.a2v_attention = nn.MultiheadAttention(
            embed_dim=dim, num_heads=num_heads, batch_first=True)
            
        # Fusion layer
        self.fusion = nn.Sequential(
            nn.Linear(dim * 2, dim),
            nn.GELU(),
            nn.LayerNorm(dim),
            nn.Dropout(0.2)
        )
        
        # Normalization layers
        self.norm_visual = nn.LayerNorm(dim)
        self.norm_audio = nn.LayerNorm(dim)
        
    def forward(self, visual_features, audio_features):
        """Apply bi-directional attention between visual and audio modalities"""
        # Handle batched input - reshape if needed
        if len(visual_features.shape) == 3:  # [batch_size, seq_len, hidden_dim]
            batch_size, seq_len, hidden_dim = visual_features.shape
            
            # If audio is not temporal, expand it to match visual temporal dimension
            if len(audio_features.shape) == 2:  # [batch_size, hidden_dim]
                audio_features = audio_features.unsqueeze(1).expand(-1, seq_len, -1)
                
            # Reshape to [batch_size*seq_len, hidden_dim] for processing
            visual_flat = visual_features.reshape(-1, hidden_dim)  # [batch*seq, dim]
            audio_flat = audio_features.reshape(-1, hidden_dim)    # [batch*seq, dim]
            
            # Project and normalize
            visual_proj = self.norm_visual(self.visual_proj(visual_flat))
            audio_proj = self.norm_audio(self.audio_proj(audio_flat))
            
            # Apply bidirectional attention (each frame independent)
            v2a_feat, _ = self.v2a_attention(
                visual_proj.unsqueeze(0),  # Add batch dimension for attention
                audio_proj.unsqueeze(0), 
                audio_proj.unsqueeze(0)
            )
            
            a2v_feat, _ = self.a2v_attention(
                audio_proj.unsqueeze(0),
                visual_proj.unsqueeze(0),
                visual_proj.unsqueeze(0)
            )
            
            # Remove added batch dimension
            v2a_feat = v2a_feat.squeeze(0)
            a2v_feat = a2v_feat.squeeze(0)
            
            # Combine bidirectional features
            combined = torch.cat([v2a_feat, a2v_feat], dim=-1)
            fused = self.fusion(combined)
            
            # Reshape back to [batch_size, seq_len, hidden_dim]
            fused = fused.reshape(batch_size, seq_len, hidden_dim)
        else:
            # Standard processing for non-batched inputs
            visual_proj = self.norm_visual(self.visual_proj(visual_features))
            audio_proj = self.norm_audio(self.audio_proj(audio_features))
            
            # Bidirectional attention
            v2a_features, _ = self.v2a_attention(
                visual_proj.unsqueeze(0), 
                audio_proj.unsqueeze(0), 
                audio_proj.unsqueeze(0)
            )
            
            a2v_features, _ = self.a2v_attention(
                audio_proj.unsqueeze(0),
                visual_proj.unsqueeze(0),
                visual_proj.unsqueeze(0)
            )
            
            # Remove batch dimension
            v2a_features = v2a_features.squeeze(0)
            a2v_features = a2v_features.squeeze(0)
            
            # Combine attended features
            combined = torch.cat([v2a_features, a2v_features], dim=-1)
            fused = self.fusion(combined)
        
        return fused

In [None]:
class DividedSpaceTimeAttention(nn.Module):
    """Divide attention across spatial and temporal dimensions separately - FIXED VERSION"""
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.spatial_attention = nn.MultiheadAttention(
            embed_dim=dim, num_heads=num_heads, batch_first=True)
        self.temporal_attention = nn.MultiheadAttention(
            embed_dim=dim, num_heads=num_heads, batch_first=True)
            
        # Feed-forward network for final fusion
        self.feed_forward = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(dim * 4, dim)
        )
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        
    def forward(self, x, batch_size=1, seq_len=None):
        """Apply divided attention with fixed temporal attention processing"""
        # Save original shape for later restoration
        orig_shape = x.shape
        
        # Handle explicitly provided sequence length
        if seq_len is None:
            if len(orig_shape) == 3:
                # Input is already [batch, seq, features]
                batch_size, seq_len, dim = orig_shape
            else:
                # Default to single frame if not specified
                seq_len = 1
        
        print(f"DSTA - Input shape: {orig_shape}, Using batch_size={batch_size}, seq_len={seq_len}")
        
        # Reshape to standard [batch, sequence, features] if needed
        if len(orig_shape) != 3:
            x = x.view(batch_size, seq_len, -1)
            
        # 1. Spatial attention (works for any sequence length)
        spatial_x = x
        spatial_norm = self.norm1(spatial_x)
        attended_spatial, _ = self.spatial_attention(
            spatial_norm, spatial_norm, spatial_norm
        )
        spatial_out = spatial_x + attended_spatial
        
        # 2. Temporal attention (only applied when multiple frames are present)
        if seq_len <= 1:
            print("DSTA - Skipping temporal attention as seq_len <= 1")
            temporal_out = spatial_out
        else:
            print(f"DSTA - Applying temporal attention across {seq_len} frames")
            
            # Apply feature-wise temporal attention without dimensions mismatch
            # Process each feature through temporal dimension separately
            batch, seq, feat_dim = spatial_out.shape
            
            # Take the first token representation instead of trying to process all tokens
            # This fixes the dimension mismatch issue
            cls_token = spatial_out[:, 0, :].unsqueeze(1)  # [batch, 1, feat_dim]
            
            # Apply temporal attention only to the cls token
            normalized = self.norm2(cls_token)
            attended_temporal, _ = self.temporal_attention(
                normalized, normalized, normalized
            )
            
            # Add the attended cls token back
            spatial_out[:, 0, :] = (spatial_out[:, 0, :] + attended_temporal.squeeze(1))
            temporal_out = spatial_out
        
        # 3. Feed-forward network with normalization
        output = temporal_out + self.feed_forward(self.norm3(temporal_out))
        
        # Ensure output matches original shape
        if output.shape != orig_shape:
            output = output.reshape(orig_shape)
        
        return output

In [None]:
class AdvancedTemporalFilter:
    """Adaptive temporal smoothing with class-specific handling"""
    def __init__(self, base_window_size=3, threshold=0.5, label_map=None):
        self.base_window_size = base_window_size
        self.threshold = threshold
        self.history = []
        self.label_map = label_map
        
        # Class-specific window sizes (Abuse handled via Fighting)
        self.class_windows = {
            'Explosion': 2,      # Brief events need shorter windows
            'Shooting': 2,
            'Fighting': 5,       # Sustained events need longer windows
            'Riot': 5,
            'Car Accident': 3,   # Medium events
        }
        
    def update(self, prediction_dict):
        """Update with class-adaptive window sizing"""
        # Store prediction
        self.history.append(prediction_dict)
        
        # Determine primary class
        primary_class = None
        if prediction_dict['predicted_classes']:
            primary_class = prediction_dict['predicted_classes'][0]
            # Remap Abuse to Fighting in case any still exists
            if primary_class == 'Abuse':
                primary_class = 'Fighting'
                print("DEBUG: Remapping Abuse to Fighting in temporal filter")
        
        # Get appropriate window size for this class
        window_size = self.class_windows.get(primary_class, self.base_window_size)
        
        # Maintain appropriate history length
        while len(self.history) > max(self.class_windows.values()):
            self.history.pop(0)
            
        # Not enough history for this class
        if len(self.history) < window_size:
            return prediction_dict
            
        # Get relevant history window for this class
        relevant_history = self.history[-window_size:]
        
        # Calculate weighted average (more recent predictions have higher weight)
        weights = np.linspace(0.7, 1.0, len(relevant_history))
        weights = weights / weights.sum()  # Normalize
        
        avg_predictions = np.zeros_like(relevant_history[0]['raw_predictions'])
        for i, pred in enumerate(relevant_history):
            avg_predictions += weights[i] * pred['raw_predictions']
        
        # Apply confidence boosting for consistent low-confidence predictions
        consistency = self._calculate_consistency(relevant_history)
        boosted_predictions = avg_predictions * (1.0 + 0.1 * consistency)
        
        # Threshold and create result
        smoothed_indices = np.where(boosted_predictions >= self.threshold)[0]
        
        # If no class is above threshold, take the highest probability class
        if len(smoothed_indices) == 0:
            smoothed_indices = [np.argmax(boosted_predictions)]
            
        # Map indices to class names
        if self.label_map is not None:
            smoothed_classes = []
            for idx in smoothed_indices:
                if idx < len(self.label_map):
                    label = self.label_map[idx]
                    # Remap Abuse to Fighting
                    if label == 'Abuse':
                        label = 'Fighting'
                        print("DEBUG: Remapping Abuse to Fighting in temporal filter output")
                    smoothed_classes.append(label)
        else:
            # Fall back to using the original prediction's label mapping method
            smoothed_classes = [prediction_dict['predicted_classes'][0]]  # Default
        
        return {
            'raw_predictions': boosted_predictions,
            'predicted_indices': smoothed_indices,
            'predicted_classes': smoothed_classes
        }
        
    def _calculate_consistency(self, history):
        """Calculate how consistent predictions have been"""
        if not history:
            return 0
            
        # Extract raw predictions
        all_preds = np.array([h['raw_predictions'] for h in history])
        
        # Calculate standard deviation across time for each class
        consistency = 1.0 - np.std(all_preds, axis=0)
        return consistency

In [None]:
class ViolenceDetectionPipeline:
    def __init__(self):
        # Path configurations (normalize paths)
        self.RGB_MODEL_PATH = normalize_path('F:/SRC_Bhuvaneswari/typpo/Crimenet/VisTra/Checkpoints/v1.0/best_model_acc.pt')
        self.FLOW_MODEL_PATH = normalize_path('F:/SRC_Bhuvaneswari/typpo/Crimenet/VisTra/Checkpoints/flow/best_model_acc.pt')
        self.AUDIO_MODEL_PATH = normalize_path('F:/SRC_Bhuvaneswari/typpo/Crimenet/W2V/Checkpoint/wav2vec2_epoch_10.pt')
        
        # Constants - Keep Abuse in audio model for compatibility
        self.VISUAL_LABEL_MAP = {0: 'Normal', 1: 'Explosion', 2: 'Fighting', 3: 'Car Accident', 4: 'Shooting', 5: 'Riot'}
        self.AUDIO_LABEL_MAP = {0: 'Normal', 1: 'Abuse', 2: 'Explosion', 3: 'Fighting', 4: 'Car Accident', 5: 'Shooting', 6: 'Riot'}
        
        # Combined label map for final classification with Abuse->Fighting mapping
        self.COMBINED_LABEL_MAP = {
            0: 'Normal',
            1: 'Fighting',  # Abuse mapped to Fighting
            2: 'Explosion', 
            3: 'Fighting',
            4: 'Car Accident', 
            5: 'Shooting', 
            6: 'Riot'
        }
        
        self.CLIP_LEN = 32
        self.FRAME_SAMPLE_RATE = 1
        self.SAMPLING_RATE = 16000
        self.WINDOW_SIZE_SECONDS = 3
        
        # Fixed frame batch size - FIXED: Always use 32 frames for ViViT
        self.FRAME_BATCH_SIZE = 32  # Always use 32 frames for model compatibility
        
        # Default fusion method (can be changed via UI)
        self.fusion_method = 'proposed'  # Options: 'advanced', 'majority', 'rgb', 'flow', 'audio', 'proposed'
        
        # Optimization options
        self.low_memory_mode = False  # Enable for devices with limited VRAM
        self.use_mixed_precision = True  # Use FP16 for inference
        self.audio_chunk_size = 32000  # Process audio in chunks (2 seconds)
        
        # Adaptive thresholding parameters
        self.base_threshold = 0.5
        self.threshold_range = (0.35, 0.65)  # Min/max threshold range
        
        # Initialize multi-scale detection
        self.multi_scale_enabled = True
        self.scale_weights = {
            'Explosion': (0.6, 0.3, 0.1),  # (short, medium, long) windows
            'Shooting': (0.6, 0.3, 0.1),
            'Fighting': (0.2, 0.3, 0.5),   # Will also handle remapped Abuse
            'Riot': (0.2, 0.3, 0.5),
            'Car Accident': (0.3, 0.5, 0.2)
        }
        
        # Temporal filter with class-specific settings
        self.temporal_filter = AdvancedTemporalFilter(
            base_window_size=3, 
            threshold=self.base_threshold,
            label_map=self.COMBINED_LABEL_MAP
        )
        
        # Diagnostic mode
        self.diagnostics_enabled = False
        
        # Context tracking
        self.context_events = {}
        
        # Set device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device} (Low memory mode: {self.low_memory_mode}, Mixed precision: {self.use_mixed_precision})")
        
        # Initialize models
        self._init_models()
        
    def _init_models(self):
        print("\n==== LOADING MODELS ====")
        
        # Clear CUDA cache before starting
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        # RGB ViVit model
        print("Loading RGB ViVit model...")
        self.rgb_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2", do_rescale=None, offset=None)
        self.rgb_model = VivitForVideoClassification.from_pretrained(
            "google/vivit-b-16x2",
            num_labels=len(self.VISUAL_LABEL_MAP),
            ignore_mismatched_sizes=True
        )
        
        # Use robust model loading
        self._load_model_weights(self.rgb_model, self.RGB_MODEL_PATH, "RGB ViVit")
        
        self.rgb_model.to(self.device)
        self.rgb_model.eval()
        
        # Enable gradient checkpointing for VRAM efficiency
        if hasattr(self.rgb_model, 'gradient_checkpointing_enable'):
            self.rgb_model.gradient_checkpointing_enable()
            
        print("RGB model loaded successfully")
        
        # Clear cache after loading RGB model
        if torch.cuda.is_available() and self.low_memory_mode:
            torch.cuda.empty_cache()
            
        # Optical Flow ViVit model
        print("\nLoading Flow ViVit model...")
        self.flow_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2", do_rescale=None, offset=None)
        self.flow_model = VivitForVideoClassification.from_pretrained(
            "google/vivit-b-16x2",
            num_labels=len(self.VISUAL_LABEL_MAP),
            ignore_mismatched_sizes=True
        )
        
        # Use robust model loading
        self._load_model_weights(self.flow_model, self.FLOW_MODEL_PATH, "Flow ViVit")
            
        self.flow_model.to(self.device)
        self.flow_model.eval()
        
        # Enable gradient checkpointing for VRAM efficiency
        if hasattr(self.flow_model, 'gradient_checkpointing_enable'):
            self.flow_model.gradient_checkpointing_enable()
            
        print("Flow model loaded successfully")
        
        # Clear cache after loading Flow model
        if torch.cuda.is_available() and self.low_memory_mode:
            torch.cuda.empty_cache()
        
        # Audio Wav2Vec model - Keep 7 classes for model compatibility
        print("\nLoading Audio Wav2Vec2 model...")
        self.audio_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
        self.audio_model = Wav2Vec2ForSequenceClassification.from_pretrained(
            "facebook/wav2vec2-base",
            num_labels=7,  # Keep 7 labels for model compatibility
            ignore_mismatched_sizes=True
        )
        
        try:
            # Debug print model structure
            print(f"Expected Audio model output shape: 7 classes (including Abuse)")
            state_dict = torch.load(self.AUDIO_MODEL_PATH, map_location=self.device)
            
            # Use robust model loading
            self._load_model_weights(self.audio_model, self.AUDIO_MODEL_PATH, "Audio Wav2Vec2")
        except Exception as e:
            print(f"ERROR loading Audio model weights: {e}")
            print("Model will use default initialization!")
            
        self.audio_model.to(self.device)
        self.audio_model.eval()
        print("Audio model loaded successfully")
        
        # Clear cache after loading Audio model
        if torch.cuda.is_available() and self.low_memory_mode:
            torch.cuda.empty_cache()
        
        # Initialize fusion components with enhanced attention
        self._init_fusion_components()
        
        # Final cache clear after all models loaded
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    def _load_model_weights(self, model, path, model_name):
        """Robust model weight loading with error checking"""
        try:
            # Try to load state dict directly
            state_dict = torch.load(path, map_location=self.device)
            
            # Print model structure before loading
            print(f"\nModel structure for {model_name}:")
            model_keys = set(model.state_dict().keys())
            print(f"  Model has {len(model_keys)} keys")
            
            # Check if it's a complete state dict or just weights
            if isinstance(state_dict, dict) and 'state_dict' in state_dict:
                state_dict = state_dict['state_dict']
            
            # Print checkpoint structure
            checkpoint_keys = set(state_dict.keys())
            print(f"  Checkpoint has {len(checkpoint_keys)} keys")
            
            # Find missing and unexpected keys
            missing = model_keys - checkpoint_keys
            unexpected = checkpoint_keys - model_keys
            matching = model_keys.intersection(checkpoint_keys)
            
            if missing:
                print(f"  Missing keys: {', '.join(list(missing)[:5])}{'...' if len(missing) > 5 else ''}")
            if unexpected:
                print(f"  Unexpected keys: {', '.join(list(unexpected)[:5])}{'...' if len(unexpected) > 5 else ''}")
            print(f"  {len(matching)} matching keys")
            
            # Load with strict=False to handle mismatches
            result = model.load_state_dict(state_dict, strict=False)
            print(f"  Loaded with {len(result.missing_keys)} missing and {len(result.unexpected_keys)} unexpected keys")
            return True
        except Exception as e:
            print(f"ERROR loading {model_name} weights: {e}")
            print("Model will use default initialization!")
            return False
    
    def _init_fusion_components(self):
        print("\nInitializing fusion components with all methods...")
        
        hidden_dim = 768
        
        # Original components for existing fusion methods
        self.cross_attention = BidirectionalCrossModalAttention(
            dim=hidden_dim,
            num_heads=8
        ).to(self.device)
        
        self.temporal_attention = nn.MultiheadAttention(
            embed_dim=hidden_dim, 
            num_heads=8, 
            batch_first=True
        ).to(self.device)
        
        self.visual_fusion = nn.Sequential(
            nn.Linear(hidden_dim*2, hidden_dim),
            nn.GELU(),
            nn.Dropout(0.2)
        ).to(self.device)
        
        # New components for proposed methodology
        self.intra_modality_fusion = IntraModalityFusion(
            dim=hidden_dim
        ).to(self.device)
        
        self.inter_modality_fusion = InterModalityFusion(
            dim=hidden_dim,
            num_heads=8
        ).to(self.device)
        
        self.space_time_attention = DividedSpaceTimeAttention(
            dim=hidden_dim,
            num_heads=8
        ).to(self.device)
        
        # Final classifier with GELU activation
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim*2, 512),
            nn.GELU(),
            nn.LayerNorm(512),
            nn.Dropout(0.3),
            nn.Linear(512, 7),  # 7 classes output for model compatibility
            nn.Sigmoid()
        ).to(self.device)
        
        print("All fusion components initialized successfully")
        
        if torch.cuda.is_available() and self.low_memory_mode:
            torch.cuda.empty_cache()
    
    def extract_frames(self, video_path):
        """Extract RGB frames from video with explicit path normalization."""
        video_path = normalize_path(video_path)
        frames = []
        transform = T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor()
        ])
        
        try:
            container = av.open(video_path)
            total_frames = container.streams.video[0].frames
            
            # Calculate indices to sample
            if total_frames >= self.CLIP_LEN * self.FRAME_SAMPLE_RATE:
                indices = np.linspace(0, total_frames - 1, self.CLIP_LEN, dtype=int)
            else:
                indices = np.arange(0, self.CLIP_LEN * self.FRAME_SAMPLE_RATE, self.FRAME_SAMPLE_RATE) % total_frames
                
            # Extract frames
            container.seek(0)
            for i, frame in enumerate(container.decode(video=0)):
                if i in indices:
                    img = frame.to_image()
                    frames.append(transform(img))
                    if len(frames) == self.CLIP_LEN:
                        break
                        
            container.close()
            
            # Process frames for model input
            if len(frames) > 0:
                frames_tensor = torch.stack(frames)
                frames_numpy = [frame.permute(1, 2, 0).numpy() for frame in frames]
                return frames_numpy
            
        except Exception as e:
            print(f"Error extracting frames from {video_path}: {e}")
        
        return None
    
    def extract_frames_batch(self, video_path, batch_size=None):
        """Extract RGB frames from video with batching for temporal processing.
        FIXED: Always uses CLIP_LEN (32) frames for model compatibility."""
        # Always use CLIP_LEN frames for ViViT
        batch_size = self.CLIP_LEN
            
        video_path = normalize_path(video_path)
        all_frames = []
        transform = T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor()
        ])
        
        try:
            container = av.open(video_path)
            stream = container.streams.video[0]
            total_frames = stream.frames
            fps = stream.average_rate
            
            print(f"Video has {total_frames} frames at {fps} FPS")
            
            # Calculate optimal step size for the requested batch size
            step = max(1, int(total_frames / (batch_size * 1.5)))  # Extract 1.5x for safety
            print(f"Using step size of {step} frames")
            
            batch_frames = []
            
            container.seek(0)
            for i, frame in enumerate(container.decode(video=0)):
                if i % step == 0:
                    img = frame.to_image()
                    batch_frames.append(transform(img))
                    
                    if len(batch_frames) >= batch_size:
                        break
            
            container.close()
            
            # Handle case where we didn't get enough frames
            while len(batch_frames) < batch_size:
                if len(batch_frames) == 0:
                    print("Could not extract any frames, creating empty frame")
                    batch_frames.append(torch.zeros(3, 224, 224))
                else:
                    print(f"Only got {len(batch_frames)} frames, duplicating last frame")
                    batch_frames.append(batch_frames[-1])
            
            # Convert to numpy for model input
            frames_numpy = [frame.permute(1, 2, 0).numpy() for frame in batch_frames]
            
            print(f"Successfully extracted batch of {len(frames_numpy)} frames")
            return frames_numpy
            
        except Exception as e:
            print(f"Error extracting frame batch from {video_path}: {e}")
            import traceback
            traceback.print_exc()
        
        return None
    
    def compute_optical_flow(self, video_path):
        """
        Compute optical flow frames from video with path normalization and enhancements.
        """
        video_path = normalize_path(video_path)
        frames = []
        transform = T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor()
        ])
        
        try:
            cap = cv2.VideoCapture(video_path)
            if not cap.isOpened():
                print(f"Could not open video {video_path}")
                return None
                
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            
            # Calculate indices to sample
            if total_frames >= self.CLIP_LEN * self.FRAME_SAMPLE_RATE:
                indices = np.linspace(0, total_frames - 1, self.CLIP_LEN + 1, dtype=int)
            else:
                indices = np.arange(0, self.CLIP_LEN * self.FRAME_SAMPLE_RATE + 1, self.FRAME_SAMPLE_RATE) % total_frames
            
            # Read first frame
            ret, prev_frame = cap.read()
            if not ret:
                cap.release()
                return None
                
            prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
            
            # Process frames
            frame_idx = 1
            while True:
                ret, frame = cap.read()
                if not ret or len(frames) >= self.CLIP_LEN:
                    break
                    
                if frame_idx in indices:
                    # Convert to grayscale
                    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
                    
                    # Calculate optical flow
                    flow = cv2.calcOpticalFlowFarneback(
                        prev_gray, gray, None, 0.5, 3, 15, 3, 5, 1.2, 0
                    )
                    
                    # Convert flow to RGB image
                    mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
                    hsv = np.zeros((frame.shape[0], frame.shape[1], 3), dtype=np.uint8)
                    hsv[..., 0] = ang * 180 / np.pi / 2
                    hsv[..., 1] = 255
                    hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
                    flow_rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
                    
                    # Transform and append
                    flow_pil = T.ToPILImage()(flow_rgb)
                    flow_tensor = transform(flow_pil)
                    frames.append(flow_tensor)
                    
                    # Update previous frame
                    prev_gray = gray
                
                frame_idx += 1
            
            cap.release()
            
            # Handle case where we didn't get enough frames
            while len(frames) < self.CLIP_LEN:
                # Duplicate the last frame if needed
                frames.append(frames[-1] if frames else torch.zeros(3, 224, 224))
                
            # Convert to numpy for model input
            frames_numpy = [frame.permute(1, 2, 0).numpy() for frame in frames]
            return frames_numpy
            
        except Exception as e:
            print(f"Error computing optical flow from {video_path}: {e}")
            return None
    
    def compute_optical_flow_batch(self, video_path, batch_size=None):
        """Compute optical flow frames from video with batching for temporal processing.
        FIXED: Always uses CLIP_LEN (32) frames for model compatibility."""
        # Always use CLIP_LEN frames for ViViT
        batch_size = self.CLIP_LEN
            
        video_path = normalize_path(video_path)
        frames = []
        transform = T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor()
        ])
        
        try:
            cap = cv2.VideoCapture(video_path)
            if not cap.isOpened():
                print(f"Could not open video {video_path}")
                return None
                
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            fps = cap.get(cv2.CAP_PROP_FPS)
            
            print(f"Video has {total_frames} frames at {fps} FPS")
            
            # Calculate optimal step for the requested batch size
            step = max(1, int(total_frames / (batch_size * 1.5)))  # Extract 1.5x for safety
            print(f"Using step size of {step} frames")
            
            # Read first frame
            ret, prev_frame = cap.read()
            if not ret:
                cap.release()
                return None
                
            prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
            
            # Process frames
            frame_idx = 1
            while frame_idx < total_frames:
                # Skip frames according to step
                if frame_idx % step != 0:
                    ret = cap.grab()  # Just grab frame without decoding
                    if not ret:
                        break
                    frame_idx += 1
                    continue
                
                # Read and process frame
                ret, frame = cap.retrieve() if cap.grab() else (False, None)
                if not ret:
                    break
                
                # Convert to grayscale
                gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
                
                # Calculate optical flow
                flow = cv2.calcOpticalFlowFarneback(
                    prev_gray, gray, None, 0.5, 3, 15, 3, 5, 1.2, 0
                )
                
                # Convert flow to RGB image
                mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
                hsv = np.zeros((frame.shape[0], frame.shape[1], 3), dtype=np.uint8)
                hsv[..., 0] = ang * 180 / np.pi / 2
                hsv[..., 1] = 255
                hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
                flow_rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
                
                # Transform and append
                flow_pil = T.ToPILImage()(flow_rgb)
                flow_tensor = transform(flow_pil)
                frames.append(flow_tensor)
                
                # Update previous frame
                prev_gray = gray
                
                # Check if we have enough frames
                if len(frames) >= batch_size:
                    break
                    
                frame_idx += 1
            
            cap.release()
            
            # Handle case where we didn't get enough frames
            while len(frames) < batch_size:
                # Duplicate the last frame if needed
                if len(frames) == 0:
                    print("Could not extract any flow frames, creating empty frame")
                    frames.append(torch.zeros(3, 224, 224))
                else:
                    print(f"Only got {len(frames)} flow frames, duplicating last frame")
                    frames.append(frames[-1])
                
            # Convert to numpy for model input
            frames_numpy = [frame.permute(1, 2, 0).numpy() for frame in frames]
            
            print(f"Successfully extracted batch of {len(frames_numpy)} flow frames")
            return frames_numpy
            
        except Exception as e:
            print(f"Error computing optical flow batch from {video_path}: {e}")
            import traceback
            traceback.print_exc()
            return None
    
    def extract_audio(self, video_path):
        """Extract audio from video file with robust resource management."""
        video_path = normalize_path(video_path)
        temp_dir = tempfile.mkdtemp()
        temp_audio_path = os.path.join(temp_dir, "temp_audio.wav")
        
        try:
            # Use context manager to ensure video is closed properly
            with mp.VideoFileClip(video_path) as video:
                if video.audio is None:
                    print(f"Warning: No audio track in video {video_path}")
                    # Create silent audio of the right duration
                    silence = np.zeros(int(video.duration * self.SAMPLING_RATE))
                    return silence
                    
                video.audio.write_audiofile(temp_audio_path, fps=self.SAMPLING_RATE, verbose=False, logger=None)
            
            # Load audio
            waveform, sample_rate = torchaudio.load(temp_audio_path)
            
            # Resample if necessary
            if sample_rate != self.SAMPLING_RATE:
                resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.SAMPLING_RATE)
                waveform = resampler(waveform)
                
            # Convert to mono and to numpy
            waveform = waveform.mean(dim=0).numpy()
            
            return waveform
            
        except Exception as e:
            print(f"Error extracting audio from {video_path}: {e}")
            return None
        
        finally:
            # Cleanup in finally block to ensure it always happens
            try:
                if os.path.exists(temp_audio_path):
                    os.remove(temp_audio_path)
                if os.path.exists(temp_dir):
                    os.rmdir(temp_dir)
            except Exception as cleanup_error:
                print(f"Cleanup error: {cleanup_error}")
    
    def process_rgb_frames(self, frames):
        """Process RGB frames through ViViT model and extract features."""
        if frames is None or len(frames) < self.CLIP_LEN:
            print("Insufficient RGB frames for processing")
            return None
        
        # Use mixed precision for inference if enabled
        with torch.cuda.amp.autocast(enabled=self.use_mixed_precision):
            with torch.no_grad():
                # Process frames with the RGB model
                inputs = self.rgb_processor(frames, return_tensors="pt")
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                outputs = self.rgb_model(**inputs, output_hidden_states=True)
                
                # Extract features from the last hidden state
                rgb_features = outputs.hidden_states[-1][:, 0, :]  # Use CLS token
                
        return rgb_features
    
    def process_rgb_frames_batch(self, frames):
        """Process RGB frames in batches to preserve temporal information."""
        if frames is None or len(frames) < 1:
            print("Insufficient RGB frames for batch processing")
            return None
        
        # Use mixed precision for inference if enabled
        with torch.cuda.amp.autocast(enabled=self.use_mixed_precision):
            with torch.no_grad():
                # Process frames with the RGB model - FIXED: Add interpolate flag
                inputs = self.rgb_processor(frames, return_tensors="pt", interpolate_pos_encoding=True)
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                outputs = self.rgb_model(**inputs, output_hidden_states=True)
                
                # Extract features from the last hidden state
                rgb_features = outputs.hidden_states[-1]
                
                print(f"RGB features extracted: shape {rgb_features.shape}")
                
        return rgb_features
        
    def process_flow_frames(self, frames):
        """Process optical flow frames through ViViT model and extract features."""
        if frames is None or len(frames) < self.CLIP_LEN:
            print("Insufficient optical flow frames for processing")
            return None
            
        # Use mixed precision for inference if enabled
        with torch.cuda.amp.autocast(enabled=self.use_mixed_precision):
            with torch.no_grad():
                # Process frames with the flow model
                inputs = self.flow_processor(frames, return_tensors="pt")
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                outputs = self.flow_model(**inputs, output_hidden_states=True)
                
                # Extract features
                flow_features = outputs.hidden_states[-1][:, 0, :]  # Use CLS token
                
        return flow_features
    
    def process_flow_frames_batch(self, frames):
        """Process optical flow frames in batches to preserve temporal information."""
        if frames is None or len(frames) < 1:
            print("Insufficient flow frames for batch processing")
            return None
            
        # Use mixed precision for inference if enabled
        with torch.cuda.amp.autocast(enabled=self.use_mixed_precision):
            with torch.no_grad():
                # Process frames with the flow model - FIXED: Add interpolate flag
                inputs = self.flow_processor(frames, return_tensors="pt", interpolate_pos_encoding=True)
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                outputs = self.flow_model(**inputs, output_hidden_states=True)
                
                # Extract features - preserve all tokens for temporal info
                flow_features = outputs.hidden_states[-1]
                
                print(f"Flow features extracted: shape {flow_features.shape}")
                
        return flow_features
        
    def process_audio(self, audio_waveform):
        """Process audio through Wav2Vec2 model and extract features with chunking."""
        if audio_waveform is None:
            print("Audio waveform is None")
            return None
            
        # Process audio in chunks to manage memory usage
        with torch.cuda.amp.autocast(enabled=self.use_mixed_precision):
            with torch.no_grad():
                if len(audio_waveform) > self.audio_chunk_size and self.low_memory_mode:
                    # Process in chunks
                    chunk_features = []
                    # Calculate how many chunks we need
                    num_chunks = int(np.ceil(len(audio_waveform) / self.audio_chunk_size))
                    
                    for i in range(num_chunks):
                        start_idx = i * self.audio_chunk_size
                        end_idx = min(start_idx + self.audio_chunk_size, len(audio_waveform))
                        chunk = audio_waveform[start_idx:end_idx]
                        
                        # Process chunk
                        inputs = self.audio_processor(
                            chunk, 
                            sampling_rate=self.SAMPLING_RATE, 
                            return_tensors="pt", 
                            padding=True
                        )
                        inputs = {k: v.to(self.device) for k, v in inputs.items()}
                        outputs = self.audio_model(**inputs, output_hidden_states=True)
                        
                        # Extract features from this chunk
                        chunk_feature = outputs.hidden_states[-1][:, 0, :]
                        chunk_features.append(chunk_feature)
                        
                        # Clear cache after each chunk if in low memory mode
                        if torch.cuda.is_available() and self.low_memory_mode:
                            torch.cuda.empty_cache()
                    
                    # Average the features from all chunks
                    audio_features = torch.mean(torch.stack(chunk_features, dim=0), dim=0)
                else:
                    # Process the entire audio at once
                    inputs = self.audio_processor(
                        audio_waveform, 
                        sampling_rate=self.SAMPLING_RATE, 
                        return_tensors="pt", 
                        padding=True
                    )
                    inputs = {k: v.to(self.device) for k, v in inputs.items()}
                    outputs = self.audio_model(**inputs, output_hidden_states=True)
                    
                    # Extract features
                    audio_features = outputs.hidden_states[-1][:, 0, :]  # Use CLS token
                
        return audio_features
        
    def apply_bidirectional_cross_attention(self, rgb_features, flow_features, audio_features):
        """Apply bidirectional cross-attention fusion between visual and audio features."""
        if rgb_features is None or flow_features is None or audio_features is None:
            print("Missing features for cross-attention fusion")
            return None
            
        # Use mixed precision for fusion operations
        with torch.cuda.amp.autocast(enabled=self.use_mixed_precision):
            with torch.no_grad():
                # 1. Combine RGB and flow features to get visual representation
                combined_visual = torch.cat([rgb_features, flow_features], dim=1)
                visual_features = self.visual_fusion(combined_visual)
                
                # 2. Apply bidirectional cross-attention between visual and audio
                cross_attended = self.cross_attention(visual_features, audio_features)
                
                # 3. Combine the visual features with cross-attended features
                multimodal_features = torch.cat([visual_features, cross_attended], dim=1)
                
        return multimodal_features
    
    def apply_intra_modality_fusion(self, rgb_features, flow_features):
        """Apply intra-modality fusion to batched RGB and optical flow features."""
        if rgb_features is None or flow_features is None:
            print("Missing features for intra-modality fusion")
            return None
            
        # Forward through the IntraModalityFusion module
        # This handles batched inputs automatically
        visual_features = self.intra_modality_fusion(rgb_features, flow_features)
        
        return visual_features
    
    def apply_inter_modality_fusion(self, visual_features, audio_features):
        """Apply inter-modality fusion to batched visual and audio features."""
        if visual_features is None or audio_features is None:
            print("Missing features for inter-modality fusion")
            return None
            
        # Forward through the InterModalityFusion module
        # This handles batched inputs automatically 
        multimodal_features = self.inter_modality_fusion(visual_features, audio_features)
        
        return multimodal_features
    
    def apply_divided_space_time_attention(self, features):
        """Apply divided space-time attention to properly use temporal attention with batched frames."""
        if features is None:
            print("Features for divided space-time attention are None")
            return None
            
        # Get batch size and sequence length
        batch_size = 1  # Single video
        
        if len(features.shape) == 3:
            # Input is already [batch, seq, features]
            _, seq_len, _ = features.shape
        else:
            # Need to reshape
            seq_len = 1
            print(f"Warning: Reshaping features of shape {features.shape} for space-time attention")
            
        # Apply space-time attention with explicit batch size and sequence length
        attended_features = self.space_time_attention(features, batch_size=batch_size, seq_len=seq_len)
        
        return attended_features
    
    def apply_proposed_methodology(self, rgb_features, flow_features, audio_features):
        """Apply the proposed methodology with intra/inter fusion and space-time attention"""
        if rgb_features is None or flow_features is None or audio_features is None:
            print("Missing features for proposed methodology")
            return None
            
        # Use mixed precision for fusion operations
        with torch.cuda.amp.autocast(enabled=self.use_mixed_precision):
            with torch.no_grad():
                # Print input feature shapes for debugging
                print(f"RGB features shape: {rgb_features.shape}")
                print(f"Flow features shape: {flow_features.shape}")
                print(f"Audio features shape: {audio_features.shape}")
                
                # Step 1: Intra-modality fusion for visual data (RGB + Flow)
                visual_features = self.apply_intra_modality_fusion(rgb_features, flow_features)
                print(f"After intra-modality fusion: {visual_features.shape}")
                
                # Step 2: Inter-modality fusion between visual and audio
                multimodal_features = self.apply_inter_modality_fusion(visual_features, audio_features)
                print(f"After inter-modality fusion: {multimodal_features.shape}")
                
                # Step 3: Apply divided space-time attention with batched frames
                attended_features = self.apply_divided_space_time_attention(multimodal_features)
                print(f"After divided space-time attention: {attended_features.shape}")
                
                # Step 4: Format for classifier (maintain compatibility with existing model)
                # Average across frames to get a single feature vector per batch
                if len(attended_features.shape) == 3:  # [batch_size, seq_len, hidden_dim]
                    # FIXED: Use only the CLS token instead of averaging all tokens
                    attended_features = attended_features[:, 0, :]
                    visual_features = visual_features[:, 0, :]
                
                final_features = torch.cat([visual_features, attended_features], dim=1)
                print(f"Final features shape for classification: {final_features.shape}")
                
        return final_features
        
    def classify(self, features):
        """Apply multi-label classification to the processed features."""
        if features is None:
            print("Features for classification are None")
            return None
            
        # Use mixed precision for classification
        with torch.cuda.amp.autocast(enabled=self.use_mixed_precision):
            with torch.no_grad():
                # Apply classification
                outputs = self.classifier(features)
                
        return outputs
    
    def get_adaptive_threshold(self, scene_complexity=0.5, motion_level=0.5, audio_clarity=0.5):
        """Calculate adaptive threshold based on content analysis."""
        # Start with base threshold
        threshold = self.base_threshold
        
        # Adjust for difficult scenes
        if scene_complexity > 0.7:  # Complex scene (many objects, poor lighting)
            threshold -= 0.1        # Lower threshold to avoid missing events
        
        # Adjust for high motion
        if motion_level > 0.8:      # High motion scenes
            threshold += 0.05       # Increase threshold to reduce false positives
            
        # Adjust for audio quality
        if audio_clarity < 0.4:     # Poor audio quality
            threshold += 0.1        # Increase threshold when audio is unreliable
            
        # Ensure threshold stays within range
        return max(self.threshold_range[0], min(self.threshold_range[1], threshold))
    
    def apply_multi_scale_detection(self, predictions, confidence_boost=None):
        """Apply multi-scale temporal analysis to predictions."""
        if not self.multi_scale_enabled:
            return predictions
            
        # Get original predictions
        raw_preds = predictions['raw_predictions']
        
        # Apply confidence boost from context if available
        if confidence_boost:
            for class_name, boost in confidence_boost.items():
                # Find index for this class
                for idx, label in self.COMBINED_LABEL_MAP.items():
                    if label == class_name:
                        raw_preds[idx] += boost
        
        # Apply different weighting based on class
        for class_idx, class_name in self.COMBINED_LABEL_MAP.items():
            if class_name in self.scale_weights:
                # Get weights for this class
                short_w, med_w, long_w = self.scale_weights[class_name]
                
                # Weighted value based on expected duration
                weighted_pred = raw_preds[class_idx] * (short_w + med_w + long_w)
                
                # Update prediction
                raw_preds[class_idx] = min(1.0, weighted_pred)
        
        # Recalculate predicted indices based on threshold
        threshold = self.get_adaptive_threshold()
        predicted_indices = np.where(raw_preds >= threshold)[0]
        
        # If no class is above threshold, take the highest probability class
        if len(predicted_indices) == 0:
            predicted_indices = [np.argmax(raw_preds)]
            
        # Get class names with Abuse remapping
        predicted_classes = []
        for idx in predicted_indices:
            label = self.COMBINED_LABEL_MAP.get(idx, "Unknown")
            # Remap Abuse to Fighting if it ever shows up
            if label == "Abuse":
                label = "Fighting"
                print("DEBUG: Remapped Abuse to Fighting in multi-scale detection")
            predicted_classes.append(label)
        
        # Return updated predictions
        return {
            'raw_predictions': raw_preds,
            'predicted_indices': predicted_indices,
            'predicted_classes': predicted_classes
        }
    
    def update_context(self, predictions):
        """Track contextual information for event relationships."""
        if not predictions or 'predicted_classes' not in predictions:
            return None, 0
            
        context_boost = {}
        duration = 0
        
        # Extract event types
        events = predictions['predicted_classes']
        
        # Define contextual relationships and boosts
        if 'Explosion' in events:
            # Explosion often followed by fire/smoke/injury
            context_boost = {'Fighting': 0.1, 'Car Accident': 0.15}
            duration = 30  # Boost for next 30 seconds
            
        elif 'Car Accident' in events:
            # Car accidents may involve subsequent fighting
            context_boost = {'Fighting': 0.05}
            duration = 20  # Boost for next 20 seconds
            
        return context_boost, duration
    
    def run_diagnostics(self, temp_path):
        """Run diagnostics on a video segment to understand prediction errors"""
        print("\n===== PREDICTION DIAGNOSTICS =====")
        
        # Extract inputs
        rgb_frames = self.extract_frames(temp_path)
        flow_frames = self.compute_optical_flow(temp_path)
        audio = self.extract_audio(temp_path)
        
        # Get raw predictions from individual models
        with torch.cuda.amp.autocast(enabled=self.use_mixed_precision):
            with torch.no_grad():
                # RGB
                rgb_inputs = self.rgb_processor(rgb_frames, return_tensors="pt")
                rgb_inputs = {k: v.to(self.device) for k, v in rgb_inputs.items()}
                rgb_outputs = self.rgb_model(**rgb_inputs)
                rgb_probs = torch.softmax(rgb_outputs.logits, dim=-1)[0].cpu().numpy()
                
                # Flow
                flow_inputs = self.flow_processor(flow_frames, return_tensors="pt")
                flow_inputs = {k: v.to(self.device) for k, v in flow_inputs.items()}
                flow_outputs = self.flow_model(**flow_inputs)
                flow_probs = torch.softmax(flow_outputs.logits, dim=-1)[0].cpu().numpy()
                
                # Audio
                audio_inputs = self.audio_processor(
                    audio, sampling_rate=self.SAMPLING_RATE, return_tensors="pt", padding=True
                )
                audio_inputs = {k: v.to(self.device) for k, v in audio_inputs.items()}
                audio_outputs = self.audio_model(**audio_inputs)
                audio_probs = torch.softmax(audio_outputs.logits, dim=-1)[0].cpu().numpy()
        
        # Print raw probabilities for each model
        print("RGB Probabilities:")
        for i, label in self.VISUAL_LABEL_MAP.items():
            print(f"  {label}: {rgb_probs[i]:.4f}")
        print("\nFlow Probabilities:")
        for i, label in self.VISUAL_LABEL_MAP.items():
            print(f"  {label}: {flow_probs[i]:.4f}")
        print("\nAudio Probabilities:")
        for i, label in self.AUDIO_LABEL_MAP.items():
            # Add special note if this is abuse class being remapped
            if i == 1:  # Abuse index
                print(f"  {label} (remapped to Fighting): {audio_probs[i]:.4f}")
            else:
                print(f"  {label}: {audio_probs[i]:.4f}")
        
        # Calculate majority vote as a simple reference
        rgb_class = self.VISUAL_LABEL_MAP[np.argmax(rgb_probs)]
        flow_class = self.VISUAL_LABEL_MAP[np.argmax(flow_probs)]
        
        # Handle special remapping for audio class
        audio_idx = np.argmax(audio_probs)
        audio_class = self.AUDIO_LABEL_MAP[audio_idx]
        if audio_idx == 1:  # Abuse index
            print(f"NOTE: Audio detected 'Abuse' which is being remapped to 'Fighting'")
            audio_class = "Fighting"  # Remap
        
        print(f"\nModel predictions: RGB={rgb_class}, Flow={flow_class}, Audio={audio_class}")
        
        # Try each fusion method for comparison
        current_method = self.fusion_method
        self.fusion_method = 'majority'
        majority_result = self.majority_vote_prediction(rgb_class, flow_class, audio_class)
        print(f"Majority vote result: {majority_result['predicted_classes'][0]}")
        
        # Reset fusion method
        self.fusion_method = current_method
        
        # Return diagnostics for reference
        return {
            "rgb": (rgb_class, rgb_probs),
            "flow": (flow_class, flow_probs),
            "audio": (audio_class, audio_probs),
            "majority": majority_result['predicted_classes'][0]
        }
    
    def majority_vote_prediction(self, rgb_class, flow_class, audio_class):
        """Simple majority voting fusion as a fallback"""
        # Process abuse remapping
        if audio_class == "Abuse":
            audio_class = "Fighting"
            print("DEBUG: Remapped Abuse to Fighting in majority voting")
            
        # Count occurrences of each class
        votes = {}
        for cls in [rgb_class, flow_class, audio_class]:
            votes[cls] = votes.get(cls, 0) + 1
        
        # Find maximum votes
        max_votes = max(votes.values())
        winners = [cls for cls, count in votes.items() if count == max_votes]
        
        # Return the winner (or first winner if tie)
        winner = winners[0]
        
        # Create a result structure similar to other methods
        winner_idx = -1
        for idx, label in self.COMBINED_LABEL_MAP.items():
            if label == winner:
                winner_idx = idx
                break
        
        if winner_idx == -1:
            # Handle case where winner isn't in combined map
            winner_idx = 0  # Default to Normal
            winner = self.COMBINED_LABEL_MAP[0]
        
        # Create fake probabilities with winner at 0.9
        raw_preds = np.ones(len(self.COMBINED_LABEL_MAP)) * 0.1 / (len(self.COMBINED_LABEL_MAP) - 1)
        raw_preds[winner_idx] = 0.9
        
        return {
            'raw_predictions': raw_preds,
            'predicted_indices': [winner_idx],
            'predicted_classes': [winner]
        }
    
    def _create_result_from_single(self, pred_class, label_map):
        """Create a standardized result from a single model prediction"""
        # Remap Abuse to Fighting if present
        if pred_class == "Abuse":
            pred_class = "Fighting"
            print("DEBUG: Remapped Abuse to Fighting in single-modality result")
            
        # Find the equivalent in the combined map
        combined_idx = -1
        for idx, label in self.COMBINED_LABEL_MAP.items():
            if label == pred_class:
                combined_idx = idx
                break
        
        # Create fake probabilities with prediction at 0.9
        raw_preds = np.ones(len(self.COMBINED_LABEL_MAP)) * 0.1 / (len(self.COMBINED_LABEL_MAP) - 1)
        if combined_idx >= 0:
            raw_preds[combined_idx] = 0.9
        
        return {
            'raw_predictions': raw_preds,
            'predicted_indices': [combined_idx if combined_idx >= 0 else 0],
            'predicted_classes': [pred_class]
        }
    
    def predict(self, video_path, threshold=None):
        """Full prediction pipeline for a video with normalized paths and batched frame processing."""
        video_path = normalize_path(video_path)
        print(f"\n==== PREDICTING VIOLENCE IN: {video_path} ====")
        
        try:
            # Extract inputs with fixed CLIP_LEN (32) frames for model compatibility
            print(f"Using batched frame extraction with {self.FRAME_BATCH_SIZE} frames for proposed fusion method")
            rgb_frames = self.extract_frames_batch(video_path, self.FRAME_BATCH_SIZE)
            flow_frames = self.compute_optical_flow_batch(video_path, self.FRAME_BATCH_SIZE)
                
            # Audio extraction is the same for all methods
            audio = self.extract_audio(video_path)
            
            if rgb_frames is None or flow_frames is None or audio is None:
                print("Error: Could not process inputs")
                return None
                
            # Get individual model predictions first
            with torch.cuda.amp.autocast(enabled=self.use_mixed_precision):
                with torch.no_grad():
                    # RGB prediction - FIXED: Add interpolate flag
                    rgb_inputs = self.rgb_processor(rgb_frames, return_tensors="pt", interpolate_pos_encoding=True)
                    rgb_inputs = {k: v.to(self.device) for k, v in rgb_inputs.items()}
                    rgb_outputs = self.rgb_model(**rgb_inputs)
                    rgb_logits = rgb_outputs.logits
                    rgb_probs = torch.softmax(rgb_logits, dim=-1)[0]
                    rgb_pred = torch.argmax(rgb_logits, dim=-1).item()
                    rgb_class = self.VISUAL_LABEL_MAP[rgb_pred]
                    
                    # Flow prediction - FIXED: Add interpolate flag
                    flow_inputs = self.flow_processor(flow_frames, return_tensors="pt", interpolate_pos_encoding=True)
                    flow_inputs = {k: v.to(self.device) for k, v in flow_inputs.items()}
                    flow_outputs = self.flow_model(**flow_inputs)
                    flow_logits = flow_outputs.logits
                    flow_probs = torch.softmax(flow_logits, dim=-1)[0]
                    flow_pred = torch.argmax(flow_logits, dim=-1).item()
                    flow_class = self.VISUAL_LABEL_MAP[flow_pred]
                    
                    # Audio prediction
                    audio_inputs = self.audio_processor(
                        audio, sampling_rate=self.SAMPLING_RATE, return_tensors="pt", padding=True
                    )
                    audio_inputs = {k: v.to(self.device) for k, v in audio_inputs.items()}
                    audio_outputs = self.audio_model(**audio_inputs)
                    audio_logits = audio_outputs.logits
                    audio_probs = torch.softmax(audio_logits, dim=-1)[0]
                    audio_pred = torch.argmax(audio_logits, dim=-1).item()
                    audio_class = self.AUDIO_LABEL_MAP[audio_pred]
                    
                    # Special handling for Abuse -> Fighting remapping
                    if audio_pred == 1:  # Abuse index
                        print("DEBUG: Audio predicted Abuse, remapping to Fighting")
                        audio_class = "Fighting"
            
            # Clear memory after individual model predictions
            if self.low_memory_mode:
                del rgb_inputs, flow_inputs, audio_inputs
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
            
            # Choose fusion method based on current setting
            print(f"Using fusion method: {self.fusion_method}")
            
            if self.fusion_method == 'rgb':
                # RGB-only prediction
                return self._create_result_from_single(rgb_class, self.VISUAL_LABEL_MAP)
                
            elif self.fusion_method == 'flow':
                # Flow-only prediction
                return self._create_result_from_single(flow_class, self.VISUAL_LABEL_MAP)
                
            elif self.fusion_method == 'audio':
                # Audio-only prediction
                return self._create_result_from_single(audio_class, self.AUDIO_LABEL_MAP)
                
            elif self.fusion_method == 'majority':
                # Simple majority voting
                return self.majority_vote_prediction(rgb_class, flow_class, audio_class)
                
            elif self.fusion_method == 'proposed':
                # Proposed methodology with batched intra/inter fusion and space-time attention
                # Extract features from each modality with batching
                rgb_features = self.process_rgb_frames_batch(rgb_frames)
                
                # Clear memory after RGB processing
                if self.low_memory_mode:
                    del rgb_frames
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    
                flow_features = self.process_flow_frames_batch(flow_frames)
                
                # Clear memory after flow processing
                if self.low_memory_mode:
                    del flow_frames
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    
                audio_features = self.process_audio(audio)
                
                # Clear memory after audio processing
                if self.low_memory_mode:
                    del audio
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                
                # Apply proposed methodology with batched processing
                multimodal_features = self.apply_proposed_methodology(
                    rgb_features, flow_features, audio_features
                )
                
                # Clear feature variables to free memory
                if self.low_memory_mode:
                    del rgb_features, flow_features, audio_features
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                
                # Get adaptive threshold if not provided
                if threshold is None:
                    threshold = self.get_adaptive_threshold()
                    
                # Final classification
                predictions = self.classify(multimodal_features)
                
                # Process predictions and apply filtering
                raw_result = self._process_predictions(predictions, threshold)
                scaled_result = self.apply_multi_scale_detection(raw_result, None)
                result = self.temporal_filter.update(scaled_result)
                
                # Log results for debugging
                print("Proposed Method Results:")
                for i, prob in enumerate(result['raw_predictions']):
                    class_name = self.COMBINED_LABEL_MAP.get(i, f"Unknown_{i}")
                    print(f"  {class_name}: {prob:.4f}")
                print(f"Predicted classes: {result['predicted_classes']}")
                
                # Update context based on new predictions
                new_context, duration = self.update_context(result)
                if new_context and duration:
                    context_key = f"context_{int(time.time())}"
                    self.context_events[context_key] = (new_context, time.time() + duration)
                
                return result
                
            else:
                # Advanced fusion with cross-attention (default method)
                # Process inputs through models to get features
                rgb_features = self.process_rgb_frames(rgb_frames)
                
                # Clear memory after RGB processing
                if self.low_memory_mode:
                    del rgb_frames
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    
                flow_features = self.process_flow_frames(flow_frames)
                
                # Clear memory after flow processing
                if self.low_memory_mode:
                    del flow_frames
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    
                audio_features = self.process_audio(audio)
                
                # Clear memory after audio processing
                if self.low_memory_mode:
                    del audio
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                
                # Apply bidirectional cross-attention fusion for multimodal integration
                multimodal_features = self.apply_bidirectional_cross_attention(
                    rgb_features, flow_features, audio_features
                )
                
                # Clear feature variables to free memory
                if self.low_memory_mode:
                    del rgb_features, flow_features, audio_features
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                
                # Get adaptive threshold if not provided
                if threshold is None:
                    threshold = self.get_adaptive_threshold()
                    
                # Final classification
                predictions = self.classify(multimodal_features)
                
                # Process predictions
                raw_result = self._process_predictions(predictions, threshold)
                
                # Apply multi-scale analysis with context awareness
                context_boost = None
                if self.context_events:
                    # Check for active contextual boosts
                    current_time = time.time()
                    context_boost = {}
                    
                    for context_type, (boost, end_time) in list(self.context_events.items()):
                        if current_time < end_time:
                            # This context is still active
                            context_boost.update(boost)
                        else:
                            # Expired context
                            del self.context_events[context_type]
                
                # Apply multi-scale analysis with context
                scaled_result = self.apply_multi_scale_detection(raw_result, context_boost)
                
                # Apply temporal consistency filtering
                result = self.temporal_filter.update(scaled_result)
                
                # Log results for debugging
                print("Advanced Method Results:")
                for i, prob in enumerate(result['raw_predictions']):
                    class_name = self.COMBINED_LABEL_MAP.get(i, f"Unknown_{i}")
                    print(f"  {class_name}: {prob:.4f}")
                print(f"Predicted classes: {result['predicted_classes']}")
                
                # Update context based on new predictions
                new_context, duration = self.update_context(result)
                if new_context and duration:
                    context_key = f"context_{int(time.time())}"
                    self.context_events[context_key] = (new_context, time.time() + duration)
                
                return result
                
        except Exception as e:
            print(f"Error in prediction pipeline: {e}")
            import traceback
            traceback.print_exc()
            return None
    
    def _process_predictions(self, predictions, threshold):
        """Process raw predictions to get final output with class labels."""
        if predictions is None:
            return None
            
        # Convert predictions to numpy and apply threshold
        predictions_np = predictions.cpu().numpy()[0]
        
        # Get predicted classes with Abuse remapping
        predicted_indices = np.where(predictions_np >= threshold)[0]
        predicted_classes = []
        
        for idx in predicted_indices:
            if idx < len(self.COMBINED_LABEL_MAP):
                label = self.COMBINED_LABEL_MAP[idx]
                if label == "Abuse":  # Remap Abuse to Fighting
                    label = "Fighting"
                    print(f"DEBUG: Remapping Abuse (index {idx}) to Fighting in final output")
                predicted_classes.append(label)
            else:
                print(f"WARNING: Prediction index {idx} is out of range for label map")
        
        # If no class is above threshold, take the highest probability class
        if len(predicted_classes) == 0:
            max_idx = np.argmax(predictions_np)
            if max_idx < len(self.COMBINED_LABEL_MAP):
                label = self.COMBINED_LABEL_MAP[max_idx]
                if label == "Abuse":  # Remap Abuse to Fighting
                    label = "Fighting" 
                    print(f"DEBUG: Remapping max-probability Abuse to Fighting in final output")
                predicted_classes = [label]
                predicted_indices = [max_idx]
            else:
                print(f"WARNING: Max probability index {max_idx} is out of range")
                predicted_classes = ["Normal"]  # Fallback
                predicted_indices = [0]
        
        # Create result dictionary
        result = {
            'raw_predictions': predictions_np,
            'predicted_indices': predicted_indices,
            'predicted_classes': predicted_classes
        }
        
        return result

In [None]:
# Define color map for different event types (Abuse removed - remapped to Fighting)
COLOR_MAP = {
    'Normal': QColor(50, 200, 50),      # Green
    'Explosion': QColor(255, 127, 0),   # Orange
    'Fighting': QColor(200, 50, 50),    # Red
    'Car Accident': QColor(50, 50, 200),# Blue
    'Shooting': QColor(200, 0, 200),    # Purple
    'Riot': QColor(153, 51, 255),       # Purple-blue
    'None': QColor(100, 100, 100)       # Gray - used for areas without any classification data
}

In [None]:
class MultimodalPredictionWorker(QThread):
    predictionReady = pyqtSignal(float, float, str, str)  # start_time, end_time, modality, anomaly_type
    segmentProcessed = pyqtSignal(int, int)  # progress signal: current, total
    progressUpdated = pyqtSignal(int, int)  # additional signal for UI updates
    
    def __init__(self, pipeline):
        super().__init__()
        self.pipeline = pipeline
        self.video_segments = deque()
        self.running = True
        self.mutex = QMutex()
        self.debug_mode = True  # Enable debug mode
        self.processing = False  # Flag to track if currently processing
        self.low_memory_mode = pipeline.low_memory_mode
        
    def add_video_segment(self, video_path, start_time, end_time):
        self.mutex.lock()
        self.video_segments.append((video_path, start_time, end_time))
        total_segments = len(self.video_segments)
        self.mutex.unlock()
        
        if self.debug_mode:
            print(f"Added segment for processing: {start_time:.2f}-{end_time:.2f} from {video_path}")
            print(f"Queue size: {total_segments} segments")
        
        # Update progress immediately
        self.progressUpdated.emit(0, total_segments)
    
    def clear_queue(self):
        """Clear all pending work from the queue."""
        self.mutex.lock()
        queue_size = len(self.video_segments)
        self.video_segments.clear()
        print(f"Cleared {queue_size} pending video segments from worker queue")
        self.mutex.unlock()
        
        # Reset progress
        self.progressUpdated.emit(0, 0)
    
    def stop(self):
        self.running = False
    
    def run(self):
        """Process segments with enhanced progress reporting."""
        total_segments = 0
        processed_segments = 0
        
        while self.running:
            # Get current queue state
            self.mutex.lock()
            total_segments = len(self.video_segments)
            self.mutex.unlock()
            
            # CRITICAL: Emit progress update even when idle
            self.progressUpdated.emit(processed_segments, processed_segments + total_segments)
            
            # IMPORTANT: Force UI update - this ensures the UI thread gets to process our signal
            QApplication.processEvents()
            
            if total_segments > 0 and not self.processing:
                self.processing = True  # Set flag to indicate processing
                
                # Get the next segment to process
                self.mutex.lock()
                video_path, start_time, end_time = self.video_segments.popleft()
                current_queue_size = len(self.video_segments)
                self.mutex.unlock()
                
                processed_segments += 1
                
                # Report progress - emit both signals
                self.segmentProcessed.emit(processed_segments, processed_segments + current_queue_size)
                self.progressUpdated.emit(processed_segments, processed_segments + current_queue_size)
                
                # Force UI update before intensive processing
                QApplication.processEvents()
                
                temp_dir = None
                temp_path = None
                
                try:
                    # Clear CUDA cache before processing new segment
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    
                    # Create temporary directory and files
                    temp_dir = tempfile.mkdtemp()
                    temp_path = os.path.join(temp_dir, f"temp_segment_{start_time:.2f}_{end_time:.2f}.mp4")
                    if self.debug_mode:
                        print(f"Creating temp segment at: {temp_path}")
                    
                    # Process with explicit context management
                    with mp.VideoFileClip(video_path) as video:
                        segment = video.subclip(start_time, end_time)
                        segment.write_videofile(temp_path, codec='libx264', 
                                              audio_codec='aac', 
                                              verbose=False, 
                                              logger=None)
                        segment.close()
                    
                    # Force garbage collection after writing
                    gc.collect()
                    
                    # DEBUG: Extract frame and audio info
                    if self.debug_mode:
                        with mp.VideoFileClip(temp_path) as clip:
                            print(f"Temp clip duration: {clip.duration:.2f}s, FPS: {clip.fps}")
                            if clip.audio:
                                print(f"Audio detected in clip: {clip.audio.fps}Hz")
                            else:
                                print("WARNING: No audio detected in clip")
                    
                    # Update progress after segment preparation
                    self.progressUpdated.emit(processed_segments, processed_segments + current_queue_size)
                    QApplication.processEvents()
                    
                    # Process RGB - one modality at a time to manage memory
                    print("\n--- RGB MODEL PROCESSING ---")
                    rgb_frames = self.pipeline.extract_frames(temp_path)
                    if rgb_frames is not None:
                        print(f"RGB frames extracted: {len(rgb_frames)} frames")
                        
                        # Use mixed precision for inference
                        with torch.cuda.amp.autocast(enabled=self.pipeline.use_mixed_precision):
                            with torch.no_grad():
                                rgb_inputs = self.pipeline.rgb_processor(rgb_frames, return_tensors="pt")
                                rgb_inputs = {k: v.to(self.pipeline.device) for k, v in rgb_inputs.items()}
                                rgb_outputs = self.pipeline.rgb_model(**rgb_inputs)
                                rgb_logits = rgb_outputs.logits.detach().cpu()
                                rgb_probs = torch.softmax(rgb_logits, dim=-1)[0]
                                rgb_pred = torch.argmax(rgb_logits, dim=-1).item()
                                rgb_class = self.pipeline.VISUAL_LABEL_MAP[rgb_pred]
                        
                        print(f"RGB Prediction: {rgb_class} (class {rgb_pred}) with confidence {rgb_probs[rgb_pred]:.4f}")
                        print("RGB Class Probabilities:")
                        for i, prob in enumerate(rgb_probs):
                            print(f"  {self.pipeline.VISUAL_LABEL_MAP[i]}: {prob:.4f}")
                        
                        # Emit signal
                        self.predictionReady.emit(start_time, end_time, "rgb", rgb_class)
                        print(f"RGB signal emitted: {rgb_class} for {start_time:.2f}-{end_time:.2f}")
                        
                        # Clear unnecessary variables
                        del rgb_inputs, rgb_outputs, rgb_logits, rgb_probs, rgb_frames
                        
                        # Clear CUDA cache if in low memory mode
                        if torch.cuda.is_available() and self.low_memory_mode:
                            torch.cuda.empty_cache()
                    else:
                        print("ERROR: RGB frames extraction failed")
                        rgb_class = "Normal"  # Default
                    
                    # Update progress after RGB processing
                    self.progressUpdated.emit(processed_segments, processed_segments + current_queue_size)
                    QApplication.processEvents()
                    
                    # Process flow
                    print("\n--- FLOW MODEL PROCESSING ---")
                    flow_frames = self.pipeline.compute_optical_flow(temp_path)
                    if flow_frames is not None:
                        print(f"Flow frames extracted: {len(flow_frames)} frames")
                        
                        # Use mixed precision for inference
                        with torch.cuda.amp.autocast(enabled=self.pipeline.use_mixed_precision):
                            with torch.no_grad():
                                flow_inputs = self.pipeline.flow_processor(flow_frames, return_tensors="pt")
                                flow_inputs = {k: v.to(self.pipeline.device) for k, v in flow_inputs.items()}
                                flow_outputs = self.pipeline.flow_model(**flow_inputs)
                                flow_logits = flow_outputs.logits.detach().cpu()
                                flow_probs = torch.softmax(flow_logits, dim=-1)[0]
                                flow_pred = torch.argmax(flow_logits, dim=-1).item()
                                flow_class = self.pipeline.VISUAL_LABEL_MAP[flow_pred]
                        
                        print(f"Flow Prediction: {flow_class} (class {flow_pred}) with confidence {flow_probs[flow_pred]:.4f}")
                        print("Flow Class Probabilities:")
                        for i, prob in enumerate(flow_probs):
                            print(f"  {self.pipeline.VISUAL_LABEL_MAP[i]}: {prob:.4f}")
                        
                        # Emit signal
                        self.predictionReady.emit(start_time, end_time, "flow", flow_class)
                        print(f"Flow signal emitted: {flow_class} for {start_time:.2f}-{end_time:.2f}")
                        
                        # Clear unnecessary variables
                        del flow_inputs, flow_outputs, flow_logits, flow_probs, flow_frames
                        
                        # Clear CUDA cache if in low memory mode
                        if torch.cuda.is_available() and self.low_memory_mode:
                            torch.cuda.empty_cache()
                    else:
                        print("ERROR: Flow frames extraction failed")
                        flow_class = "Normal"  # Default
                    
                    # Update progress after Flow processing
                    self.progressUpdated.emit(processed_segments, processed_segments + current_queue_size)
                    QApplication.processEvents()
                    
                    # Process audio
                    print("\n--- AUDIO MODEL PROCESSING ---")
                    audio = self.pipeline.extract_audio(temp_path)
                    if audio is not None:
                        print(f"Audio extracted: {len(audio)} samples")
                        
                        # Process audio with detailed output
                        with torch.cuda.amp.autocast(enabled=self.pipeline.use_mixed_precision):
                            with torch.no_grad():
                                audio_inputs = self.pipeline.audio_processor(
                                    audio, sampling_rate=self.pipeline.SAMPLING_RATE, return_tensors="pt", padding=True
                                )
                                audio_inputs = {k: v.to(self.pipeline.device) for k, v in audio_inputs.items()}
                                audio_outputs = self.pipeline.audio_model(**audio_inputs)
                                audio_logits = audio_outputs.logits.detach().cpu()
                                audio_probs = torch.softmax(audio_logits, dim=-1)[0]
                                audio_pred = torch.argmax(audio_logits, dim=-1).item()
                                audio_class = self.pipeline.AUDIO_LABEL_MAP[audio_pred]
                                
                                # Special handling for Abuse -> Fighting remapping
                                if audio_pred == 1:  # Abuse index
                                    print("DEBUG: Audio detected Abuse, remapping to Fighting")
                                    audio_class = "Fighting"
                        
                        print(f"Audio Prediction: {audio_class} (class {audio_pred}) with confidence {audio_probs[audio_pred]:.4f}")
                        print("Audio Class Probabilities:")
                        for i, prob in enumerate(audio_probs):
                            if i == 1:  # Abuse class
                                print(f"  {self.pipeline.AUDIO_LABEL_MAP[i]} (remapped to Fighting): {prob:.4f}")
                            else:
                                print(f"  {self.pipeline.AUDIO_LABEL_MAP[i]}: {prob:.4f}")
                        
                        # Emit signal
                        self.predictionReady.emit(start_time, end_time, "audio", audio_class)
                        print(f"Audio signal emitted: {audio_class} for {start_time:.2f}-{end_time:.2f}")
                        
                        # Clear unnecessary variables
                        del audio_inputs, audio_outputs, audio_logits, audio_probs, audio
                        
                        # Clear CUDA cache if in low memory mode
                        if torch.cuda.is_available() and self.low_memory_mode:
                            torch.cuda.empty_cache()
                    else:
                        print("ERROR: Audio extraction failed")
                        audio_class = "Normal"  # Default

                    # Update progress after Audio processing
                    self.progressUpdated.emit(processed_segments, processed_segments + current_queue_size)
                    QApplication.processEvents()
                    
                    # Combined prediction using the selected fusion method
                    print(f"\n--- COMBINED MODEL PROCESSING USING {self.pipeline.fusion_method.upper()} FUSION ---")
                    try:
                        # Use current fusion method
                        fusion_method = self.pipeline.fusion_method
                        
                        # Generate combined prediction based on fusion method
                        if fusion_method == 'rgb':
                            result = self.pipeline._create_result_from_single(
                                rgb_class,
                                self.pipeline.VISUAL_LABEL_MAP
                            )
                        elif fusion_method == 'flow':
                            result = self.pipeline._create_result_from_single(
                                flow_class,
                                self.pipeline.VISUAL_LABEL_MAP
                            )
                        elif fusion_method == 'audio':
                            result = self.pipeline._create_result_from_single(
                                audio_class,
                                self.pipeline.AUDIO_LABEL_MAP
                            )
                        elif fusion_method == 'majority':
                            result = self.pipeline.majority_vote_prediction(rgb_class, flow_class, audio_class)
                        elif fusion_method == 'proposed' or fusion_method == 'advanced':
                            # Process the full segment with advanced fusion (reloads the temp file)
                            result = self.pipeline.predict(temp_path)
                        else:
                            # Fallback to advanced fusion
                            result = self.pipeline.predict(temp_path)
                        
                        if result and result['predicted_classes']:
                            print(f"Combined Prediction ({fusion_method}): {result['predicted_classes']}")
                            print("Combined Raw Probabilities:")
                            for i, prob in enumerate(result['raw_predictions']):
                                label = self.pipeline.COMBINED_LABEL_MAP.get(i, f"Class_{i}")
                                # Special handling for Abuse label
                                if label == "Abuse":
                                    print(f"  {label} (remapped to Fighting): {prob:.4f}")
                                else:
                                    print(f"  {label}: {prob:.4f}")
                            
                            # Print verification message for console and timeline match
                            print("\nVERIFICATION: Prediction in console matches what will appear in timeline:")
                            for cls in result['predicted_classes']:
                                print(f"  TIMELINE WILL SHOW: {cls} for {start_time:.2f}-{end_time:.2f}")
                                print(f"  EMITTING SIGNAL NOW: combined, {cls}, {start_time:.2f}-{end_time:.2f}")
                                self.predictionReady.emit(start_time, end_time, "combined", cls)
                        else:
                            print("ERROR: No combined predictions generated")
                            
                    except Exception as fusion_error:
                        print(f"Error during fusion: {fusion_error}")
                        import traceback
                        traceback.print_exc()
                
                except Exception as e:
                    print(f"ERROR in prediction worker: {e}")
                    import traceback
                    traceback.print_exc()
                
                finally:
                    # Force garbage collection
                    gc.collect()
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    
                    # Add a delay to ensure resources are released
                    time.sleep(0.5)
                    
                    # Clean up temporary files
                    try:
                        if temp_path and os.path.exists(temp_path):
                            os.remove(temp_path)
                            if self.debug_mode:
                                print(f"Removed temp file: {temp_path}")
                        if temp_dir and os.path.exists(temp_dir):
                            os.rmdir(temp_dir)
                            if self.debug_mode:
                                print(f"Removed temp dir: {temp_dir}")
                    except Exception as cleanup_error:
                        print(f"Cleanup error: {cleanup_error}")
                
                # Final progress update after segment completion
                self.progressUpdated.emit(processed_segments, processed_segments + current_queue_size)
                QApplication.processEvents()
                
                # Reset processing flag and wait before next segment
                self.processing = False
                time.sleep(0.5)  # 500ms delay between segments
            
            # Use shorter sleep times and process events during wait periods
            time.sleep(0.05)
            QApplication.processEvents()  # Keep UI responsive during sleep

In [None]:
class TimelineWidget(QFrame):
    def __init__(self):
        super().__init__()
        self.setMinimumHeight(80)
        self.setStyleSheet("background-color: #2d2d2d;")
        self.anomalies = []
        self.filtered_anomalies = None  # For interactive legends
        self.duration = 0
        self.debug_mode = True  # Enable debug mode
    
    def set_anomalies(self, anomalies):
        if self.debug_mode:
            print(f"\n==== TIMELINE WIDGET UPDATE ====")
            print(f"Received {len(anomalies)} anomalies to display")
            for i, (start, end, anomaly) in enumerate(anomalies):
                if i < 10:  # Limit the logging to avoid excessive output
                    print(f"  {i+1}. {anomaly} at {start:.2f}-{end:.2f}")
                elif i == 10:
                    print(f"  ... and {len(anomalies)-10} more")
        
        self.anomalies = anomalies
        self.filtered_anomalies = None  # Reset filtered view
        self.update()  # Force repaint
    
    def set_filtered_anomalies(self, anomalies):
        """Set anomalies with filtering applied"""
        self.filtered_anomalies = anomalies
        self.update()  # Force repaint
    
    def set_duration(self, duration):
        if self.debug_mode:
            print(f"Timeline duration set to: {duration:.2f} seconds")
        self.duration = duration
        self.update()
    
    def paintEvent(self, event):
        if self.duration <= 0:
            if self.debug_mode:
                print("Timeline paint skipped: Duration is zero")
            return
        
        painter = QPainter(self)
        painter.setRenderHint(QPainter.Antialiasing)
        
        # Draw background
        painter.fillRect(self.rect(), QBrush(QColor(45, 45, 45)))
        
        # Draw timeline base
        painter.setPen(QPen(QColor(200, 200, 200), 1))
        y_middle = self.height() // 2
        painter.drawLine(0, y_middle, self.width(), y_middle)
        
        # Draw time markers
        painter.setPen(QPen(QColor(150, 150, 150), 1))
        marker_interval = self.width() / 10
        for i in range(11):
            x = i * marker_interval
            painter.drawLine(int(x), y_middle - 5, int(x), y_middle + 5)
            
            # Draw time text
            time_at_marker = (i / 10) * self.duration
            minutes = int(time_at_marker / 60)
            seconds = int(time_at_marker % 60)
            time_text = f"{minutes:02d}:{seconds:02d}"
            painter.drawText(int(x) - 15, y_middle + 20, time_text)
        
        # Use filtered_anomalies if available, otherwise use all anomalies
        segments_to_draw = self.filtered_anomalies if self.filtered_anomalies is not None else self.anomalies
        
        # Draw anomaly segments
        segments_drawn = 0
        for start_time, end_time, anomaly_type in segments_to_draw:
            if start_time >= self.duration:
                continue
            
            # Calculate positions
            start_pos = int((start_time / self.duration) * self.width())
            end_pos = int((min(end_time, self.duration) / self.duration) * self.width())
            
            # Get color for anomaly type
            color = COLOR_MAP.get(anomaly_type, QColor(100, 100, 100))
            
            # Draw segment
            painter.fillRect(start_pos, 5, end_pos - start_pos, self.height() - 10, QBrush(color))
            
            # Draw label if segment is wide enough
            if end_pos - start_pos > 50:
                painter.setPen(QPen(QColor(255, 255, 255), 1))
                painter.drawText(start_pos + 5, y_middle + 5, anomaly_type)
            
            segments_drawn += 1
        
        if self.debug_mode and segments_drawn > 0 and segments_drawn != len(segments_to_draw):
            print(f"Timeline painted with {segments_drawn} visible segments out of {len(segments_to_draw)}")

In [None]:
class InteractiveLegendWidget(QFrame):
    legendToggled = pyqtSignal(str, bool)  # Signal: category, is_visible
    
    def __init__(self, color_map):
        super().__init__()
        self.color_map = color_map
        self.setMinimumHeight(40)
        self.setMaximumHeight(60)
        self.setStyleSheet("background-color: #333333; border-radius: 5px;")
        
        self.visibility = {label: True for label in self.color_map.keys()}
        
        # Create layout
        self.layout = QHBoxLayout(self)
        self.layout.setSpacing(15)
        self.layout.setContentsMargins(10, 5, 10, 5)
        
        # Add title
        title = QLabel("Legend:")
        title.setStyleSheet("color: white; font-weight: bold;")
        self.layout.addWidget(title)
        
        # Create legend items for each class
        self.checkboxes = {}
        for label, color in self.color_map.items():
            if label != "None":  # Skip the "None" category in legends
                legend_item = QWidget()
                item_layout = QHBoxLayout(legend_item)
                item_layout.setContentsMargins(0, 0, 0, 0)
                item_layout.setSpacing(5)
                
                # Checkbox for toggling
                checkbox = QCheckBox()
                checkbox.setChecked(True)
                checkbox.setStyleSheet("QCheckBox::indicator { width: 12px; height: 12px; }")
                checkbox.stateChanged.connect(lambda state, l=label: self.toggle_category(l, state))
                self.checkboxes[label] = checkbox
                
                # Color indicator
                color_box = QLabel()
                color_box.setFixedSize(16, 16)
                color_box.setStyleSheet(f"background-color: rgb({color.red()}, {color.green()}, {color.blue()}); border-radius: 3px;")
                
                # Text label
                text_label = QLabel(label)
                text_label.setStyleSheet("color: white; font-size: 12px;")
                
                item_layout.addWidget(checkbox)
                item_layout.addWidget(color_box)
                item_layout.addWidget(text_label)
                
                self.layout.addWidget(legend_item)
        
        # Add stretch to keep items left-aligned
        self.layout.addStretch(1)
    
    def toggle_category(self, label, state):
        """Toggle visibility of a category when checkbox changes"""
        is_visible = (state == Qt.Checked)
        self.visibility[label] = is_visible
        self.legendToggled.emit(label, is_visible)

In [None]:
class MultimodalTimelinePlayerWindow(QMainWindow):
    def __init__(self):
        super().__init__()
        
        # Initialize pipeline
        self.pipeline = ViolenceDetectionPipeline()
        self.current_video_path = None
        
        # Setup UI
        self.setWindowTitle("Multimodal Violence Detection Video Player")
        self.setGeometry(100, 100, 1200, 800)  # Make window taller for extra controls
        
        # Create central widget and layout
        self.central_widget = QWidget()
        self.setCentralWidget(self.central_widget)
        self.layout = QVBoxLayout(self.central_widget)
        
        # Video display area
        self.video_frame = QLabel()
        self.video_frame.setAlignment(Qt.AlignCenter)
        self.video_frame.setMinimumSize(640, 360)
        self.video_frame.setStyleSheet("background-color: black;")
        self.layout.addWidget(self.video_frame)
        
        # Current classification displays (one row with different modalities)
        self.classifications_layout = QHBoxLayout()
        
        # RGB classification
        self.rgb_label = QLabel("RGB: None")
        self.rgb_label.setAlignment(Qt.AlignCenter)
        self.rgb_label.setStyleSheet("font-size: 14px; font-weight: bold;")
        self.classifications_layout.addWidget(self.rgb_label)
        
        # Flow classification
        self.flow_label = QLabel("Flow: None")
        self.flow_label.setAlignment(Qt.AlignCenter)
        self.flow_label.setStyleSheet("font-size: 14px; font-weight: bold;")
        self.classifications_layout.addWidget(self.flow_label)
        
        # Audio classification
        self.audio_label = QLabel("Audio: None")
        self.audio_label.setAlignment(Qt.AlignCenter)
        self.audio_label.setStyleSheet("font-size: 14px; font-weight: bold;")
        self.classifications_layout.addWidget(self.audio_label)
        
        # Combined classification (larger)
        self.combined_label = QLabel("Combined: None")
        self.combined_label.setAlignment(Qt.AlignCenter)
        self.combined_label.setStyleSheet("font-size: 18px; font-weight: bold;")
        
        self.layout.addLayout(self.classifications_layout)
        self.layout.addWidget(self.combined_label)
        
        # Timeline widget
        self.timeline_widget = TimelineWidget()
        self.layout.addWidget(self.timeline_widget)
        
        # Add interactive legends widget right after timeline
        self.legend_widget = InteractiveLegendWidget(COLOR_MAP)
        self.legend_widget.legendToggled.connect(self.toggle_category_visibility)
        self.layout.addWidget(self.legend_widget)
        
        # Fusion method selector
        self.fusion_layout = QHBoxLayout()
        self.fusion_label = QLabel("Fusion Method:")
        self.fusion_dropdown = QComboBox()
        self.fusion_dropdown.addItems([
            "Advanced (Cross-Attention)", 
            "Proposed (Intra/Inter Fusion)", 
            "RGB Only", 
            "Flow Only", 
            "Audio Only",
            "Simple (Majority Vote)"
        ])
        self.fusion_dropdown.setCurrentIndex(5)  # Default to proposed method
        self.fusion_dropdown.currentIndexChanged.connect(self.change_fusion_method)
        self.fusion_layout.addWidget(self.fusion_label)
        self.fusion_layout.addWidget(self.fusion_dropdown)
        self.layout.addLayout(self.fusion_layout)
        
        # Threshold slider
        self.threshold_layout = QHBoxLayout()
        self.threshold_label = QLabel(f"Threshold: {self.pipeline.base_threshold:.2f}")
        self.threshold_slider = QSlider(Qt.Horizontal)
        self.threshold_slider.setRange(0, 100)
        self.threshold_slider.setValue(int(self.pipeline.base_threshold * 100))
        self.threshold_slider.valueChanged.connect(self.update_threshold)
        self.threshold_layout.addWidget(self.threshold_label)
        self.threshold_layout.addWidget(self.threshold_slider)
        self.layout.addLayout(self.threshold_layout)
        
        # Progress bar for segment processing
        self.progress_layout = QHBoxLayout()
        self.progress_label = QLabel("Processing segments: 0/0")
        self.progress_bar = QProgressBar()
        self.progress_bar.setRange(0, 100)
        self.progress_bar.setValue(0)
        self.progress_layout.addWidget(self.progress_label)
        self.progress_layout.addWidget(self.progress_bar)
        self.layout.addLayout(self.progress_layout)
        
        # Controls layout
        self.controls_layout = QHBoxLayout()
        
        # Play/Pause button
        self.play_button = QPushButton()
        self.play_button.setIcon(self.style().standardIcon(QStyle.SP_MediaPlay))
        self.play_button.clicked.connect(self.toggle_play)
        self.controls_layout.addWidget(self.play_button)
        
        # Time display
        self.time_label = QLabel("00:00 / 00:00")
        self.controls_layout.addWidget(self.time_label)
        
        # Position slider
        self.position_slider = QSlider(Qt.Horizontal)
        self.position_slider.sliderMoved.connect(self.set_position)
        self.controls_layout.addWidget(self.position_slider)
        
        # Open file button
        self.open_button = QPushButton("Open Video")
        self.open_button.clicked.connect(self.open_file)
        self.controls_layout.addWidget(self.open_button)
        
        # Diagnostic button
        self.diagnostic_button = QPushButton("Run Diagnostics")
        self.diagnostic_button.setStyleSheet("background-color: #17a2b8; color: white; font-weight: bold; padding: 8px;")
        self.diagnostic_button.clicked.connect(self.run_diagnostics)
        self.controls_layout.addWidget(self.diagnostic_button)
        
        self.layout.addLayout(self.controls_layout)
        
        # Start processing button
        self.process_layout = QHBoxLayout()
        self.start_processing_button = QPushButton("▶ Start Processing")
        self.start_processing_button.setStyleSheet("background-color: #28a745; color: white; font-weight: bold; padding: 8px;")
        self.start_processing_button.clicked.connect(self.toggle_processing)
        self.process_layout.addWidget(self.start_processing_button)
        self.layout.addLayout(self.process_layout)
        
        # Video processing variables
        self.cap = None
        self.timer = QTimer(self)
        self.timer.timeout.connect(self.update_frame)
        self.current_frame = 0
        self.fps = 0
        self.total_frames = 0
        self.playing = False
        
        # Prediction variables
        self.rgb_segments = []
        self.flow_segments = []
        self.audio_segments = []
        self.combined_segments = []
        self.current_time = 0
        self.last_prediction_time = 0
        
        # Set debug mode
        self.debug_mode = True
        
        # Create UI refresh timer
        self.ui_refresh_timer = QTimer(self)
        self.ui_refresh_timer.timeout.connect(lambda: QApplication.processEvents())
        self.ui_refresh_timer.start(100)  # Refresh UI every 100ms
        
        # Start prediction thread
        self.prediction_worker = MultimodalPredictionWorker(self.pipeline)
        
        # Connect signals with explicit connection type for thread safety
        self.prediction_worker.predictionReady.connect(self.update_anomaly, Qt.QueuedConnection)
        self.prediction_worker.segmentProcessed.connect(self.update_progress, Qt.QueuedConnection)
        self.prediction_worker.progressUpdated.connect(self.update_progress_bar, Qt.QueuedConnection)
        self.prediction_worker.start()
        
        # Show file dialog upon launch
        QTimer.singleShot(100, self.open_file)
    
    def update_progress_bar(self, current, total):
        """Enhanced progress bar update with forced refresh"""
        if total > 0:
            percentage = int((current / total) * 100)
            
            # Update progress bar with new value
            self.progress_bar.setValue(percentage)
            
            # Update label with detailed information
            self.progress_label.setText(f"Processing segments: {current}/{total} ({percentage}%)")
            
            # Print debug statement to verify updates
            print(f"Progress update: {percentage}% ({current}/{total})")
            
            # Force immediate UI refresh
            self.progress_bar.repaint()
            QApplication.processEvents()
    
    def toggle_category_visibility(self, category, is_visible):
        """Handle toggling visibility of event categories"""
        # Filter segments based on visibility settings
        visible_segments = []
        for start, end, event_type in self.combined_segments:
            if event_type == category and not is_visible:
                continue  # Skip this segment as its category is hidden
            visible_segments.append((start, end, event_type))
        
        # Update timeline with filtered segments
        self.timeline_widget.set_filtered_anomalies(visible_segments)
    
    def update_threshold(self, value):
        """Update the detection threshold based on slider value"""
        threshold = value / 100.0
        self.threshold_label.setText(f"Threshold: {threshold:.2f}")
        self.pipeline.base_threshold = threshold
        self.pipeline.temporal_filter.threshold = threshold
        print(f"Updated detection threshold to {threshold:.2f}")
    
    def change_fusion_method(self, index):
        """Change the fusion method used for predictions"""
        methods = ["advanced", "majority", "rgb", "flow", "audio", "proposed"]
        self.pipeline.fusion_method = methods[index]
        print(f"Changed fusion method to: {self.pipeline.fusion_method}")
        
        # Update threshold slider based on fusion method
        if self.pipeline.fusion_method in ["rgb", "flow", "audio", "majority"]:
            # Single modality methods use higher threshold
            self.threshold_slider.setValue(70)
            self.update_threshold(70)
        else:
            # Advanced fusion methods use default threshold
            self.threshold_slider.setValue(50)
            self.update_threshold(50)
    
    def toggle_processing(self):
        """Start or pause the background processing of video segments"""
        if self.start_processing_button.text().startswith("▶"):
            # Start processing
            if self.current_video_path:
                self.start_processing_button.setText("⏸ Pause Processing")
                self.start_processing_button.setStyleSheet("background-color: #dc3545; color: white; font-weight: bold; padding: 8px;")
                
                # Reset progress indicators before starting
                self.progress_bar.setValue(0)
                self.progress_label.setText("Processing segments: 0/0")
                QApplication.processEvents()  # Force immediate UI update
                
                self.start_background_analysis(self.current_video_path)
        else:
            # Pause processing
            self.start_processing_button.setText("▶ Start Processing")
            self.start_processing_button.setStyleSheet("background-color: #28a745; color: white; font-weight: bold; padding: 8px;")
            self.prediction_worker.clear_queue()
    
    def run_diagnostics(self):
        """Run diagnostic analysis on the current video position"""
        if not self.cap or not self.current_video_path:
            print("No video loaded. Please load a video first.")
            return
            
        print("\n===== RUNNING MODEL DIAGNOSTICS =====")
        
        # Create a short temp segment from current position
        current_time = self.current_frame / self.fps
        temp_dir = tempfile.mkdtemp()
        temp_path = os.path.join(temp_dir, "diagnostic_segment.mp4")
        
        try:
            # Extract 3-second clip around current position
            with mp.VideoFileClip(self.current_video_path) as video:
                start_time = max(0, current_time - 1.5)
                end_time = min(video.duration, current_time + 1.5)
                segment = video.subclip(start_time, end_time)
                segment.write_videofile(temp_path, codec='libx264', 
                                      audio_codec='aac', 
                                      verbose=False,
                                      logger=None)
            
            # Run diagnostics
            results = self.pipeline.run_diagnostics(temp_path)
            
            # Show simple vote result as reference
            rgb_class, rgb_probs = results["rgb"] 
            flow_class, flow_probs = results["flow"]
            audio_class, audio_probs = results["audio"]
            majority_vote = results["majority"]
            
            # Create visualization of all probabilities
            fig = plt.figure(figsize=(10, 8))
            
            # RGB subplot
            ax1 = fig.add_subplot(311)
            ax1.bar(list(self.pipeline.VISUAL_LABEL_MAP.values()), rgb_probs)
            ax1.set_title("RGB Predictions")
            ax1.set_ylim(0, 1)
            
            # Flow subplot
            ax2 = fig.add_subplot(312)
            ax2.bar(list(self.pipeline.VISUAL_LABEL_MAP.values()), flow_probs)
            ax2.set_title("Flow Predictions")
            ax2.set_ylim(0, 1)
            
            # Audio subplot
            ax3 = fig.add_subplot(313)
            ax3.bar(list(self.pipeline.AUDIO_LABEL_MAP.values()), audio_probs)
            ax3.set_title("Audio Predictions")
            ax3.set_ylim(0, 1)
            
            plt.tight_layout()
            
            # Save the diagnostic figure
            diag_image_path = os.path.join(temp_dir, "diagnostics.png")
            plt.savefig(diag_image_path)
            plt.close()
            
            # Create diagnostic message
            diagnostic_text = (
                f"DIAGNOSTIC RESULTS at {self.format_time(current_time)}:\n\n"
                f"RGB Model: {rgb_class} ({rgb_probs[list(self.pipeline.VISUAL_LABEL_MAP.values()).index(rgb_class)]:.3f})\n"
                f"Flow Model: {flow_class} ({flow_probs[list(self.pipeline.VISUAL_LABEL_MAP.values()).index(flow_class)]:.3f})\n"
                f"Audio Model: {audio_class} ({audio_probs[list(self.pipeline.AUDIO_LABEL_MAP.values()).index(audio_class)]:.3f})\n\n"
                f"Majority Vote: {majority_vote}\n"
                f"Current Fusion Method: {self.pipeline.fusion_method}\n"
                f"Current Threshold: {self.pipeline.base_threshold:.2f}"
            )
            
            # Show dialog with results
            msg_box = QMessageBox()
            msg_box.setWindowTitle("Diagnostic Results")
            msg_box.setText(diagnostic_text)
            msg_box.setIconPixmap(QPixmap(diag_image_path).scaled(600, 400, Qt.KeepAspectRatio))
            msg_box.exec_()
            
        except Exception as e:
            print(f"Error running diagnostics: {e}")
            import traceback
            traceback.print_exc()
        
        finally:
            # Clean up
            try:
                if os.path.exists(temp_path):
                    os.remove(temp_path)
                if os.path.exists(os.path.join(temp_dir, "diagnostics.png")):
                    os.remove(os.path.join(temp_dir, "diagnostics.png"))
                if os.path.exists(temp_dir):
                    os.rmdir(temp_dir)
            except Exception as cleanup_error:
                print(f"Cleanup error: {cleanup_error}")
    
    def update_progress(self, current, total):
        """Update progress bar and label for segment processing."""
        percentage = int((current / total) * 100) if total > 0 else 0
        self.progress_bar.setValue(percentage)
        self.progress_label.setText(f"Processing segments: {current}/{total} ({percentage}%)")
    
    def open_file(self):
        file_path, _ = QFileDialog.getOpenFileName(self, "Open Video File", "",
            "Video Files (*.mp4 *.avi *.mkv *.mov);;All Files (*)")
        if file_path:
            self.load_video(normalize_path(file_path))
    
    def load_video(self, video_path):
        print(f"\n==== LOADING VIDEO: {video_path} ====")
        
        # Stop current video if playing
        if self.playing:
            self.timer.stop()
            self.playing = False
        
        # Clean up resources from previous video
        self.cleanup_temp_resources()
        
        # Release previous capture if any
        if self.cap is not None:
            self.cap.release()
            self.cap = None
        
        # Reset all UI elements
        self.reset_ui_state()
        
        # Open new video
        self.cap = cv2.VideoCapture(video_path)
        if not self.cap.isOpened():
            print(f"ERROR: Could not open video {video_path}")
            return
        
        # Get video properties
        self.fps = self.cap.get(cv2.CAP_PROP_FPS)
        self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
        print(f"Video properties: {self.total_frames} frames, {self.fps} FPS")
        print(f"Video duration: {self.format_time(self.total_frames / self.fps)}")
        
        self.current_frame = 0
        self.current_time = 0
        self.last_prediction_time = 0
        self.current_video_path = video_path
        
        # Clear previous anomalies
        self.rgb_segments = []
        self.flow_segments = []
        self.audio_segments = []
        self.combined_segments = []
        print("Cleared previous anomalies")
        
        # Get video duration
        duration = self.total_frames / self.fps
        
        # Initialize timeline
        self.timeline_widget.set_anomalies([])
        self.timeline_widget.set_duration(duration)
        
        # Reset UI
        self.position_slider.setRange(0, self.total_frames)
        duration_str = self.format_time(duration)
        self.time_label.setText(f"00:00 / {duration_str}")
        
        # Show first frame
        ret, frame = self.cap.read()
        if ret:
            self.display_frame(frame)
        
        print("Video loaded successfully")
        
        # Reset processing button
        self.start_processing_button.setText("▶ Start Processing")
        self.start_processing_button.setStyleSheet("background-color: #28a745; color: white; font-weight: bold; padding: 8px;")
    
    def cleanup_temp_resources(self):
        """Clean up all temporary resources and files."""
        print("Cleaning up temporary resources...")
        
        # Cancel any pending predictions
        self.prediction_worker.clear_queue()
        
        # Clean temp directories in the system temp folder
        temp_dirs = [d for d in os.listdir(tempfile.gettempdir()) 
                     if d.startswith('tmp') and os.path.isdir(os.path.join(tempfile.gettempdir(), d))]
        
        for d in temp_dirs:
            try:
                temp_path = os.path.join(tempfile.gettempdir(), d)
                # Check if directory contains our temp video segments
                if any(f.endswith('.mp4') or f.endswith('.wav') for f in os.listdir(temp_path) 
                      if os.path.isfile(os.path.join(temp_path, f))):
                    # Force garbage collection to release any file handles
                    gc.collect()
                    
                    # Remove files in directory
                    for f in os.listdir(temp_path):
                        try:
                            file_path = os.path.join(temp_path, f)
                            if os.path.isfile(file_path):
                                os.remove(file_path)
                        except Exception as e:
                            print(f"Error removing file {f}: {e}")
                    
                    # Remove directory
                    os.rmdir(temp_path)
                    print(f"Cleaned up temporary directory: {temp_path}")
            except Exception as e:
                print(f"Error cleaning temp directory {d}: {e}")
        
        # Reset model caches if needed
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        print("Cleanup complete")
    
    def reset_ui_state(self):
        """Reset all UI elements to their initial state."""
        print("Resetting UI state...")
        
        # Reset classification labels
        self.rgb_label.setText("RGB: None")
        self.rgb_label.setStyleSheet("font-size: 14px; font-weight: bold;")
        
        self.flow_label.setText("Flow: None")
        self.flow_label.setStyleSheet("font-size: 14px; font-weight: bold;")
        
        self.audio_label.setText("Audio: None")
        self.audio_label.setStyleSheet("font-size: 14px; font-weight: bold;")
        
        self.combined_label.setText("Combined: None")
        self.combined_label.setStyleSheet("font-size: 18px; font-weight: bold;")
        
        # Reset progress bar
        self.progress_bar.setValue(0)
        self.progress_label.setText("Processing segments: 0/0")
        
        # Clear video frame
        self.video_frame.clear()
        self.video_frame.setPixmap(QPixmap())
        
        # Reset timeline
        self.timeline_widget.set_anomalies([])
        self.timeline_widget.set_duration(0)
        
        # Reset time label
        self.time_label.setText("00:00 / 00:00")
        
        # Reset slider
        self.position_slider.setValue(0)
        self.position_slider.setRange(0, 100)  # Default range until video is loaded
    
    def start_background_analysis(self, video_path):
        """Start background processing of the entire video for timeline."""
        try:
            # Get video duration with proper resource management
            with mp.VideoFileClip(video_path) as video:
                duration = video.duration
                
            # Create segments with 50% overlap for better temporal coverage
            # Using variable segment length based on video duration for more efficient processing
            if duration <= 60:  # For short videos (under 1 minute)
                segment_length = 3
                overlap = 1.5
            elif duration <= 180:  # For medium videos (1-3 minutes)
                segment_length = 4
                overlap = 2
            else:  # For longer videos
                segment_length = 5
                overlap = 2.5
            
            start_times = np.arange(0, duration - segment_length/2, segment_length - overlap)
            
            total_segments = len(start_times)
            print(f"Preparing to analyze video in {total_segments} segments...")
            
            # Initialize progress tracking
            self.progress_bar.setRange(0, total_segments)
            self.progress_bar.setValue(0)
            self.progress_label.setText(f"Processing segments: 0/{total_segments} (0%)")
            
            # Queue segments for processing
            for start_time in start_times:
                end_time = min(start_time + segment_length, duration)
                self.prediction_worker.add_video_segment(video_path, start_time, end_time)
        
        except Exception as e:
            print(f"Error preparing video analysis: {e}")
            import traceback
            traceback.print_exc()
    
    def toggle_play(self):
        if self.cap is None:
            return
        
        if self.playing:
            self.timer.stop()
            self.play_button.setIcon(self.style().standardIcon(QStyle.SP_MediaPlay))
        else:
            self.timer.start(1000 // 30)  # 30 fps display
            self.play_button.setIcon(self.style().standardIcon(QStyle.SP_MediaPause))
        
        self.playing = not self.playing
    
    def update_frame(self):
        if self.cap is None or not self.playing:
            return
        
        ret, frame = self.cap.read()
        if not ret:
            # End of video - perform cleanup
            print("End of video reached, performing cleanup...")
            self.timer.stop()
            self.playing = False
            self.play_button.setIcon(self.style().standardIcon(QStyle.SP_MediaPlay))
            
            # Clean up temporary resources when video ends
            self.cleanup_temp_resources()
            return
        
        # Display the frame
        self.display_frame(frame)
        
        # Update current time
        self.current_time = self.current_frame / self.fps
        
        # Debug the timeline position periodically (every 30 frames)
        if self.debug_mode and self.current_frame % 30 == 0:
            print(f"\n==== CURRENT POSITION: {self.format_time(self.current_time)} ====")
            print(f"RGB anomalies: {len(self.rgb_segments)}")
            print(f"Flow anomalies: {len(self.flow_segments)}")
            print(f"Audio anomalies: {len(self.audio_segments)}")
            print(f"Combined anomalies: {len(self.combined_segments)}")
        
        # Update slider and time label
        self.position_slider.setValue(self.current_frame)
        current_time_str = self.format_time(self.current_time)
        duration_str = self.format_time(self.total_frames / self.fps)
        self.time_label.setText(f"{current_time_str} / {duration_str}")
        
        # Update classification display
        rgb_class = self.get_classification_at_time(self.current_time, self.rgb_segments)
        flow_class = self.get_classification_at_time(self.current_time, self.flow_segments)
        audio_class = self.get_classification_at_time(self.current_time, self.audio_segments)
        combined_class = self.get_classification_at_time(self.current_time, self.combined_segments)
        
        # Print classifications periodically
        if self.debug_mode and self.current_frame % 30 == 0:
            print(f"Current classifications at {self.format_time(self.current_time)}:")
            print(f"  RGB: {rgb_class}")
            print(f"  Flow: {flow_class}")
            print(f"  Audio: {audio_class}")
            print(f"  Combined: {combined_class}")
        
        self.rgb_label.setText(f"RGB: {rgb_class}")
        self.flow_label.setText(f"Flow: {flow_class}")
        self.audio_label.setText(f"Audio: {audio_class}")
        self.combined_label.setText(f"Combined: {combined_class}")
        
        # Set classification colors
        self.set_label_color(self.rgb_label, rgb_class)
        self.set_label_color(self.flow_label, flow_class)
        self.set_label_color(self.audio_label, audio_class)
        self.set_label_color(self.combined_label, combined_class)
        
        # Increment frame counter
        self.current_frame += 1
    
    def set_label_color(self, label, class_name):
        """Set label color based on the class name."""
        if class_name == "None":
            if label == self.combined_label:
                label.setStyleSheet("font-size: 18px; font-weight: bold;")
            else:
                label.setStyleSheet("font-size: 14px; font-weight: bold;")
            return
            
        color = COLOR_MAP.get(class_name, QColor(100, 100, 100))
        
        # Combined label is larger
        if label == self.combined_label:
            label.setStyleSheet(f"font-size: 18px; font-weight: bold; color: rgb({color.red()}, {color.green()}, {color.blue()})")
        else:
            label.setStyleSheet(f"font-size: 14px; font-weight: bold; color: rgb({color.red()}, {color.green()}, {color.blue()})")
    
    def display_frame(self, frame):
        """Convert frame to QImage and display it."""
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        h, w, ch = frame_rgb.shape
        bytes_per_line = ch * w
        q_img = QImage(frame_rgb.data, w, h, bytes_per_line, QImage.Format_RGB888)
        self.video_frame.setPixmap(QPixmap.fromImage(q_img).scaled(
            self.video_frame.width(), self.video_frame.height(),
            Qt.KeepAspectRatio, Qt.SmoothTransformation))
    
    def set_position(self, position):
        """Set video playback position."""
        if self.cap is None:
            return
        
        # Seek to position
        self.cap.set(cv2.CAP_PROP_POS_FRAMES, position)
        self.current_frame = position
        self.current_time = position / self.fps
        
        # Read and display the frame
        ret, frame = self.cap.read()
        if ret:
            self.display_frame(frame)
        
        # Update time label
        current_time_str = self.format_time(self.current_time)
        duration_str = self.format_time(self.total_frames / self.fps)
        self.time_label.setText(f"{current_time_str} / {duration_str}")
        
        # Update classification displays
        rgb_class = self.get_classification_at_time(self.current_time, self.rgb_segments)
        flow_class = self.get_classification_at_time(self.current_time, self.flow_segments)
        audio_class = self.get_classification_at_time(self.current_time, self.audio_segments)
        combined_class = self.get_classification_at_time(self.current_time, self.combined_segments)
        
        self.rgb_label.setText(f"RGB: {rgb_class}")
        self.flow_label.setText(f"Flow: {flow_class}")
        self.audio_label.setText(f"Audio: {audio_class}")
        self.combined_label.setText(f"Combined: {combined_class}")
        
        # Set classification colors
        self.set_label_color(self.rgb_label, rgb_class)
        self.set_label_color(self.flow_label, flow_class)
        self.set_label_color(self.audio_label, audio_class)
        self.set_label_color(self.combined_label, combined_class)
    
    def update_anomaly(self, start_time, end_time, modality, anomaly_type):
        """Update anomaly segments from prediction thread with debugging."""
        if self.debug_mode:
            print(f"\n==== ANOMALY RECEIVED: {modality.upper()} ====")
            print(f"Time: {start_time:.2f}-{end_time:.2f}, Type: {anomaly_type}")
        
        if modality == "rgb":
            self.rgb_segments.append((start_time, end_time, anomaly_type))
            if self.debug_mode:
                print(f"Added to rgb_segments. Total: {len(self.rgb_segments)}")
        elif modality == "flow":
            self.flow_segments.append((start_time, end_time, anomaly_type))
            if self.debug_mode:
                print(f"Added to flow_segments. Total: {len(self.flow_segments)}")
        elif modality == "audio":
            self.audio_segments.append((start_time, end_time, anomaly_type))
            if self.debug_mode:
                print(f"Added to audio_segments. Total: {len(self.audio_segments)}")
        elif modality == "combined":
            self.combined_segments.append((start_time, end_time, anomaly_type))
            if self.debug_mode:
                print(f"Added to combined_segments. Total: {len(self.combined_segments)}")
        
        # Update timeline with combined segments
        if modality == "combined":
            if self.debug_mode:
                print(f"Updating timeline with {len(self.combined_segments)} combined segments")
            self.timeline_widget.set_anomalies(self.combined_segments)
        
        # Force UI update if we're at this time point
        if start_time <= self.current_time <= end_time:
            if self.debug_mode:
                print(f"Updating UI labels for current time ({self.current_time:.2f})")
            if modality == "rgb":
                self.rgb_label.setText(f"RGB: {anomaly_type}")
                self.set_label_color(self.rgb_label, anomaly_type)
            elif modality == "flow":
                self.flow_label.setText(f"Flow: {anomaly_type}")
                self.set_label_color(self.flow_label, anomaly_type)
            elif modality == "audio":
                self.audio_label.setText(f"Audio: {anomaly_type}")
                self.set_label_color(self.audio_label, anomaly_type)
            elif modality == "combined":
                self.combined_label.setText(f"Combined: {anomaly_type}")
                self.set_label_color(self.combined_label, anomaly_type)
    
    def get_classification_at_time(self, time_point, segments):
        """Find the classification for the current time with debugging."""
        if len(segments) == 0:
            return "None"
            
        matching_segments = []
        for start_time, end_time, anomaly_type in reversed(segments):
            if start_time <= time_point <= end_time:
                matching_segments.append((start_time, end_time, anomaly_type))
        
        if matching_segments:
            # Take the first match (most recent in the reversed list)
            start_time, end_time, anomaly_type = matching_segments[0]
            return anomaly_type
        
        return "None"
    
    def format_time(self, seconds):
        """Format time in seconds to HH:MM:SS or MM:SS format."""
        minutes, seconds = divmod(int(seconds), 60)
        hours, minutes = divmod(minutes, 60)
        return f"{hours:02d}:{minutes:02d}:{seconds:02d}" if hours else f"{minutes:02d}:{seconds:02d}"
    
    def closeEvent(self, event):
        """Handle window close event with thorough cleanup."""
        print("Application closing, performing final cleanup...")
        
        # Stop playback
        if self.playing:
            self.timer.stop()
            self.playing = False
        
        # Clean up all temporary resources
        self.cleanup_temp_resources()
        
        # Release video capture
        if self.cap is not None:
            self.cap.release()
            self.cap = None
        
        # Stop prediction worker thread
        if hasattr(self, 'prediction_worker'):
            self.prediction_worker.stop()
            self.prediction_worker.wait()
        
        # Release CUDA memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        print("Application cleanup complete")
        event.accept()

In [None]:
# Main entry point
if __name__ == "__main__":
    app = QApplication(sys.argv)
    window = MultimodalTimelinePlayerWindow()
    window.show()
    sys.exit(app.exec_())