# FVC Binary Video Classifier

This notebook wires together the FVC video library to:

1. Run `setup_fvc_dataset.py` to prepare the dataset.
2. Build augmented train/val loaders.
3. Train the variable-aspect-ratio 3D CNN binary classifier.
4. Log metrics and save intermediates/checkpoints.

All warnings are captured into the Python logger so they are logged but do not stop execution.


In [None]:
import os
import sys
import logging
import warnings
import subprocess

# Configure logging (Python 3.6 compatible - no force=True)
# Clear existing handlers to allow reconfiguration
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
    root_logger.removeHandler(handler)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger("fvc_notebook")

# Capture Python warnings into logging
logging.captureWarnings(True)
warnings.filterwarnings("default")

# Detect project root: look for src/fvc_binary_classifier.ipynb in parent dirs
# or use SLURM_SUBMIT_DIR if available, or current working directory
if "SLURM_SUBMIT_DIR" in os.environ:
    PROJECT_ROOT = os.environ["SLURM_SUBMIT_DIR"]
elif "SLURM_TMPDIR" in os.environ:
    # If running in TMPDIR, project was copied there
    PROJECT_ROOT = os.environ["SLURM_TMPDIR"]
else:
    # Fallback: assume we're in project root or one level down
    cwd = os.getcwd()
    if os.path.exists(os.path.join(cwd, "src", "fvc_binary_classifier.ipynb")):
        PROJECT_ROOT = cwd
    elif os.path.exists(os.path.join(cwd, "..", "src", "fvc_binary_classifier.ipynb")):
        PROJECT_ROOT = os.path.dirname(cwd)
    else:
        PROJECT_ROOT = cwd

PROJECT_ROOT = os.path.abspath(PROJECT_ROOT)
DATA_CSV = os.path.join(PROJECT_ROOT, "data", "video_index_input.csv")
INTERMEDIATE_DIR = os.path.join(PROJECT_ROOT, "data", "intermediates")
MODELS_DIR = os.path.join(PROJECT_ROOT, "models")

os.makedirs(INTERMEDIATE_DIR, exist_ok=True)
os.makedirs(MODELS_DIR, exist_ok=True)

logger.info("Python version: %s", sys.version)
logger.info("Python version info: %s.%s.%s", sys.version_info.major, sys.version_info.minor, sys.version_info.micro)

# Check Python version compatibility
if sys.version_info < (3, 6):
    logger.error("Python 3.6+ required, found %s.%s", sys.version_info.major, sys.version_info.minor)
    raise RuntimeError("Python 3.6+ required")

# Check for dataclasses (Python 3.7+)
try:
    from dataclasses import dataclass
    logger.info("✓ dataclasses module available")
except ImportError:
    logger.error("dataclasses module not available (requires Python 3.7+)")
    raise RuntimeError("dataclasses module required but not available")

logger.info("Current working directory: %s", os.getcwd())
logger.info("Project root: %s", PROJECT_ROOT)
logger.info("Data CSV: %s", DATA_CSV)
logger.info("Intermediates dir: %s", INTERMEDIATE_DIR)
logger.info("Models dir: %s", MODELS_DIR)

# Verify critical files exist
if not os.path.exists(DATA_CSV):
    logger.error("Data CSV not found at %s", DATA_CSV)
    raise FileNotFoundError("Data CSV not found: {}".format(DATA_CSV))
logger.info("✓ Data CSV exists")

# Verify required packages are importable
try:
    import polars as pl
    import torch
    logger.info("✓ Core packages (polars, torch) importable")
except ImportError as e:
    logger.error("Failed to import required package: %s", e)
    raise


In [None]:
# 1. Run setup_fvc_dataset.py to prepare the dataset

try:
    logger.info("Running setup_fvc_dataset.py...")
    # Python 3.6 compatible: use stdout/stderr instead of capture_output
    result = subprocess.run(
        ["python", os.path.join(PROJECT_ROOT, "src", "setup_fvc_dataset.py")],
        check=True,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
    )
    # Decode bytes to string (Python 3.6 compatible)
    stdout_text = result.stdout.decode('utf-8') if result.stdout else ""
    stderr_text = result.stderr.decode('utf-8') if result.stderr else ""
    
    logger.info("setup_fvc_dataset.py stdout:\n%s", stdout_text)
    if stderr_text:
        logger.warning("setup_fvc_dataset.py stderr:\n%s", stderr_text)
    logger.info("setup_fvc_dataset.py completed successfully.")
except subprocess.CalledProcessError as e:
    stdout_text = e.stdout.decode('utf-8') if e.stdout else ""
    stderr_text = e.stderr.decode('utf-8') if e.stderr else ""
    logger.error("setup_fvc_dataset.py failed with return code %s", e.returncode)
    logger.error("stdout:\n%s", stdout_text)
    logger.error("stderr:\n%s", stderr_text)
    raise


In [None]:
# 2. Build augmented train/val loaders using the FVC video library

import polars as pl
import torch
from pathlib import Path

from lib.video_data import SplitConfig, load_metadata, train_val_test_split, filter_existing_videos, make_balanced_batch_sampler
from lib.video_modeling import VideoConfig, VideoDataset, variable_ar_collate
from lib.video_training import OptimConfig, TrainConfig, fit
from lib.video_augmentation_pipeline import pregenerate_augmented_dataset
from torch.utils.data import DataLoader

logger.info("Loading metadata with Polars from %s", DATA_CSV)
meta_df = load_metadata(DATA_CSV)
logger.info("Metadata rows: %d", meta_df.height)

# Filter out missing video files before splitting
logger.info("Filtering out missing video files...")
try:
    meta_df_filtered = filter_existing_videos(meta_df, PROJECT_ROOT)
    missing_count = meta_df.height - meta_df_filtered.height
    if missing_count > 0:
        logger.warning("Filtered out %d missing video files (keeping %d)", missing_count, meta_df_filtered.height)
    else:
        logger.info("All %d video files exist", meta_df_filtered.height)
except ValueError as e:
    logger.error("Failed to filter videos: %s", str(e))
    logger.error("Cannot proceed with training - no valid videos found.")
    raise

# Validate we have enough videos for training
if meta_df_filtered.height < 2:
    raise ValueError(
        f"Not enough videos for training. Found {meta_df_filtered.height} videos, "
        f"but need at least 2 (one for train, one for val)."
    )

splits = train_val_test_split(
    meta_df_filtered,
    SplitConfig(val_size=0.2, test_size=0.0),
    save_dir=INTERMEDIATE_DIR,
)
train_df = splits["train"]
val_df = splits["val"]

logger.info("Train rows: %d, Val rows: %d", train_df.height, val_df.height)

# Final validation before creating datasets
if train_df.height == 0:
    raise ValueError("Train dataset is empty after filtering. Cannot create DataLoader.")
if val_df.height == 0:
    logger.warning("Validation dataset is empty. Consider reducing val_size or checking data splits.")

# GPU-optimized configuration with comprehensive augmentations
# Strategy: Fixed-size downscaling + comprehensive augmentations for better generalization
# Using fixed_size=224 ensures all videos have 224x224 dimensions (with letterboxing for aspect ratio)
# Comprehensive augmentations include: geometric, color, noise, blur, cutout, and temporal augmentations
video_cfg = VideoConfig(
    num_frames=16,  # Increased back to 16 - fixed_size=224 makes this memory-efficient
    fixed_size=224,  # Fixed 224x224 size with letterboxing (maintains aspect ratio, no padding needed)
    # Comprehensive spatial augmentations
    augmentation_config={
        'rotation_degrees': 15.0,  # Small rotations
        'rotation_p': 0.5,
        'affine_p': 0.3,  # Translation, scale, shear
        'gaussian_noise_std': 0.1,  # Noise injection
        'gaussian_noise_p': 0.3,
        'gaussian_blur_p': 0.3,  # Blur augmentation
        'cutout_p': 0.5,  # Random erasing
        'cutout_max_size': 32,  # Max cutout size
        'elastic_transform_p': 0.2,  # Elastic deformation
        'color_jitter_brightness': 0.3,  # Enhanced color jitter
        'color_jitter_contrast': 0.3,
        'color_jitter_saturation': 0.3,
        'color_jitter_hue': 0.1,
    },
    # Temporal augmentations
    temporal_augmentation_config={
        'frame_drop_prob': 0.1,  # Randomly drop frames
        'frame_dup_prob': 0.1,  # Randomly duplicate frames (slow motion)
        'reverse_prob': 0.1,  # Reverse temporal order
    }
)
logger.info("Video config: num_frames=%d, fixed_size=%s (GPU-optimized with comprehensive augmentations)", 
            video_cfg.num_frames, video_cfg.fixed_size)
logger.info("Augmentations enabled: geometric, color, noise, blur, cutout, temporal")

# STEP 1: Pre-generate augmented clips BEFORE training
# This stores augmented data on disk, making training faster and more reproducible
AUGMENTED_DATA_DIR = os.path.join(INTERMEDIATE_DIR, "augmented_clips")
NUM_AUGMENTATIONS_PER_VIDEO = 3  # Generate 3 augmented versions per video

logger.info("=" * 80)
logger.info("STEP 1: Pre-generating augmented clips...")
logger.info("This will create %d augmented versions per training video", NUM_AUGMENTATIONS_PER_VIDEO)
logger.info("Output directory: %s", AUGMENTED_DATA_DIR)

try:
    # Check if augmented data already exists
    if os.path.exists(AUGMENTED_DATA_DIR) and len(list(Path(AUGMENTED_DATA_DIR).glob("*.pt"))) > 0:
        logger.info("✓ Augmented clips already exist. Skipping generation.")
        logger.info("  To regenerate, delete: %s", AUGMENTED_DATA_DIR)
        # Load existing augmented dataset metadata
        augmented_train_df = pl.read_csv(os.path.join(AUGMENTED_DATA_DIR, "augmented_train_metadata.csv"))
        logger.info("Loaded %d pre-generated augmented clips", augmented_train_df.height)
    else:
        logger.info("Generating augmented clips (this may take a while)...")
        augmented_train_df = pregenerate_augmented_dataset(
            train_df,
            PROJECT_ROOT,
            video_cfg,
            output_dir=AUGMENTED_DATA_DIR,
            num_augmentations_per_video=NUM_AUGMENTATIONS_PER_VIDEO,
        )
        
        # Save metadata for future use
        metadata_path = os.path.join(AUGMENTED_DATA_DIR, "augmented_train_metadata.csv")
        augmented_train_df.write_csv(metadata_path)
        logger.info("✓ Generated %d augmented clips", augmented_train_df.height)
        logger.info("✓ Saved metadata to: %s", metadata_path)
    
    # Use augmented dataset for training
    train_df_final = augmented_train_df
    logger.info("Using pre-generated augmented clips for training")
    
except Exception as e:
    logger.error("Failed to pre-generate augmentations: %s", str(e))
    logger.warning("Falling back to on-the-fly augmentations during training")
    train_df_final = train_df

logger.info("=" * 80)

# STEP 2: Create datasets
# Training dataset: use pre-generated augmented clips (or fallback to on-the-fly)
# Validation dataset: use original videos (no augmentation)
train_ds = VideoDataset(train_df_final, PROJECT_ROOT, config=video_cfg, train=False)  # train=False since augmentations are pre-generated
val_ds = VideoDataset(val_df, PROJECT_ROOT, config=video_cfg, train=False)  # No augmentation for validation

logger.info("Train dataset: %d samples (pre-generated augmentations)", len(train_ds))
logger.info("Val dataset: %d samples (original videos)", len(val_ds))

# GPU-optimized: With fixed_size=224, we can use much larger batch sizes
# Progressive batch size strategy to fully utilize GPU memory
# Priority: batch_size=32 -> 16 -> 8 -> 4 (with balanced sampling)
num_workers = 4  # Enable parallel data loading for better throughput
pin_memory = torch.cuda.is_available()  # Enable pin_memory for faster GPU transfers

# Try larger batch sizes first (fixed_size=224 makes this feasible)
batch_size = 32  # Start with 32 (16 real + 16 fake per batch) - aggressive but feasible with 224x224
samples_per_class = batch_size // 2

try:
    # Create balanced batch sampler (ensures equal real/fake per batch)
    # Use train_df_final (which contains pre-generated augmentations) instead of train_df
    balanced_sampler = make_balanced_batch_sampler(
        train_df_final,  # Use augmented dataset
        batch_size=batch_size,
        samples_per_class=samples_per_class,
        shuffle=True,
        random_state=42,
    )
    
    train_loader = DataLoader(
        train_ds,
        batch_sampler=balanced_sampler,  # Use batch_sampler instead of batch_size
        num_workers=num_workers,
        pin_memory=pin_memory,  # Enable for faster GPU transfers
        collate_fn=variable_ar_collate,
        persistent_workers=True if num_workers > 0 else False,  # Keep workers alive for efficiency
        prefetch_factor=2 if num_workers > 0 else None,  # Prefetch batches for better throughput
    )
    
    logger.info("✓ Using balanced batch sampler: %d samples per class per batch (batch_size=%d)", 
                samples_per_class, batch_size)
    logger.info("⚠ If you encounter OOM, the code will automatically fallback to smaller batch sizes")
    
except (ValueError, RuntimeError) as e:
    logger.warning("⚠ Balanced batch sampling with batch_size=32 failed: %s. Trying batch_size=16...", str(e))
    # Fallback 1: Try batch_size=16 (8 real + 8 fake)
    batch_size = 16
    samples_per_class = 8
    
    try:
        balanced_sampler = make_balanced_batch_sampler(
            train_df,
            batch_size=batch_size,
            samples_per_class=samples_per_class,
            shuffle=True,
            random_state=42,
        )
        
        train_loader = DataLoader(
            train_ds,
            batch_sampler=balanced_sampler,
            num_workers=num_workers,
            pin_memory=pin_memory,
            collate_fn=variable_ar_collate,
            persistent_workers=True if num_workers > 0 else False,
            prefetch_factor=2 if num_workers > 0 else None,
        )
        
        logger.info("✓ Using balanced batch sampler: %d samples per class per batch (batch_size=%d)", 
                    samples_per_class, batch_size)
        logger.info("⚠ If you encounter OOM, the code will automatically fallback to batch_size=8")
        
    except (ValueError, RuntimeError) as e2:
        logger.warning("⚠ Balanced batch sampling with batch_size=16 failed: %s. Trying batch_size=8...", str(e2))
        # Fallback 2: Try batch_size=8 (4 real + 4 fake)
        batch_size = 8
        samples_per_class = 4
        
        try:
            balanced_sampler = make_balanced_batch_sampler(
                train_df,
                batch_size=batch_size,
                samples_per_class=samples_per_class,
                shuffle=True,
                random_state=42,
            )
            
            train_loader = DataLoader(
                train_ds,
                batch_sampler=balanced_sampler,
                num_workers=num_workers,
                pin_memory=pin_memory,
                collate_fn=variable_ar_collate,
                persistent_workers=True if num_workers > 0 else False,
                prefetch_factor=2 if num_workers > 0 else None,
            )
            
            logger.info("✓ Using balanced batch sampler: %d samples per class per batch (batch_size=%d)", 
                        samples_per_class, batch_size)
            logger.info("⚠ If you encounter OOM, the code will automatically fallback to batch_size=4")
            
        except (ValueError, RuntimeError) as e3:
            logger.warning("⚠ Balanced batch sampling with batch_size=8 failed: %s. Trying batch_size=4...", str(e3))
            # Fallback 3: Try batch_size=4 (2 real + 2 fake)
            batch_size = 4
            samples_per_class = 2
            
            try:
                balanced_sampler = make_balanced_batch_sampler(
                    train_df,
                    batch_size=batch_size,
                    samples_per_class=samples_per_class,
                    shuffle=True,
                    random_state=42,
                )
                
                train_loader = DataLoader(
                    train_ds,
                    batch_sampler=balanced_sampler,
                    num_workers=num_workers,
                    pin_memory=pin_memory,
                    collate_fn=variable_ar_collate,
                    persistent_workers=True if num_workers > 0 else False,
                    prefetch_factor=2 if num_workers > 0 else None,
                )
                
                logger.info("✓ Using balanced batch sampler: %d samples per class per batch (batch_size=%d)", 
                            samples_per_class, batch_size)
                logger.info("⚠ If you encounter OOM, the code will automatically fallback to batch_size=2")
                
            except (ValueError, RuntimeError) as e4:
                logger.warning("⚠ Balanced batch sampling with batch_size=4 failed: %s. Using batch_size=2...", str(e4))
                # Fallback 4: Try batch_size=2 (1 real + 1 fake)
                batch_size = 2
                samples_per_class = 1
                
                try:
                    balanced_sampler = make_balanced_batch_sampler(
                        train_df,
                        batch_size=batch_size,
                        samples_per_class=samples_per_class,
                        shuffle=True,
                        random_state=42,
                    )
                    
                    train_loader = DataLoader(
                        train_ds,
                        batch_sampler=balanced_sampler,
                        num_workers=num_workers,
                        pin_memory=pin_memory,
                        collate_fn=variable_ar_collate,
                        persistent_workers=True if num_workers > 0 else False,
                        prefetch_factor=2 if num_workers > 0 else None,
                    )
                    
                    logger.info("✓ Using balanced batch sampler: %d samples per class per batch (batch_size=%d)", 
                                samples_per_class, batch_size)
                    logger.warning("⚠ Using small batch_size=2. Consider reducing num_frames if OOM occurs.")
                    
                except (ValueError, RuntimeError) as e5:
                    logger.warning("⚠ Balanced batch sampling with batch_size=2 failed: %s. Using batch_size=1 with gradient accumulation...", str(e5))
                    # Final fallback: Use batch_size=1 with gradient accumulation
                    batch_size = 1
                    train_loader = DataLoader(
                        train_ds,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=num_workers,
                        pin_memory=pin_memory,
                        collate_fn=variable_ar_collate,
                        persistent_workers=False,
                        prefetch_factor=None,
                    )
                    logger.warning("⚠ Using batch_size=1 (unbalanced batches). Gradient accumulation will be used to simulate larger batches.")

# Validation loader (always use regular sampling, no need for balanced batches)
# Use same batch_size as training for consistency (fixed_size=224 makes this feasible)
val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,  # Use same batch_size as training (fixed_size enables this)
    shuffle=False,
    num_workers=num_workers,
    pin_memory=pin_memory,
    collate_fn=variable_ar_collate,
    persistent_workers=True if num_workers > 0 else False,
    prefetch_factor=2 if num_workers > 0 else None,
)

logger.info("Data loaders ready: %d train batches, %d val batches", len(train_loader), len(val_loader))
logger.info("Final config: batch_size=%d, num_frames=%d, num_workers=%d", 
            batch_size, video_cfg.num_frames, num_workers)


In [None]:
# 3. Train the binary classifier (pretrained 3D ResNet backbone + Inception head)

from lib.video_modeling import PretrainedInceptionVideoModel

device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("Using device: %s", device)

# Clear CUDA cache before loading model
if device.startswith("cuda"):
    torch.cuda.empty_cache()
    logger.info("CUDA memory cleared. Free: %.2f GB", 
                torch.cuda.get_device_properties(0).total_memory / 1e9 - 
                torch.cuda.memory_allocated(0) / 1e9)

# Use pretrained r3d_18 backbone with an Inception-like head.
# Backbone layers are frozen by default; only the head is trained.
model = PretrainedInceptionVideoModel(freeze_backbone=True)
model = model.to(device)

# Clear cache after model loading
if device.startswith("cuda"):
    torch.cuda.empty_cache()
    logger.info("Model loaded. CUDA memory allocated: %.2f GB", 
                torch.cuda.memory_allocated(0) / 1e9)

optim_cfg = OptimConfig(lr=1e-4, weight_decay=1e-4)

# Adjust gradient accumulation based on actual batch_size
# With larger batch sizes enabled by fixed_size=224, we typically don't need accumulation
# Only use accumulation for very small batch sizes
if batch_size >= 8:
    # Large batches: no accumulation needed
    gradient_accumulation_steps = 1
    effective_batch_size = batch_size
elif batch_size >= 4:
    # Medium batches: minimal accumulation for stability
    gradient_accumulation_steps = 1
    effective_batch_size = batch_size
elif batch_size >= 2:
    # Small batches: use accumulation to simulate larger effective batch
    gradient_accumulation_steps = 2
    effective_batch_size = batch_size * gradient_accumulation_steps
else:
    # Very small batch_size=1: use more accumulation
    gradient_accumulation_steps = 4
    effective_batch_size = batch_size * gradient_accumulation_steps

train_cfg = TrainConfig(
    num_epochs=10,
    device=device,
    log_interval=10,
    use_class_weights=True,
    use_amp=True,
    checkpoint_dir=MODELS_DIR,
    early_stopping_patience=3,
    gradient_accumulation_steps=gradient_accumulation_steps,
)

if gradient_accumulation_steps > 1:
    logger.info("Using gradient accumulation: %d steps (effective batch_size=%d)", 
                gradient_accumulation_steps, effective_batch_size)
else:
    logger.info("Using balanced batches: batch_size=%d (no gradient accumulation needed)", 
                batch_size)

logger.info("Starting training with pretrained backbone...")

# Set memory management environment variables
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Monitor initial memory
if device.startswith("cuda"):
    torch.cuda.empty_cache()
    allocated = torch.cuda.memory_allocated(0) / 1e9
    reserved = torch.cuda.memory_reserved(0) / 1e9
    total = torch.cuda.get_device_properties(0).total_memory / 1e9
    logger.info("GPU memory before training: %.2f GB allocated, %.2f GB reserved, %.2f GB total", 
                allocated, reserved, total)

try:
    model = fit(
        model,
        train_loader=train_loader,
        val_loader=val_loader,
        optim_cfg=optim_cfg,
        train_cfg=train_cfg,
    )
    logger.info("Training completed successfully.")
except RuntimeError as e:
    if "out of memory" in str(e).lower() or "OOM" in str(e):
        logger.error("CUDA OOM error during training: %s", str(e))
        if device.startswith("cuda"):
            torch.cuda.empty_cache()
            allocated = torch.cuda.memory_allocated(0) / 1e9
            logger.error("GPU memory after OOM: %.2f GB still allocated", allocated)
        logger.error("Suggestions:")
        logger.error("  1. Reduce batch_size further (currently %d)", batch_size)
        logger.error("  2. Reduce num_frames further (currently %d)", video_cfg.num_frames)
        logger.error("  3. Use gradient accumulation to simulate larger batches")
        logger.error("  4. Consider using a smaller model")
    raise
except Exception as e:
    logger.exception("Training failed with exception: %s", e)
    raise
finally:
    # Clean up memory
    if device.startswith("cuda"):
        torch.cuda.empty_cache()
        logger.info("GPU memory cleaned up after training")


In [None]:
# 4. Evaluate on validation set and log metrics

from lib.video_metrics import collect_logits_and_labels, basic_classification_metrics, confusion_matrix, roc_auc

logger.info("Evaluating model on validation set...")

try:
    logits, labels = collect_logits_and_labels(model, val_loader, device=device)
    metrics = basic_classification_metrics(logits, labels)
    cm = confusion_matrix(logits, labels)
    auc = roc_auc(logits, labels)

    logger.info("Validation metrics: %s", metrics)
    logger.info("Validation ROC-AUC: %.4f", auc)
    logger.info("Confusion matrix:\n%s", cm.numpy())
except Exception as e:
    logger.exception("Evaluation failed with exception: %s", e)
    raise
