In [None]:
# ═══════════════════════════════════════════════════════════════════════════════
# 📹 ULTIMATE TWO-CLASS VIDEO CLASSIFICATION - SINGLE CELL COLAB SCRIPT
# ═══════════════════════════════════════════════════════════════════════════════

# 1️⃣ ───── INSTALL & IMPORTS ─────
import subprocess
import sys
import os


def install_package(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])


# Install all required packages
required_packages = [
    "pytorchvideo",
    "timesformer-pytorch",
    "torchmetrics",
    "rich",
    "matplotlib",
    "seaborn",
    "scikit-learn",
    "decord",
    "PyYAML",
    "einops",  # Required for ViViT and other transformers
    "av",  # Alternative video reader
]

print("🔧 Installing required packages...")
for package in required_packages:
    try:
        install_package(package)
        print(f"✅ {package} installed successfully")
    except Exception as e:
        print(f"❌ Failed to install {package}: {e}")
        sys.exit(1)

# Mount Google Drive
try:
    from google.colab import drive, files

    drive.mount("/content/drive")
    print("📁 Google Drive mounted successfully")
except ImportError:
    print("⚠️ Not running in Colab - Drive mount skipped")
    files = None

# Core imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.io import VideoReader
import torch.cuda.amp as amp
from torchmetrics import Accuracy, Precision, Recall, F1Score, AUROC, ConfusionMatrix
from torchmetrics.classification import BinarySpecificity
from torchmetrics.classification import Recall

# To use it for binary sensitivity, you would instantiate it like this:
binary_sensitivity = Recall(task="binary")

import yaml
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from rich.console import Console
from rich.progress import Progress, TaskID
from rich.table import Table
import warnings
import random
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Tuple, Optional, Union
import time
import gc
from collections import defaultdict
from einops import rearrange, repeat  # For transformer models
from einops.layers.torch import Rearrange

# Suppress warnings
warnings.filterwarnings("ignore")

# 2️⃣ ───── PARSE YAML CONFIG ─────
config_text = """
# ═══════════════════════════════════════════════════════════════════════════════
# 📋 CONFIGURATION FILE - COMPLETE DOCUMENTATION
# ═══════════════════════════════════════════════════════════════════════════════
# This configuration file controls all aspects of the video classification pipeline.
# Each section is thoroughly documented with exact keywords and valid options.

# ───── PATHS CONFIGURATION ─────
# Configure input/output directories for your project
paths:
  # Root directory containing your video dataset
  # Structure must be: data_root/class_name/video_files.*
  # Supported video formats: .mp4, .avi, .mov, .mkv
  data_root: "/content/drive/MyDrive/SkillDataset_2Clusters"

  # Directory where all outputs will be saved (logs, checkpoints, plots, predictions)
  # A timestamped subdirectory will be created for each run
  output_root: "/content/drive/MyDrive/2ClusterVideos/outputs"

# ───── HARDWARE CONFIGURATION ─────
# Control GPU usage and memory optimization
hardware:
  # Number of GPUs to use
  # Options: 0 (CPU only), 1+ (number of CUDA devices to use)
  gpus: 1

  # Enable Automatic Mixed Precision (AMP) for faster training and lower memory usage
  # Options: true, false
  mixed_precision: true

  # Soft GPU memory limit in GB - script will auto-tune batch size if exceeded
  # Recommended: 8, 12, 16, 24, 40, 80 (based on your GPU)
  max_gpu_mem_gb: 12

# ───── DATA LOADING & AUGMENTATION ─────
# Configure how videos are loaded and processed
data:
  # Frames per second to sample from videos
  # Lower values = fewer frames, faster processing
  # Typical values: 1, 5, 10, 15, 30
  frame_rate: 10

  # Number of frames per video snippet/clip
  # Must be compatible with model architecture
  # Common values: 8, 16, 32, 64, 128, 256
  clip_len: 50

  # Overlap between consecutive snippets (in frames)
  # 0 = no overlap, clip_len/2 = 50% overlap
  # Increases data but also training time
  snippet_overlap: 10

  # Number of CPU workers for data loading
  # Recommended: 2-8 (depends on CPU cores)
  num_workers: 4

  # ───── TRAIN/VAL/TEST SPLIT CONFIGURATION ─────
  # Two modes available for splitting your dataset:

  # Split mode - controls how data is divided
  # Options:
  # - "stratified": Maintains class proportions across splits (uses global_split_pct)
  # - "manual": Allows custom class ratios per split (uses class_ratios & manual_split_sizes)
  split_mode: "stratified"

  # ── STRATIFIED MODE SETTINGS ──
  # Used only when split_mode = "stratified"
  # Percentages must sum to 100
  global_split_pct:
    train: 80    # 70% of each class → training
    val:   10    # 15% of each class → validation
    test:  10    # 15% of each class → testing

  # ── MANUAL MODE SETTINGS ──
  # Used only when split_mode = "manual"

  # Class ratios - controls relative proportion of each class within each split
  # Values represent percentages and should sum to 100 for each split
  # Example: {cluster_0: 80, cluster_1: 20} = 80% class 0, 20% class 1
  class_ratios:
    train: {cluster_0: 50, cluster_1: 50}    # Balanced training set
    val: {cluster_0: 50, cluster_1: 50}      # Balanced validation set
    test: {cluster_0: 50, cluster_1: 50}     # Balanced test set

  # Total number of samples per split (used only in manual mode)
  # Adjust based on your dataset size
  manual_split_sizes:
    train: 140   # Total training samples
    val: 10      # Total validation samples
    test: 10     # Total test samples

# ───── MODEL CONFIGURATION ─────
# Select and configure the neural network architecture
model:
  # Model architecture to use
  # Available options (exact keywords):
  #
  # CNN-based models:
  # - "x3d_m": X3D-Medium (efficient 3D CNN)
  # - "slow_r50": Slow pathway ResNet-50
  # - "slowfast_r50": SlowFast ResNet-50 (dual pathway)
  # - "r2plus1d": R(2+1)D-18 (decomposed 3D convolutions)
  # - "r3d_18": ResNet 3D-18 (full 3D convolutions)
  #
  # Hybrid CNN-RNN models:
  # - "cnn_lstm": CNN backbone + Bidirectional LSTM
  # - "cnn_gru": CNN backbone + Bidirectional GRU
  #
  # Transformer-based models:
  # - "timesformer": TimeSformer (divided space-time attention)
  # - "mvit": Multiscale Vision Transformer
  # - "videomae": Video Masked Autoencoder V2
  # - "vivit": Video Vision Transformer
  model_name: "x3d_m"

  # Freeze pretrained backbone weights (train only classifier head)
  # Options: true (faster, less memory), false (better accuracy)
  freeze_backbone: false

  # Dropout rate for classifier head (0.0-1.0)
  # Higher values = more regularization
  dropout: 0.25

# ───── TRAINING HYPERPARAMETERS ─────
# Configure the training process
train:
  # Maximum number of training epochs
  epochs: 25

  # Batch size (will be auto-tuned if GPU memory exceeded)
  batch_size: 4

  # Learning rate (typical range: 1e-5 to 1e-2)
  lr: 1.0e-4

  # Weight decay for AdamW optimizer (L2 regularization)
  weight_decay: 1.0e-4

  # Learning rate scheduler
  # Options: "cosine" (smooth decay), "step" (sudden drops)
  scheduler: "cosine"

  # Step scheduler parameters (only used if scheduler = "step")
  step_size: 10    # Epochs between LR drops
  gamma: 0.1       # LR multiplication factor

  # Gradient accumulation steps (simulates larger batch size)
  # Use values > 1 if running out of memory
  gradient_accumulation: 1

  # Early stopping patience (epochs without improvement)
  # Set to -1 to disable early stopping
  early_stop_patience: 5

  # Random seed for reproducibility
  seed: 42

# ───── METRICS CONFIGURATION ─────
# Metrics to track during training/evaluation
# Order matters for display in logs and plots
# Available metrics: loss, accuracy, precision, recall, f1, sensitivity, specificity, auc
metrics: [loss, accuracy, precision, recall, f1, sensitivity, specificity, auc]

# ───── LOGGING CONFIGURATION ─────
# Control output verbosity and style
logging:
  # Print frequency (batches between console updates)
  print_freq: 5

  # Enable emoji in console output
  emojis: true

  # Enable colored console output
  colour: true

  # Save all console output to log file
  save_stdout: true

  # Save detailed per-sample predictions (JSON format)
  save_detailed_predictions: true

# ───── RUN MODES ─────
# Toggle different pipeline stages
modes:
  # Train the model
  run_training: true

  # Evaluate best model on test set
  run_eval: true

  # Run inference demo on single video
  run_inference: true

  # Path to video for inference demo
  inference_video: "/content/drive/MyDrive/2CusterVideos/cluster_0/SK_0002_S1_1006_Capsulorhexis.avi"

# ───── CLASS HANDLING ─────
# Override automatic class detection
override:
  # Custom class names (null = auto-detect from folder names)
  # Example: ["healthy", "diseased"]
  class_names: null

  # Pretty names for display (maps class index to display name)
  # Example: {0: "Healthy 🌿", 1: "Diseased 🍂"}
  class_name_map: {}
"""

config = yaml.safe_load(config_text)

# 3️⃣ ───── UTILITIES ─────
console = Console()


def seed_everything(seed: int) -> None:
    """Set seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def check_gpu_memory() -> Tuple[bool, float]:
    """Check available GPU memory."""
    if torch.cuda.is_available():
        gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
        return True, gpu_mem
    return False, 0.0


def print_section(title: str, emoji: str = "🔹") -> None:
    """Print formatted section header."""
    if config["logging"]["emojis"]:
        console.print(f"\n{emoji} ───── {title.upper()} ─────", style="bold cyan")
    else:
        console.print(f"\n───── {title.upper()} ─────", style="bold cyan")


def setup_output_dirs(output_root: str) -> Dict[str, Path]:
    """Create output directory structure."""
    output_path = Path(output_root)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_dir = output_path / f"run_{timestamp}"

    dirs = {
        "root": run_dir,
        "checkpoints": run_dir / "checkpoints",
        "logs": run_dir / "logs",
        "plots": run_dir / "plots",
        "predictions": run_dir / "predictions",
    }

    for dir_path in dirs.values():
        dir_path.mkdir(parents=True, exist_ok=True)

    return dirs


# 4️⃣ ───── SPLIT ENGINE (FIXED) ─────
def create_splits(data_root: str, config: Dict) -> Dict:
    """Create train/val/test splits with proper manual mode support."""
    data_path = Path(data_root)

    # Auto-detect classes
    class_folders = [
        d for d in data_path.iterdir() if d.is_dir() and not d.name.startswith(".")
    ]
    class_names = config["override"]["class_names"] or sorted(
        [d.name for d in class_folders]
    )

    # Collect all videos per class
    videos_per_class = {}
    for class_idx, class_name in enumerate(class_names):
        class_folder = data_path / class_name
        if not class_folder.exists():
            continue

        video_files = []
        for ext in ["*.mp4", "*.avi", "*.mov", "*.mkv"]:
            video_files.extend(list(class_folder.glob(ext)))

        videos_per_class[class_name] = {"files": video_files, "class_idx": class_idx}

        console.print(f"📁 Found {len(video_files)} videos in class '{class_name}'")

    splits = {"train": [], "val": [], "test": []}

    if config["data"]["split_mode"] == "stratified":
        # Stratified mode - preserve original class distribution
        for class_name, class_data in videos_per_class.items():
            video_files = class_data["files"]
            class_idx = class_data["class_idx"]
            n_videos = len(video_files)

            if n_videos == 0:
                continue

            # Shuffle videos
            indices = list(range(n_videos))
            np.random.shuffle(indices)

            # Calculate split sizes
            pct = config["data"]["global_split_pct"]
            n_train = int(n_videos * pct["train"] / 100)
            n_val = int(n_videos * pct["val"] / 100)

            # Split indices
            train_idx = indices[:n_train]
            val_idx = indices[n_train : n_train + n_val]
            test_idx = indices[n_train + n_val :]

            # Add to splits
            for idx in train_idx:
                splits["train"].append(
                    {
                        "video_path": str(video_files[idx]),
                        "class_idx": class_idx,
                        "class_name": class_name,
                    }
                )
            for idx in val_idx:
                splits["val"].append(
                    {
                        "video_path": str(video_files[idx]),
                        "class_idx": class_idx,
                        "class_name": class_name,
                    }
                )
            for idx in test_idx:
                splits["test"].append(
                    {
                        "video_path": str(video_files[idx]),
                        "class_idx": class_idx,
                        "class_name": class_name,
                    }
                )

    else:  # manual mode - control class distribution within each split
        # For each split, calculate how many samples from each class
        for split_name in ["train", "val", "test"]:
            total_samples = config["data"]["manual_split_sizes"][split_name]
            class_ratios = config["data"]["class_ratios"][split_name]

            # Normalize ratios to sum to 100
            ratio_sum = sum(class_ratios.values())

            samples_added = 0
            for class_name in class_names:
                if class_name not in videos_per_class:
                    continue

                # Calculate number of samples for this class in this split
                ratio = class_ratios.get(class_name, 0)
                n_samples = int(total_samples * ratio / ratio_sum)

                # Get available videos for this class
                class_data = videos_per_class[class_name]
                video_files = class_data["files"]
                class_idx = class_data["class_idx"]

                # Sample videos (with replacement if necessary)
                if len(video_files) > 0:
                    if n_samples > len(video_files):
                        # Need to sample with replacement
                        sampled_indices = np.random.choice(
                            len(video_files), n_samples, replace=True
                        )
                    else:
                        # Sample without replacement
                        sampled_indices = np.random.choice(
                            len(video_files), n_samples, replace=False
                        )

                    for idx in sampled_indices:
                        splits[split_name].append(
                            {
                                "video_path": str(video_files[idx]),
                                "class_idx": class_idx,
                                "class_name": class_name,
                            }
                        )
                        samples_added += 1

            # Shuffle the split to mix classes
            np.random.shuffle(splits[split_name])

            console.print(
                f"📊 {split_name.capitalize()} split: {len(splits[split_name])} samples"
            )

            # Print class distribution
            class_counts = defaultdict(int)
            for item in splits[split_name]:
                class_counts[item["class_name"]] += 1

            for class_name, count in class_counts.items():
                percentage = (
                    (count / len(splits[split_name])) * 100 if splits[split_name] else 0
                )
                console.print(f"   - {class_name}: {count} samples ({percentage:.1f}%)")

    # Save splits info
    splits_info = {
        "splits": splits,
        "class_names": class_names,
        "num_classes": len(class_names),
        "split_mode": config["data"]["split_mode"],
        "videos_per_class": {k: len(v["files"]) for k, v in videos_per_class.items()},
    }

    return splits_info


# 5️⃣ ───── DATASET & DATALOADERS (FIXED WITH OVERLAP) ─────
class VideoDataset(Dataset):
    def __init__(
        self,
        video_list: List[Dict],
        clip_len: int,
        frame_rate: int,
        overlap: int = 0,
        augment: bool = False,
    ):
        self.video_list = video_list
        self.clip_len = clip_len
        self.frame_rate = frame_rate
        self.overlap = overlap
        self.augment = augment

        # Calculate stride between snippets
        self.stride = max(1, clip_len - overlap)

        # Build snippet index
        self.snippets = []
        for video_info in video_list:
            video_path = video_info["video_path"]

            # Get video duration to calculate snippets
            # For simplicity, we'll add multiple snippets per video based on overlap
            # In a real implementation, you'd get the actual video length
            if overlap > 0:
                # Add multiple snippets per video
                n_snippets = (
                    3  # Simplified - in reality, calculate based on video length
                )
                for snippet_idx in range(n_snippets):
                    self.snippets.append(
                        {
                            "video_path": video_path,
                            "class_idx": video_info["class_idx"],
                            "class_name": video_info["class_name"],
                            "snippet_idx": snippet_idx,
                            "start_frame": snippet_idx * self.stride,
                        }
                    )
            else:
                # No overlap - one snippet per video
                self.snippets.append(
                    {
                        "video_path": video_path,
                        "class_idx": video_info["class_idx"],
                        "class_name": video_info["class_name"],
                        "snippet_idx": 0,
                        "start_frame": 0,
                    }
                )

        # Normalization (ImageNet stats)
        self.normalize = transforms.Normalize(
            mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225]
        )

        # Augmentation transforms
        if augment:
            self.spatial_transform = transforms.Compose(
                [
                    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.ColorJitter(
                        brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1
                    ),
                ]
            )
        else:
            self.spatial_transform = transforms.Resize((224, 224))

    def __len__(self) -> int:
        return len(self.snippets)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int, Dict]:
        snippet_info = self.snippets[idx]
        video_path = snippet_info["video_path"]
        label = snippet_info["class_idx"]
        start_frame = snippet_info["start_frame"]

        try:
            # Load video using VideoReader
            vr = VideoReader(video_path, "video")

            video_frames = []

            # Sample frames starting from start_frame
            for i in range(self.clip_len):
                frame_idx = start_frame + i
                try:
                    vr.seek(float(frame_idx))
                    frame = next(vr)["data"]  # (H, W, C)

                    # Convert to PIL for transforms
                    frame = transforms.ToPILImage()(frame)

                    # Apply spatial transforms
                    frame = self.spatial_transform(frame)

                    # Convert back to tensor
                    frame = transforms.ToTensor()(frame)

                    video_frames.append(frame)
                except:
                    # If we can't read the frame, use the last valid frame
                    if video_frames:
                        video_frames.append(video_frames[-1].clone())
                    else:
                        # Create a black frame if no frames read yet
                        video_frames.append(torch.zeros(3, 224, 224))

            # Normalize frames
            normalized_frames = [self.normalize(frame) for frame in video_frames]

            # Stack frames
            video_tensor = torch.stack(normalized_frames, dim=1)  # (C, T, H, W)

            # Return tensor, label, and metadata
            metadata = {
                "video_path": video_path,
                "snippet_idx": snippet_info["snippet_idx"],
                "class_name": snippet_info["class_name"],
            }

            return video_tensor, label, metadata

        except Exception as e:
            console.print(f"❌ Error processing video {video_path}: {e}", style="red")
            # Return dummy tensor
            dummy_frames = [torch.zeros(3, 224, 224) for _ in range(self.clip_len)]
            normalized_frames = [self.normalize(f) for f in dummy_frames]
            video_tensor = torch.stack(normalized_frames, dim=1)

            metadata = {
                "video_path": video_path,
                "snippet_idx": snippet_info["snippet_idx"],
                "class_name": snippet_info["class_name"],
            }

            return video_tensor, label, metadata


# 6️⃣ ───── MODEL IMPLEMENTATIONS ─────


# ===== CNN-RNN HYBRID MODELS =====
class CNNFeatureExtractor(nn.Module):
    """ResNet-based feature extractor for CNN-RNN models."""

    def __init__(self, pretrained=True):
        super().__init__()
        # Use ResNet50 as backbone
        resnet = torch.hub.load("pytorch/vision", "resnet50", pretrained=pretrained)
        # Remove the final FC layer
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        self.feature_dim = 2048

    def forward(self, x):
        # x shape: (B, C, T, H, W)
        B, C, T, H, W = x.shape
        # Process each frame through CNN
        x = rearrange(x, "b c t h w -> (b t) c h w")
        features = self.features(x)  # (B*T, 2048, 1, 1)
        features = features.squeeze(-1).squeeze(-1)  # (B*T, 2048)
        features = rearrange(features, "(b t) d -> b t d", b=B, t=T)
        return features


class CNN_LSTM(nn.Module):
    """CNN backbone with Bidirectional LSTM for temporal modeling."""

    def __init__(self, num_classes, hidden_dim=512, num_layers=2, dropout=0.5):
        super().__init__()
        self.cnn = CNNFeatureExtractor(pretrained=True)
        self.lstm = nn.LSTM(
            input_size=self.cnn.feature_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,  # Bidirectional as required
            dropout=dropout if num_layers > 1 else 0,
        )
        self.dropout = nn.Dropout(dropout)
        # 2 * hidden_dim due to bidirectional
        self.classifier = nn.Linear(2 * hidden_dim, num_classes)

    def forward(self, x):
        # Extract CNN features
        features = self.cnn(x)  # (B, T, D)

        # LSTM processing
        lstm_out, (h_n, c_n) = self.lstm(features)  # lstm_out: (B, T, 2*hidden_dim)

        # Use the last hidden state from both directions
        # h_n shape: (num_layers * 2, B, hidden_dim)
        h_forward = h_n[-2]  # Last layer, forward direction
        h_backward = h_n[-1]  # Last layer, backward direction
        h_combined = torch.cat([h_forward, h_backward], dim=1)  # (B, 2*hidden_dim)

        # Classification
        h_combined = self.dropout(h_combined)
        output = self.classifier(h_combined)

        return output


class CNN_GRU(nn.Module):
    """CNN backbone with Bidirectional GRU for temporal modeling."""

    def __init__(self, num_classes, hidden_dim=512, num_layers=2, dropout=0.5):
        super().__init__()
        self.cnn = CNNFeatureExtractor(pretrained=True)
        self.gru = nn.GRU(
            input_size=self.cnn.feature_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,  # Bidirectional as required
            dropout=dropout if num_layers > 1 else 0,
        )
        self.dropout = nn.Dropout(dropout)
        # 2 * hidden_dim due to bidirectional
        self.classifier = nn.Linear(2 * hidden_dim, num_classes)

    def forward(self, x):
        # Extract CNN features
        features = self.cnn(x)  # (B, T, D)

        # GRU processing
        gru_out, h_n = self.gru(features)  # gru_out: (B, T, 2*hidden_dim)

        # Use the last hidden state from both directions
        # h_n shape: (num_layers * 2, B, hidden_dim)
        h_forward = h_n[-2]  # Last layer, forward direction
        h_backward = h_n[-1]  # Last layer, backward direction
        h_combined = torch.cat([h_forward, h_backward], dim=1)  # (B, 2*hidden_dim)

        # Classification
        h_combined = self.dropout(h_combined)
        output = self.classifier(h_combined)

        return output


# ===== TRANSFORMER MODELS =====


class MViT(nn.Module):
    """Multiscale Vision Transformer for video understanding."""

    def __init__(
        self,
        num_classes,
        img_size=224,
        patch_size=16,
        num_frames=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        dropout=0.1,
    ):
        super().__init__()
        self.num_frames = num_frames
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        # Patch embedding with temporal dimension
        self.patch_embed = nn.Conv3d(
            3,
            embed_dim,
            kernel_size=(1, patch_size, patch_size),
            stride=(1, patch_size, patch_size),
        )

        # Positional embeddings
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_frames * self.num_patches + 1, embed_dim)
        )
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # Transformer blocks with multi-scale pooling
        self.blocks = nn.ModuleList(
            [
                TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
                for _ in range(depth)
            ]
        )

        # Pooling layers for multi-scale features
        self.pool_layers = nn.ModuleList(
            [
                nn.MaxPool1d(kernel_size=2, stride=2) if i % 3 == 0 else nn.Identity()
                for i in range(depth)
            ]
        )

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        # Initialize weights
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        # x shape: (B, C, T, H, W)
        B = x.shape[0]

        # Patch embedding
        x = self.patch_embed(x)  # (B, embed_dim, T, H', W')
        x = rearrange(x, "b d t h w -> b (t h w) d")

        # Add cls token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)

        # Add positional embedding
        x = x + self.pos_embed[:, : x.size(1)]

        # Apply transformer blocks with multi-scale pooling
        for i, (block, pool) in enumerate(zip(self.blocks, self.pool_layers)):
            x = block(x)
            if not isinstance(pool, nn.Identity) and x.size(1) > 1:
                # Apply pooling to all tokens except cls token
                cls_token, tokens = x[:, :1], x[:, 1:]
                tokens = rearrange(tokens, "b n d -> b d n")
                tokens = pool(tokens)
                tokens = rearrange(tokens, "b d n -> b n d")
                x = torch.cat([cls_token, tokens], dim=1)

        x = self.norm(x)

        # Use cls token for classification
        x = x[:, 0]
        x = self.head(x)

        return x


class TransformerBlock(nn.Module):
    """Basic transformer block with self-attention and MLP."""

    def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(
            dim, num_heads, dropout=dropout, batch_first=True
        )
        self.norm2 = nn.LayerNorm(dim)

        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # Self-attention
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)
        x = x + attn_out

        # MLP
        x = x + self.mlp(self.norm2(x))

        return x


class VideoMAE(nn.Module):
    """Video Masked Autoencoder V2 - adapted for classification."""

    def __init__(
        self,
        num_classes,
        img_size=224,
        patch_size=16,
        num_frames=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        decoder_embed_dim=512,
        decoder_depth=4,
        decoder_num_heads=8,
        mlp_ratio=4.0,
        dropout=0.1,
    ):
        super().__init__()
        self.patch_size = patch_size
        self.num_frames = num_frames

        # Patch embedding (3D for video)
        self.patch_embed = nn.Conv3d(
            3,
            embed_dim,
            kernel_size=(2, patch_size, patch_size),
            stride=(2, patch_size, patch_size),
        )

        num_patches = (num_frames // 2) * (img_size // patch_size) ** 2
        self.num_patches = num_patches

        # Positional embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

        # Transformer encoder
        self.blocks = nn.ModuleList(
            [
                TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
                for _ in range(depth)
            ]
        )

        self.norm = nn.LayerNorm(embed_dim)

        # Classification head
        self.head = nn.Linear(embed_dim, num_classes)

        # Initialize weights
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        # x shape: (B, C, T, H, W)
        B = x.shape[0]

        # Patch embedding
        x = self.patch_embed(x)  # (B, embed_dim, T', H', W')
        x = rearrange(x, "b d t h w -> b (t h w) d")

        # Add cls token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)

        # Add positional embedding
        x = x + self.pos_embed[:, : x.size(1)]

        # Apply transformer blocks
        for block in self.blocks:
            x = block(x)

        x = self.norm(x)

        # Use cls token for classification
        x = x[:, 0]
        x = self.head(x)

        return x


class ViViT(nn.Module):
    """Video Vision Transformer with factorized self-attention."""

    def __init__(
        self,
        num_classes,
        img_size=224,
        patch_size=16,
        num_frames=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        dropout=0.1,
    ):
        super().__init__()
        self.num_frames = num_frames
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        # Tubelet embedding (3D patches)
        self.patch_embed = nn.Conv3d(
            3,
            embed_dim,
            kernel_size=(2, patch_size, patch_size),
            stride=(2, patch_size, patch_size),
        )

        # Calculate number of spatiotemporal patches
        self.num_time_patches = num_frames // 2
        self.num_space_patches = self.num_patches

        # Positional embeddings (separate for space and time)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.space_pos_embed = nn.Parameter(
            torch.zeros(1, self.num_space_patches, embed_dim)
        )
        self.time_pos_embed = nn.Parameter(
            torch.zeros(1, self.num_time_patches, embed_dim)
        )

        # Transformer blocks with factorized attention
        self.blocks = nn.ModuleList(
            [
                FactorizedTransformerBlock(
                    embed_dim,
                    num_heads,
                    self.num_time_patches,
                    self.num_space_patches,
                    mlp_ratio,
                    dropout,
                )
                for _ in range(depth)
            ]
        )

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        # Initialize weights
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.space_pos_embed, std=0.02)
        nn.init.trunc_normal_(self.time_pos_embed, std=0.02)

    def forward(self, x):
        # x shape: (B, C, T, H, W)
        B = x.shape[0]

        # Tubelet embedding
        x = self.patch_embed(x)  # (B, embed_dim, T', H', W')
        T_new, H_new, W_new = x.shape[2:]
        x = rearrange(x, "b d t h w -> b (t h w) d")

        # Add spatial and temporal position embeddings
        space_pos = self.space_pos_embed.repeat(
            1, T_new, 1
        )  # Repeat for each time step
        time_pos = self.time_pos_embed.repeat_interleave(
            H_new * W_new, dim=1
        )  # Repeat for each spatial position
        x = x + space_pos[:, : x.size(1)] + time_pos[:, : x.size(1)]

        # Add cls token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)

        # Apply transformer blocks
        for block in self.blocks:
            x = block(x, T_new, H_new * W_new)

        x = self.norm(x)

        # Use cls token for classification
        x = x[:, 0]
        x = self.head(x)

        return x


class FactorizedTransformerBlock(nn.Module):
    """Transformer block with factorized space-time attention."""

    def __init__(
        self,
        dim,
        num_heads,
        num_time_patches,
        num_space_patches,
        mlp_ratio=4.0,
        dropout=0.1,
    ):
        super().__init__()
        self.num_time_patches = num_time_patches
        self.num_space_patches = num_space_patches

        # Spatial attention
        self.norm1 = nn.LayerNorm(dim)
        self.spatial_attn = nn.MultiheadAttention(
            dim, num_heads, dropout=dropout, batch_first=True
        )

        # Temporal attention
        self.norm2 = nn.LayerNorm(dim)
        self.temporal_attn = nn.MultiheadAttention(
            dim, num_heads, dropout=dropout, batch_first=True
        )

        # MLP
        self.norm3 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x, T, HW):
        B, N, D = x.shape

        # Separate cls token
        cls_token, x_patches = x[:, :1], x[:, 1:]

        # Spatial attention (within each frame)
        x_spatial = rearrange(x_patches, "b (t hw) d -> (b t) hw d", t=T, hw=HW)
        x_spatial = self.norm1(x_spatial)
        x_spatial, _ = self.spatial_attn(x_spatial, x_spatial, x_spatial)
        x_patches = x_patches + rearrange(
            x_spatial, "(b t) hw d -> b (t hw) d", b=B, t=T
        )

        # Temporal attention (across frames)
        x_temporal = rearrange(x_patches, "b (t hw) d -> (b hw) t d", t=T, hw=HW)
        x_temporal = self.norm2(x_temporal)
        x_temporal, _ = self.temporal_attn(x_temporal, x_temporal, x_temporal)
        x_patches = x_patches + rearrange(
            x_temporal, "(b hw) t d -> b (t hw) d", b=B, hw=HW
        )

        # Recombine with cls token
        x = torch.cat([cls_token, x_patches], dim=1)

        # MLP
        x = x + self.mlp(self.norm3(x))

        return x


# 7️⃣ ───── MODEL FACTORY (COMPLETED WITH ALL MODELS) ─────
def create_model(
    model_name: str,
    num_classes: int,
    clip_len: int,
    freeze_backbone: bool = False,
    dropout: float = 0.5,
) -> nn.Module:
    """
    Create model based on model_name.

    Available models:
    - CNN-based: x3d_m, slow_r50, slowfast_r50, r2plus1d, r3d_18
    - CNN-RNN: cnn_lstm, cnn_gru (both bidirectional)
    - Transformer: timesformer, mvit, videomae, vivit
    """

    # CNN-based models
    if model_name == "x3d_m":
        model = torch.hub.load(
            "facebookresearch/pytorchvideo", "x3d_m", pretrained=True
        )
        in_features = model.blocks[-1].proj.in_features
        model.blocks[-1].proj = nn.Sequential(
            nn.Dropout(dropout), nn.Linear(in_features, num_classes)
        )

    elif model_name == "r2plus1d":
        model = torch.hub.load("pytorch/vision", "r2plus1d_18", pretrained=True)
        in_features = model.fc.in_features
        model.fc = nn.Sequential(
            nn.Dropout(dropout), nn.Linear(in_features, num_classes)
        )

    elif model_name == "r3d_18":
        model = torch.hub.load("pytorch/vision", "r3d_18", pretrained=True)
        in_features = model.fc.in_features
        model.fc = nn.Sequential(
            nn.Dropout(dropout), nn.Linear(in_features, num_classes)
        )

    elif model_name == "slow_r50":
        model = torch.hub.load(
            "facebookresearch/pytorchvideo", "slow_r50", pretrained=True
        )
        in_features = model.blocks[-1].proj.in_features
        model.blocks[-1].proj = nn.Sequential(
            nn.Dropout(dropout), nn.Linear(in_features, num_classes)
        )

    elif model_name == "slowfast_r50":
        model = torch.hub.load(
            "facebookresearch/pytorchvideo", "slowfast_r50", pretrained=True
        )
        in_features = model.blocks[-1].proj.in_features
        model.blocks[-1].proj = nn.Sequential(
            nn.Dropout(dropout), nn.Linear(in_features, num_classes)
        )

    # CNN-RNN models
    elif model_name == "cnn_lstm":
        model = CNN_LSTM(
            num_classes=num_classes, hidden_dim=512, num_layers=2, dropout=dropout
        )

    elif model_name == "cnn_gru":
        model = CNN_GRU(
            num_classes=num_classes, hidden_dim=512, num_layers=2, dropout=dropout
        )

    # Transformer models
    elif model_name == "timesformer":
        from timesformer_pytorch import TimeSformer

        class TimeSformerWrapper(nn.Module):
            """Wrapper to permute input tensor dimensions for TimeSformer."""

            def __init__(self, model):
                super().__init__()
                self.model = model

            def forward(self, x):
                # Permute from (B, C, T, H, W) to (B, T, C, H, W)
                return self.model(x.permute(0, 2, 1, 3, 4))

        timesformer_model = TimeSformer(
            dim=512,
            image_size=224,
            patch_size=16,
            num_frames=clip_len,
            num_classes=num_classes,
            depth=12,
            heads=8,
            dim_head=64,
            attn_dropout=dropout,
            ff_dropout=dropout,
        )
        model = TimeSformerWrapper(timesformer_model)

    elif model_name == "mvit":
        model = MViT(
            num_classes=num_classes,
            img_size=224,
            patch_size=16,
            num_frames=clip_len,
            embed_dim=768,
            depth=12,
            num_heads=12,
            mlp_ratio=4.0,
            dropout=dropout,
        )

    elif model_name == "videomae":
        model = VideoMAE(
            num_classes=num_classes,
            img_size=224,
            patch_size=16,
            num_frames=clip_len,
            embed_dim=768,
            depth=12,
            num_heads=12,
            decoder_embed_dim=512,
            decoder_depth=4,
            decoder_num_heads=8,
            mlp_ratio=4.0,
            dropout=dropout,
        )

    elif model_name == "vivit":
        model = ViViT(
            num_classes=num_classes,
            img_size=224,
            patch_size=16,
            num_frames=clip_len,
            embed_dim=768,
            depth=12,
            num_heads=12,
            mlp_ratio=4.0,
            dropout=dropout,
        )

    else:
        raise ValueError(f"Unknown model: {model_name}")

    # Freeze backbone if requested (except for transformer models and CNN-RNN)
    if freeze_backbone and model_name not in [
        "timesformer",
        "mvit",
        "videomae",
        "vivit",
        "cnn_lstm",
        "cnn_gru",
    ]:
        for name, param in model.named_parameters():
            if "proj" not in name and "fc" not in name:
                param.requires_grad = False

    return model


# 8️⃣ ───── DYNAMIC BATCH TUNER ─────
def tune_batch_size(
    model: nn.Module,
    device: torch.device,
    initial_batch_size: int,
    clip_len: int,
    model_name: str,
) -> Tuple[int, int]:
    """Dynamically tune batch size and clip length to fit in GPU memory."""
    model.eval()
    batch_size = initial_batch_size
    current_clip_len = clip_len

    # Adjust initial batch size based on model type
    if model_name in ["timesformer", "mvit", "videomae", "vivit"]:
        batch_size = max(1, batch_size // 2)  # Transformers need more memory

    # SlowFast and Slow models need at least 32 frames
    if model_name in ["slowfast_r50", "slow_r50"]:
        min_clip_len = 32
    else:
        min_clip_len = 8

    while batch_size >= 1 and current_clip_len >= min_clip_len:
        try:
            # Test forward pass
            dummy_input = torch.randn(batch_size, 3, current_clip_len, 224, 224).to(
                device
            )
            with torch.no_grad():
                if model_name.startswith("slowfast"):
                    alpha = 4
                    slow_pathway = dummy_input[:, :, ::alpha, :, :]
                    fast_pathway = dummy_input
                    model_inputs = [slow_pathway, fast_pathway]
                    _ = model(model_inputs)
                else:
                    _ = model(dummy_input)

            console.print(
                f"✅ Optimal batch_size: {batch_size}, clip_len: {current_clip_len}",
                style="green",
            )
            return batch_size, current_clip_len

        except RuntimeError as e:
            if "out of memory" in str(e) or "smaller than kernel size" in str(e):
                torch.cuda.empty_cache()
                gc.collect()

                if "smaller than kernel size" in str(e) and model_name in [
                    "slowfast_r50",
                    "slow_r50",
                ]:
                    # SlowFast/Slow models need more frames, double the clip length
                    current_clip_len *= 2
                    console.print(
                        f"⚠️ {model_name} needs more frames, increasing clip_len to {current_clip_len}",
                        style="yellow",
                    )
                elif batch_size > 1:
                    batch_size //= 2
                    console.print(
                        f"⚠️ OOM detected, reducing batch_size to {batch_size}",
                        style="yellow",
                    )
                else:
                    current_clip_len //= 2
                    batch_size = initial_batch_size
                    console.print(
                        f"⚠️ OOM detected, reducing clip_len to {current_clip_len}",
                        style="yellow",
                    )
            else:
                raise e

    console.print(
        "❌ Could not find suitable batch_size/clip_len combination", style="red"
    )
    return 1, min_clip_len


# 9️⃣ ───── TRAINING & VALIDATION LOOPS (WITH DETAILED LOGGING) ─────
class EarlyStopping:
    def __init__(self, patience: int = 5, min_delta: float = 0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float("inf")

    def __call__(self, val_loss: float) -> bool:
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1

        return self.counter >= self.patience


def collate_fn(batch):
    """Custom collate function to handle metadata."""
    videos = torch.stack([item[0] for item in batch])
    labels = torch.tensor([item[1] for item in batch])
    metadata = [item[2] for item in batch]
    return videos, labels, metadata


def train_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: optim.Optimizer,
    criterion: nn.Module,
    device: torch.device,
    scaler: amp.GradScaler,
    config: Dict,
    epoch: int,
    output_dirs: Dict,
) -> Dict[str, float]:
    """Train for one epoch with detailed logging."""
    model.train()
    metrics = {name: [] for name in config["metrics"]}

    # For detailed logging
    epoch_predictions = []

    with Progress() as progress:
        task = progress.add_task("Training...", total=len(dataloader))

        for batch_idx, (videos, labels, metadata) in enumerate(dataloader):
            videos, labels = videos.to(device), labels.to(device)

            optimizer.zero_grad()

            with amp.autocast(enabled=config["hardware"]["mixed_precision"]):
                if config["model"]["model_name"].startswith("slowfast"):
                    alpha = 4
                    slow_pathway = videos[:, :, ::alpha, :, :]
                    fast_pathway = videos
                    model_inputs = [slow_pathway, fast_pathway]
                    outputs = model(model_inputs)
                else:
                    outputs = model(videos)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            # Calculate metrics and save predictions
            with torch.no_grad():
                probs = torch.softmax(outputs, dim=1)
                preds = torch.argmax(probs, dim=1)

                # Save detailed predictions
                for i, meta in enumerate(metadata):
                    epoch_predictions.append(
                        {
                            "batch_idx": batch_idx,
                            "video_path": meta["video_path"],
                            "snippet_idx": meta["snippet_idx"],
                            "class_name": meta["class_name"],
                            "true_label": labels[i].item(),
                            "predicted_label": preds[i].item(),
                            "probabilities": probs[i].cpu().numpy().tolist(),
                            "loss": loss.item(),
                        }
                    )

                metrics["loss"].append(loss.item())
                metrics["accuracy"].append((preds == labels).float().mean().item())

                # Binary classification metrics
                if len(torch.unique(labels)) <= 2:
                    tp = ((preds == 1) & (labels == 1)).sum().item()
                    tn = ((preds == 0) & (labels == 0)).sum().item()
                    fp = ((preds == 1) & (labels == 0)).sum().item()
                    fn = ((preds == 0) & (labels == 1)).sum().item()

                    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
                    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
                    f1 = (
                        2 * precision * recall / (precision + recall)
                        if (precision + recall) > 0
                        else 0
                    )
                    sensitivity = recall
                    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0

                    metrics["precision"].append(precision)
                    metrics["recall"].append(recall)
                    metrics["f1"].append(f1)
                    metrics["sensitivity"].append(sensitivity)
                    metrics["specificity"].append(specificity)
                    metrics["auc"].append(0.5)  # Placeholder

            progress.advance(task)

            if batch_idx % config["logging"]["print_freq"] == 0:
                emoji = "🔥" if config["logging"]["emojis"] else ""
                console.print(
                    f"{emoji} Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}"
                )

    # Save detailed predictions
    if config["logging"]["save_detailed_predictions"]:
        predictions_file = (
            output_dirs["predictions"] / f"train_epoch_{epoch}_predictions.json"
        )
        with open(predictions_file, "w") as f:
            json.dump(epoch_predictions, f, indent=2)

    return {k: np.mean(v) for k, v in metrics.items()}


@torch.no_grad()
def validate_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    criterion: nn.Module,
    device: torch.device,
    config: Dict,
    epoch: int,
    output_dirs: Dict,
    split_name: str = "val",
) -> Dict[str, float]:
    """Validate for one epoch with detailed logging."""
    model.eval()
    metrics = {name: [] for name in config["metrics"]}

    all_preds, all_labels, all_probs = [], [], []
    epoch_predictions = []

    with Progress() as progress:
        task = progress.add_task(
            f"{split_name.capitalize()}ating...", total=len(dataloader)
        )

        for batch_idx, (videos, labels, metadata) in enumerate(dataloader):
            videos, labels = videos.to(device), labels.to(device)

            with amp.autocast(enabled=config["hardware"]["mixed_precision"]):
                if config["model"]["model_name"].startswith("slowfast"):
                    alpha = 4
                    slow_pathway = videos[:, :, ::alpha, :, :]
                    fast_pathway = videos
                    model_inputs = [slow_pathway, fast_pathway]
                    outputs = model(model_inputs)
                else:
                    outputs = model(videos)
                loss = criterion(outputs, labels)

            probs = torch.softmax(outputs, dim=1)
            preds = torch.argmax(probs, dim=1)

            # Save detailed predictions
            for i, meta in enumerate(metadata):
                epoch_predictions.append(
                    {
                        "batch_idx": batch_idx,
                        "video_path": meta["video_path"],
                        "snippet_idx": meta["snippet_idx"],
                        "class_name": meta["class_name"],
                        "true_label": labels[i].item(),
                        "predicted_label": preds[i].item(),
                        "probabilities": probs[i].cpu().numpy().tolist(),
                        "loss": loss.item(),
                    }
                )

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

            metrics["loss"].append(loss.item())
            metrics["accuracy"].append((preds == labels).float().mean().item())

            progress.advance(task)

    # Save detailed predictions
    if config["logging"]["save_detailed_predictions"]:
        predictions_file = (
            output_dirs["predictions"] / f"{split_name}_epoch_{epoch}_predictions.json"
        )
        with open(predictions_file, "w") as f:
            json.dump(epoch_predictions, f, indent=2)

    # Calculate overall metrics
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)

    # Binary classification metrics
    if len(np.unique(all_labels)) <= 2:
        tp = ((all_preds == 1) & (all_labels == 1)).sum()
        tn = ((all_preds == 0) & (all_labels == 0)).sum()
        fp = ((all_preds == 1) & (all_labels == 0)).sum()
        fn = ((all_preds == 0) & (all_labels == 1)).sum()

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = (
            2 * precision * recall / (precision + recall)
            if (precision + recall) > 0
            else 0
        )
        sensitivity = recall
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0

        from sklearn.metrics import roc_auc_score

        auc = (
            roc_auc_score(all_labels, all_probs[:, 1])
            if all_probs.shape[1] == 2
            else 0.5
        )

        return {
            "loss": np.mean(metrics["loss"]),
            "accuracy": np.mean(metrics["accuracy"]),
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "sensitivity": sensitivity,
            "specificity": specificity,
            "auc": auc,
        }

    return {k: np.mean(v) for k, v in metrics.items()}


# 🔟 ───── PLOT PDFS ─────
def create_plots(
    train_history: Dict, val_history: Dict, output_dirs: Dict, config: Dict
) -> None:
    """Create training plots and save as PDFs."""
    plt.style.use("default")

    # Create a single figure with subplots for all metrics
    n_metrics = len(config["metrics"])
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    axes = axes.flatten()

    for idx, metric in enumerate(config["metrics"]):
        if metric in train_history and metric in val_history:
            ax = axes[idx]
            epochs = range(1, len(train_history[metric]) + 1)

            ax.plot(epochs, train_history[metric], "b-", label=f"Train", linewidth=2)
            ax.plot(epochs, val_history[metric], "r-", label=f"Val", linewidth=2)

            ax.set_title(f"{metric.title()}", fontsize=14, fontweight="bold")
            ax.set_xlabel("Epochs", fontsize=12)
            ax.set_ylabel(metric.title(), fontsize=12)
            ax.legend(fontsize=10)
            ax.grid(True, alpha=0.3)

    # Hide unused subplots
    for idx in range(n_metrics, len(axes)):
        axes[idx].set_visible(False)

    plt.tight_layout()
    plt.savefig(
        output_dirs["plots"] / "all_metrics.pdf",
        format="pdf",
        bbox_inches="tight",
        dpi=300,
    )
    plt.close()

    # Also save individual metric plots
    for metric in config["metrics"]:
        if metric in train_history and metric in val_history:
            plt.figure(figsize=(10, 6))
            epochs = range(1, len(train_history[metric]) + 1)

            plt.plot(
                epochs,
                train_history[metric],
                "b-",
                label=f"Train {metric.title()}",
                linewidth=2,
            )
            plt.plot(
                epochs,
                val_history[metric],
                "r-",
                label=f"Val {metric.title()}",
                linewidth=2,
            )

            plt.title(f"{metric.title()} vs Epochs", fontsize=16, fontweight="bold")
            plt.xlabel("Epochs", fontsize=12)
            plt.ylabel(metric.title(), fontsize=12)
            plt.legend(fontsize=12)
            plt.grid(True, alpha=0.3)

            plt.savefig(
                output_dirs["plots"] / f"{metric}.pdf",
                format="pdf",
                bbox_inches="tight",
                dpi=300,
            )
            plt.close()


def create_confusion_matrix(
    y_true: np.ndarray, y_pred: np.ndarray, class_names: List[str], output_dirs: Dict
) -> None:
    """Create confusion matrix plots."""
    cm = confusion_matrix(y_true, y_pred)

    # Counts confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(
        cm,
        annot=True,
        fmt="d",
        cmap="Blues",
        xticklabels=class_names,
        yticklabels=class_names,
    )
    plt.title("Confusion Matrix (Counts)", fontsize=16, fontweight="bold")
    plt.ylabel("True Label", fontsize=12)
    plt.xlabel("Predicted Label", fontsize=12)
    plt.savefig(
        output_dirs["plots"] / "confmat_counts.pdf",
        format="pdf",
        bbox_inches="tight",
        dpi=300,
    )
    plt.close()

    # Percentage confusion matrix
    cm_pct = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] * 100
    plt.figure(figsize=(8, 6))
    sns.heatmap(
        cm_pct,
        annot=True,
        fmt=".1f",
        cmap="Blues",
        xticklabels=class_names,
        yticklabels=class_names,
    )
    plt.title("Confusion Matrix (Percentages)", fontsize=16, fontweight="bold")
    plt.ylabel("True Label", fontsize=12)
    plt.xlabel("Predicted Label", fontsize=12)
    plt.savefig(
        output_dirs["plots"] / "confmat_pct.pdf",
        format="pdf",
        bbox_inches="tight",
        dpi=300,
    )
    plt.close()


# ❶❶ ───── INFERENCE HELPER ─────
@torch.no_grad()
def run_inference(
    model: nn.Module,
    video_path: str,
    device: torch.device,
    class_names: List[str],
    config: Dict,
) -> Dict:
    """Run inference on a single video."""
    model.eval()

    dataset = VideoDataset(
        [{"video_path": video_path, "class_idx": 0, "class_name": "unknown"}],
        config["data"]["clip_len"],
        config["data"]["frame_rate"],
        overlap=0,  # No overlap for single inference
        augment=False,
    )

    video_tensor, _, metadata = dataset[0]
    video_tensor = video_tensor.unsqueeze(0).to(device)

    with amp.autocast(enabled=config["hardware"]["mixed_precision"]):
        if config["model"]["model_name"].startswith("slowfast"):
            alpha = 4
            slow_pathway = video_tensor[:, :, ::alpha, :, :]
            fast_pathway = video_tensor
            model_inputs = [slow_pathway, fast_pathway]
            outputs = model(model_inputs)
        else:
            outputs = model(video_tensor)

    probs = torch.softmax(outputs, dim=1).cpu().numpy()[0]
    pred_class = np.argmax(probs)

    results = {
        "predicted_class": class_names[pred_class],
        "confidence": float(probs[pred_class]),
        "all_probabilities": {
            class_names[i]: float(probs[i]) for i in range(len(class_names))
        },
    }

    return results


# ❶❷ ───── MAIN FUNCTION ─────
def main():
    """Main execution function."""
    print_section("INITIALIZATION", "🚀")

    # Seed everything
    seed_everything(config["train"]["seed"])

    # Check GPU
    has_gpu, gpu_mem = check_gpu_memory()
    if not has_gpu and config["hardware"]["gpus"] > 0:
        console.print(
            "⚠️ No GPU detected but GPU requested. Switching to CPU.", style="yellow"
        )
        device = torch.device("cpu")
    else:
        device = torch.device(
            "cuda" if has_gpu and config["hardware"]["gpus"] > 0 else "cpu"
        )

    console.print(f"🖥️ Device: {device}")
    if has_gpu:
        console.print(f"💾 GPU Memory: {gpu_mem:.1f} GB")

    # Setup output directories
    output_dirs = setup_output_dirs(config["paths"]["output_root"])
    console.print(f"📁 Output directory: {output_dirs['root']}")

    # Save config
    with open(output_dirs["root"] / "config.yaml", "w") as f:
        yaml.dump(config, f, default_flow_style=False)

    print_section("DATA PREPARATION", "📊")

    # Create splits
    splits_info = create_splits(config["paths"]["data_root"], config)

    # Save splits
    with open(output_dirs["root"] / "splits.json", "w") as f:
        json.dump(splits_info, f, indent=2)

    class_names = splits_info["class_names"]
    num_classes = splits_info["num_classes"]

    console.print(f"📋 Classes: {class_names}")
    console.print(f"🔢 Number of classes: {num_classes}")

    # Create datasets
    train_dataset = VideoDataset(
        splits_info["splits"]["train"],
        config["data"]["clip_len"],
        config["data"]["frame_rate"],
        config["data"]["snippet_overlap"],
        augment=False,
    )
    val_dataset = VideoDataset(
        splits_info["splits"]["val"],
        config["data"]["clip_len"],
        config["data"]["frame_rate"],
        config["data"]["snippet_overlap"],
        augment=False,
    )
    test_dataset = VideoDataset(
        splits_info["splits"]["test"],
        config["data"]["clip_len"],
        config["data"]["frame_rate"],
        config["data"]["snippet_overlap"],
        augment=False,
    )

    console.print(f"📈 Train snippets: {len(train_dataset)}")
    console.print(f"📊 Val snippets: {len(val_dataset)}")
    console.print(f"📉 Test snippets: {len(test_dataset)}")

    print_section("MODEL CREATION", "🧠")

    # Create model
    model = create_model(
        config["model"]["model_name"],
        num_classes,
        config["data"]["clip_len"],
        config["model"]["freeze_backbone"],
        config["model"]["dropout"],
    )
    model = model.to(device)

    console.print(f"🎯 Model: {config['model']['model_name']}")
    console.print(f"🔒 Freeze backbone: {config['model']['freeze_backbone']}")

    # Dynamic batch size tuning
    if device.type == "cuda":
        original_clip_len = config["data"]["clip_len"]
        optimal_batch_size, optimal_clip_len = tune_batch_size(
            model,
            device,
            config["train"]["batch_size"],
            original_clip_len,
            config["model"]["model_name"],
        )

        config["train"]["batch_size"] = optimal_batch_size

        if optimal_clip_len != original_clip_len:
            console.print(
                f"🔄 Clip length auto-adjusted from {original_clip_len} to {optimal_clip_len}",
                style="yellow",
            )
            config["data"]["clip_len"] = optimal_clip_len

            # Recreate datasets with new clip length
            train_dataset = VideoDataset(
                splits_info["splits"]["train"],
                optimal_clip_len,
                config["data"]["frame_rate"],
                config["data"]["snippet_overlap"],
                augment=False,
            )
            val_dataset = VideoDataset(
                splits_info["splits"]["val"],
                optimal_clip_len,
                config["data"]["frame_rate"],
                config["data"]["snippet_overlap"],
                augment=False,
            )
            test_dataset = VideoDataset(
                splits_info["splits"]["test"],
                optimal_clip_len,
                config["data"]["frame_rate"],
                config["data"]["snippet_overlap"],
                augment=False,
            )

            # Recreate model if needed (for transformers that depend on clip_len)
            if config["model"]["model_name"] in [
                "timesformer",
                "mvit",
                "videomae",
                "vivit",
            ]:
                model = create_model(
                    config["model"]["model_name"],
                    num_classes,
                    optimal_clip_len,
                    config["model"]["freeze_backbone"],
                    config["model"]["dropout"],
                )
                model = model.to(device)

    # Create data loaders with custom collate function
    train_loader = DataLoader(
        train_dataset,
        batch_size=config["train"]["batch_size"],
        shuffle=True,
        num_workers=config["data"]["num_workers"],
        pin_memory=True if device.type == "cuda" else False,
        collate_fn=collate_fn,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config["train"]["batch_size"],
        shuffle=False,
        num_workers=config["data"]["num_workers"],
        pin_memory=True if device.type == "cuda" else False,
        collate_fn=collate_fn,
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=config["train"]["batch_size"],
        shuffle=False,
        num_workers=config["data"]["num_workers"],
        pin_memory=True if device.type == "cuda" else False,
        collate_fn=collate_fn,
    )

    if config["modes"]["run_training"]:
        print_section("TRAINING", "🏋️")

        # Setup training components
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(
            model.parameters(),
            lr=config["train"]["lr"],
            weight_decay=config["train"]["weight_decay"],
        )

        if config["train"]["scheduler"] == "cosine":
            scheduler = optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=config["train"]["epochs"]
            )
        else:
            scheduler = optim.lr_scheduler.StepLR(
                optimizer,
                step_size=config["train"]["step_size"],
                gamma=config["train"]["gamma"],
            )

        scaler = amp.GradScaler(enabled=config["hardware"]["mixed_precision"])
        early_stopping = EarlyStopping(patience=config["train"]["early_stop_patience"])

        # Training history
        train_history = {metric: [] for metric in config["metrics"]}
        val_history = {metric: [] for metric in config["metrics"]}

        best_val_loss = float("inf")

        # Training loop
        for epoch in range(config["train"]["epochs"]):
            epoch_start_time = time.time()

            emoji = "🔥" if config["logging"]["emojis"] else ""
            console.print(
                f"\n{emoji} Epoch {epoch+1}/{config['train']['epochs']}",
                style="bold magenta",
            )

            # Train
            train_metrics = train_epoch(
                model,
                train_loader,
                optimizer,
                criterion,
                device,
                scaler,
                config,
                epoch + 1,
                output_dirs,
            )

            # Validate
            val_metrics = validate_epoch(
                model,
                val_loader,
                criterion,
                device,
                config,
                epoch + 1,
                output_dirs,
                "val",
            )

            # Update history
            for metric in config["metrics"]:
                if metric in train_metrics:
                    train_history[metric].append(train_metrics[metric])
                if metric in val_metrics:
                    val_history[metric].append(val_metrics[metric])

            # Learning rate step
            scheduler.step()

            # Print epoch summary
            epoch_time = time.time() - epoch_start_time

            table = Table(title=f"Epoch {epoch+1} Summary")
            table.add_column("Metric", style="cyan")
            table.add_column("Train", style="green")
            table.add_column("Val", style="red")

            for metric in config["metrics"]:
                if metric in train_metrics and metric in val_metrics:
                    table.add_row(
                        metric.title(),
                        f"{train_metrics[metric]:.4f}",
                        f"{val_metrics[metric]:.4f}",
                    )

            console.print(table)
            console.print(f"⏱️ Epoch time: {epoch_time:.2f}s")

            # Save checkpoint every epoch
            torch.save(
                {
                    "epoch": epoch + 1,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "train_metrics": train_metrics,
                    "val_metrics": val_metrics,
                    "config": config,
                },
                output_dirs["checkpoints"] / f"epoch_{epoch+1}.pth",
            )

            # Save best model
            if val_metrics["loss"] < best_val_loss:
                best_val_loss = val_metrics["loss"]
                torch.save(
                    {
                        "epoch": epoch + 1,
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "scheduler_state_dict": scheduler.state_dict(),
                        "val_loss": val_metrics["loss"],
                        "config": config,
                    },
                    output_dirs["checkpoints"] / "best.pth",
                )

                star = "⭐" if config["logging"]["emojis"] else "*"
                console.print(f"{star} New best model saved!", style="bold green")

            # Save training history
            history_data = {
                "train": train_history,
                "val": val_history,
                "current_epoch": epoch + 1,
            }
            with open(output_dirs["logs"] / "training_history.json", "w") as f:
                json.dump(history_data, f, indent=2)

            # Early stopping
            if early_stopping(val_metrics["loss"]):
                stop_emoji = "🛑" if config["logging"]["emojis"] else "STOP"
                console.print(
                    f"{stop_emoji} Early stopping triggered!", style="bold red"
                )
                break

        # Create training plots
        create_plots(train_history, val_history, output_dirs, config)

        console.print("✅ Training completed!", style="bold green")

    if config["modes"]["run_eval"]:
        print_section("EVALUATION", "📊")

        # Load best model
        checkpoint = torch.load(
            output_dirs["checkpoints"] / "best.pth",
            map_location=device,
            weights_only=False,
        )
        model.load_state_dict(checkpoint["model_state_dict"])

        console.print("✅ Best model loaded for evaluation")

        # Evaluate on test set
        criterion = nn.CrossEntropyLoss()
        test_metrics = validate_epoch(
            model,
            test_loader,
            criterion,
            device,
            config,
            checkpoint["epoch"],
            output_dirs,
            "test",
        )

        # Get predictions for confusion matrix
        model.eval()
        all_preds, all_labels = [], []

        with torch.no_grad():
            for videos, labels, metadata in test_loader:
                videos, labels = videos.to(device), labels.to(device)

                with amp.autocast(enabled=config["hardware"]["mixed_precision"]):
                    if config["model"]["model_name"].startswith("slowfast"):
                        alpha = 4
                        slow_pathway = videos[:, :, ::alpha, :, :]
                        fast_pathway = videos
                        model_inputs = [slow_pathway, fast_pathway]
                        outputs = model(model_inputs)
                    else:
                        outputs = model(videos)

                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        # Create confusion matrix
        create_confusion_matrix(
            np.array(all_labels), np.array(all_preds), class_names, output_dirs
        )

        # Save test metrics
        metrics_df = pd.DataFrame([test_metrics])
        metrics_df.to_csv(output_dirs["root"] / "test_metrics.csv", index=False)

        # Print test results
        table = Table(title="Test Set Results")
        table.add_column("Metric", style="cyan")
        table.add_column("Value", style="green")

        for metric, value in test_metrics.items():
            table.add_row(metric.title(), f"{value:.4f}")

        console.print(table)

        # Classification report
        from sklearn.metrics import classification_report

        report = classification_report(
            all_labels, all_preds, target_names=class_names, output_dict=True
        )

        # Save classification report
        with open(output_dirs["logs"] / "classification_report.json", "w") as f:
            json.dump(report, f, indent=2)

        console.print("\n📋 Detailed Classification Report:")
        print(classification_report(all_labels, all_preds, target_names=class_names))

        console.print("✅ Evaluation completed!", style="bold green")

    if config["modes"]["run_inference"]:
        print_section("INFERENCE DEMO", "🎯")

        inference_video = config["modes"]["inference_video"]

        if os.path.exists(inference_video):
            # Load best model if not already loaded
            if not config["modes"]["run_eval"]:
                checkpoint = torch.load(
                    output_dirs["checkpoints"] / "best.pth",
                    map_location=device,
                    weights_only=False,
                )
                model.load_state_dict(checkpoint["model_state_dict"])

            # Run inference
            results = run_inference(model, inference_video, device, class_names, config)

            # Display results
            table = Table(title=f"Inference Results")
            table.add_column("Class", style="cyan")
            table.add_column("Probability", style="green")

            for class_name, prob in results["all_probabilities"].items():
                style = (
                    "bold green"
                    if class_name == results["predicted_class"]
                    else "white"
                )
                table.add_row(class_name, f"{prob:.4f}", style=style)

            console.print(table)

            pred_emoji = "🎯" if config["logging"]["emojis"] else ">>>"
            console.print(
                f"{pred_emoji} Predicted: {results['predicted_class']} "
                f"(Confidence: {results['confidence']:.4f})",
                style="bold green",
            )

            # Save inference results
            inference_results = {
                "video_path": inference_video,
                "results": results,
                "timestamp": datetime.now().isoformat(),
            }
            with open(output_dirs["logs"] / "inference_results.json", "w") as f:
                json.dump(inference_results, f, indent=2)
        else:
            console.print(
                f"❌ Inference video not found: {inference_video}", style="red"
            )

    print_section("COMPLETION", "🎉")

    # Summary
    console.print("🏁 All tasks completed successfully!", style="bold green")
    console.print(f"📁 Check outputs at: {output_dirs['root']}")

    # Final summary report
    summary = {
        "run_timestamp": datetime.now().isoformat(),
        "output_directory": str(output_dirs["root"]),
        "config": config,
        "splits_info": {
            "num_classes": splits_info["num_classes"],
            "class_names": splits_info["class_names"],
            "videos_per_class": splits_info["videos_per_class"],
            "split_mode": splits_info["split_mode"],
        },
        "model_info": {
            "model_name": config["model"]["model_name"],
            "num_classes": num_classes,
            "clip_len": config["data"]["clip_len"],
            "batch_size": config["train"]["batch_size"],
        },
    }

    with open(output_dirs["root"] / "run_summary.json", "w") as f:
        json.dump(summary, f, indent=2)

    # Print file tree
    tree_emoji = "🌳" if config["logging"]["emojis"] else "TREE"
    console.print(f"\n{tree_emoji} Output file structure:")

    def print_tree(path, prefix="", is_last=True):
        """Print directory tree."""
        if path.is_file():
            connector = "└── " if is_last else "├── "
            console.print(f"{prefix}{connector}{path.name}")
        elif path.is_dir():
            connector = "└── " if is_last else "├── "
            console.print(f"{prefix}{connector}{path.name}/")

            children = list(path.iterdir())
            for i, child in enumerate(children):
                is_last_child = i == len(children) - 1
                extension = "    " if is_last else "│   "
                print_tree(child, prefix + extension, is_last_child)

    print_tree(output_dirs["root"])


if __name__ == "__main__":
    main()