# üé¨ Video Moment Retrieval Demo with Switch-NET

This is a demo interface for video moment retrieval based on the Switch-NET model.

- **Model**: Switch-NET (Loop Decoder DETR)
- **Function**: Locate relevant moments in a video given a natural language query
- **Interface**: Gradio interactive UI

## 1. Import dependencies

Import required libraries: PyTorch, Gradio, NumPy, etc.

In [1]:
import os
import sys
import time
import json
import random
import tempfile
import subprocess
import ast
import nbformat
from collections import OrderedDict
import numpy as np
from typing import Optional, List, Dict, Tuple, Type, Any
import copy
import math
import inspect
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

import gradio as gr
import clip
from PIL import Image
try:
    import cv2
except ImportError:
    cv2 = None

PROJECT_ROOT = "/Users/shuqi/Desktop/FYP"

print("Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Cache for CLIP components to avoid reloading
_clip_components = {
    "model": None,
    "preprocess": None,
    "device": None,
}

  Referenced from: <BC9ECF8E-76F4-3E60-B608-FAB81953710D> /opt/anaconda3/envs/fyp/lib/python3.8/site-packages/torchvision/image.so
  warn(


Libraries imported successfully!
PyTorch version: 2.4.1
CUDA available: False
Using device: cpu


In [2]:
# Global feature cache to reuse extracted video features across queries
MAX_CACHED_VIDEOS = 5
_feature_cache: "OrderedDict[Tuple, Dict[str, Any]]" = OrderedDict()

def _make_cache_key(video_path: str, max_v_l: int, vid_dim: int, clip_duration: Optional[float]) -> Optional[Tuple]:
    """Build a stable cache key for the given video + model configuration."""
    if not video_path:
        return None
    if vid_dim == 514:
        return (video_path, max_v_l, "clip_only")
    duration_value = 0.0 if clip_duration is None else float(clip_duration)
    return (video_path, max_v_l, round(duration_value, 4), vid_dim)

def _cache_get(key: Optional[Tuple]) -> Optional[Dict[str, Any]]:
    if key is None:
        return None
    entry = _feature_cache.get(key)
    if entry is None:
        return None
    _feature_cache.move_to_end(key, last=True)
    return entry

def _cache_put(key: Optional[Tuple], entry: Dict[str, Any]) -> None:
    if key is None:
        return
    _feature_cache[key] = entry
    _feature_cache.move_to_end(key, last=True)
    while len(_feature_cache) > MAX_CACHED_VIDEOS:
        evicted_key, _ = _feature_cache.popitem(last=False)
        print(f"[Feature Cache] Evicted cached features for {evicted_key}")

def get_cached_video_features(key: Optional[Tuple], vid_dim: int) -> Optional[Dict[str, Any]]:
    entry = _cache_get(key)
    if entry is None:
        return None
    if entry.get("vid_dim") != vid_dim:
        return None
    return entry

def cache_video_features(key: Optional[Tuple], fused_features: np.ndarray, fused_valid_len: int, vid_dim: int) -> None:
    if key is None:
        return
    payload = {
        "fused_features": np.array(fused_features, copy=True),
        "fused_valid_len": int(fused_valid_len),
        "vid_dim": int(vid_dim),
    }
    _cache_put(key, payload)

## 2. Define Switch-NET model architecture

This section defines the Switch-NET model structure, including:
- Position Encoding
- Multi-Head Attention
- Transformer Encoder/Decoder
- The main Switch-NET model

In [4]:
def load_Switch_net_from_notebook(notebook_path: str) -> Tuple[int, Type[nn.Module]]:
    """
    Execute definition cells from the training notebook to populate Switch-Net classes.

    Only top-level imports, assignments, functions, and classes are executed to avoid
    running training loops or dataset instantiation code.
    Returns a tuple of (executed_blocks, resolved_model_class).
    """
    if not os.path.exists(notebook_path):
        raise FileNotFoundError(f"Notebook not found: {notebook_path}")
    
    nb_data = nbformat.read(notebook_path, as_version=4)
    exec_globals = globals()
    executed = 0
    
    allowed_nodes = (
        ast.Import,
        ast.ImportFrom,
        ast.FunctionDef,
        ast.AsyncFunctionDef,
        ast.ClassDef,
        ast.Assign,
        ast.AnnAssign,
    )
    
    def execute_node(node: ast.AST):
        nonlocal executed
        module = ast.Module(body=[node], type_ignores=[])
        ast.fix_missing_locations(module)
        try:
            code_obj = compile(module, notebook_path, "exec")
            exec(code_obj, exec_globals)
        except NameError as err:
            print(f"[Notebook Loader] Skipping statement due to NameError: {err}")
            return
        except Exception as err:
            print(f"[Notebook Loader] Skipping statement due to {err.__class__.__name__}: {err}")
            return
        executed += 1
    
    for cell in nb_data.get("cells", []):
        if cell.get("cell_type") != "code":
            continue
        source = cell.get("source", "")
        if isinstance(source, list):
            source = "".join(source)
        if not source.strip():
            continue
        try:
            tree = ast.parse(source, filename=notebook_path)
        except SyntaxError:
            continue
        for node in tree.body:
            if isinstance(node, allowed_nodes):
                execute_node(node)
            elif isinstance(node, ast.Try):
                safe = all(isinstance(stmt, allowed_nodes) for stmt in node.body)
                if safe:
                    for stmt in node.body:
                        execute_node(stmt)
            # Skip other node types (loops, with-statements, main guards, etc.)
    
    resolved_model = exec_globals.get("SwitchNet") or exec_globals.get("LD_DETR")
    if resolved_model is None:
        raise RuntimeError("Switch-Net definition not found after executing notebook definitions.")
    
    exec_globals["SwitchNet"] = resolved_model
    
    return executed, resolved_model

MODEL_SOURCE_NOTEBOOK = os.path.join(PROJECT_ROOT, "switch-net.ipynb")
_blocks_executed, SwitchNET = load_Switch_net_from_notebook(MODEL_SOURCE_NOTEBOOK)
print(f"Loaded Switch-Net definitions from training notebook (definition blocks executed: {_blocks_executed})")

[Notebook Loader] Skipping statement due to FileNotFoundError: [Errno 2] No such file or directory: 'output/models/best_model.pth'
[Notebook Loader] Skipping statement due to NameError: name 'test_loader' is not defined
Loaded Switch-Net definitions from training notebook (definition blocks executed: 112)


  "id": "70efac55",


In [5]:
print("Switch-NET model components loaded from training notebook definitions")

Switch-NET model components loaded from training notebook definitions


## 3. Model loading function

Define a helper to load a trained Switch-NET checkpoint.

In [6]:
def load_trained_model(checkpoint_path: str, device: torch.device):
    """
    Load a trained Switch-Net checkpoint and return the model plus its resolved config.

    Args:
        checkpoint_path: Filesystem path to the serialized `.pth` checkpoint.
        device: Target device onto which the model should be moved.

    Returns:
        A tuple of (model, resolved_config_dict). The model is ready for eval().
    """
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
    
    load_kwargs = {"map_location": device}
    checkpoint = None
    if "weights_only" in inspect.signature(torch.load).parameters:
        try:
            checkpoint = torch.load(checkpoint_path, weights_only=True, **load_kwargs)
        except Exception as exc:
            warnings.warn(
                f"Falling back to torch.load(weights_only=False) because safe loading failed with: {exc}",
                RuntimeWarning,
            )
            checkpoint = torch.load(checkpoint_path, weights_only=False, **load_kwargs)
    else:
        checkpoint = torch.load(checkpoint_path, **load_kwargs)
    
    raw_config = checkpoint.get("config", {})
    if hasattr(raw_config, "_asdict"):
        raw_config = raw_config._asdict()
    elif hasattr(raw_config, "__dict__"):
        raw_config = vars(raw_config)
    
    if not isinstance(raw_config, dict):
        raw_config = {}
    
    def to_builtin(value):
        if isinstance(value, torch.Tensor):
            return value.item()
        return value
    
    field_aliases = {
        "txt_dim": (["txt_dim", "query_dim", "text_dim"], 512),
        "vid_dim": (["vid_dim", "video_dim", "video_feature_dim"], 512),
        "hidden_dim": (["hidden_dim", "d_model"], 256),
        "num_queries": (["num_queries"], 10),
        "aux_loss": (["aux_loss"], True),
        "position_embedding": (["position_embedding"], "sine"),
        "max_v_l": (["max_v_l", "max_video_len"], 75),
        "max_q_l": (["max_q_l", "max_query_len"], 32),
        "span_loss_type": (["span_loss_type"], "l1"),
        "use_txt_pos": (["use_txt_pos"], False),
        "aud_dim": (["aud_dim"], 0),
        "queue_length": (["queue_length"], 65536),
        "momentum": (["momentum"], 0.995),
        "distillation_coefficient": (["distillation_coefficient"], 0.4),
        "num_v2t_encoder_layers": (["num_v2t_encoder_layers"], 2),
        "num_encoder1_layers": (["num_encoder1_layers"], 2),
        "num_convolutional_blocks": (["num_convolutional_blocks"], 5),
        "num_encoder2_layers": (["num_encoder2_layers"], 2),
        "num_decoder_layers": (["num_decoder_layers"], 2),
        "num_decoder_loops": (["num_decoder_loops"], 3),
        "clip_len": (["clip_len"], 2),
    }
    
    def resolve_field(keys, default_value):
        for key in keys:
            if key in raw_config and raw_config[key] is not None:
                return to_builtin(raw_config[key])
        return default_value
    
    model_args = {
        field: resolve_field(keys, default)
        for field, (keys, default) in field_aliases.items()
    }
    
    model = SwitchNET(**model_args)
    state_dict = checkpoint.get("model_state_dict", checkpoint)
    load_result = model.load_state_dict(state_dict, strict=False)
    if load_result.missing_keys:
        print(f"Missing keys during load: {load_result.missing_keys[:5]}")
        if len(load_result.missing_keys) > 5:
            print("(truncated)")
    if load_result.unexpected_keys:
        print(f"Unexpected keys during load: {load_result.unexpected_keys[:5]}")
        if len(load_result.unexpected_keys) > 5:
            print("(truncated)")
    
    model = model.to(device)
    model.eval()
    
    print(f"Model loaded from {checkpoint_path}")
    print(f"Resolved config: {model_args}")
    print(f"Trained for {checkpoint.get('epoch', 'unknown')} epochs")
    
    return model, model_args

print("Model loading function defined!")

Model loading function defined!


## 4. Feature extraction functions
This section gathers the encoding modules used during inference:
- **CLIP text/image features**: reuse the same ViT-B/32 used during training to align language and visual semantics;
- **TEF (Temporal Endpoint Features)**: provide normalized start/end position information to help the model interpret time indices;
- **SlowFast video features**: capture rich spatio-temporal motion representations, used for the 2818-d multimodal input configuration.

In [7]:
CLIP_MODEL_NAME = "ViT-B/32"
TEF_DIM = 2

def get_clip_components():
    """Load and cache the CLIP model and preprocess pipeline."""
    if _clip_components["model"] is None:
        clip_device = device
        model, preprocess = clip.load(CLIP_MODEL_NAME, device=device, jit=False)
        model.eval()
        _clip_components["model"] = model
        _clip_components["preprocess"] = preprocess
        _clip_components["device"] = clip_device
        print(f"Loaded CLIP model ({CLIP_MODEL_NAME}) on {clip_device}")
    return _clip_components["model"], _clip_components["preprocess"], _clip_components["device"]

def build_tef_features(num_clips: int) -> np.ndarray:
    """Construct temporal endpoint features (TEF) matching training logic."""
    if num_clips <= 0:
        return np.zeros((0, TEF_DIM), dtype=np.float32)
    positions = np.arange(0, num_clips, dtype=np.float32) / float(num_clips)
    tef_start = positions
    tef_end = positions + 1.0 / float(num_clips)
    tef = np.stack([tef_start, tef_end], axis=1)
    return tef.astype(np.float32)

def extract_video_clip_features(video_path: str, num_clips: int = 64, feature_dim: int = 512) -> Tuple[np.ndarray, int]:
    """Extract CLIP frame features from a video file and return features with valid length."""
    if cv2 is None:
        raise RuntimeError("OpenCV (cv2) is required for video processing but is not available.")
    
    model, preprocess, clip_device = get_clip_components()
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise RuntimeError(f"Unable to open video file: {video_path}")
    
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    if total_frames <= 0:
        total_frames = num_clips
    
    frame_indices = np.linspace(0, max(total_frames - 1, 0), num=num_clips, dtype=int)
    collected = []
    
    for frame_idx in frame_indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, int(frame_idx))
        success, frame = cap.read()
        if not success or frame is None:
            continue
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(frame_rgb)
        image_tensor = preprocess(image).unsqueeze(0).to(clip_device)
        
        with torch.no_grad():
            clip_feat = model.encode_image(image_tensor)
            clip_feat = clip_feat / clip_feat.norm(dim=-1, keepdim=True)
        
        collected.append(clip_feat.squeeze(0).cpu().float().numpy())
    
    cap.release()
    
    if not collected:
        return np.zeros((num_clips, feature_dim), dtype=np.float32), 0
    
    features = np.vstack(collected)
    valid_length = min(features.shape[0], num_clips)
    
    if features.shape[0] < num_clips:
        padding = np.zeros((num_clips - features.shape[0], feature_dim), dtype=np.float32)
        features = np.concatenate([features, padding], axis=0)
    
    return features[:num_clips].astype(np.float32), valid_length

def extract_query_clip_features(query_text: str, seq_len: int = 32, feature_dim: int = 512) -> Tuple[np.ndarray, int]:
    """Extract per-token CLIP text features and return features with valid token length."""
    model, _, clip_device = get_clip_components()
    if not query_text or query_text.strip() == "":
        return np.zeros((seq_len, feature_dim), dtype=np.float32), 0
    
    tokens = clip.tokenize([query_text], truncate=True).to(clip_device)
    with torch.no_grad():
        x = model.token_embedding(tokens).type(model.dtype)
        x = x + model.positional_embedding.type(model.dtype)
        x = x.permute(1, 0, 2)  # (context_length, batch, dim)
        x = model.transformer(x)
        x = x.permute(1, 0, 2)  # (batch, context_length, dim)
        x = model.ln_final(x)
        text_features = x[0].float().cpu().numpy()
    
    valid_length = int((tokens[0] != 0).sum().item())
    valid_length = min(valid_length, seq_len)
    
    if text_features.shape[1] != feature_dim:
        raise ValueError(f"Expected text feature dim {feature_dim}, got {text_features.shape[1]}")
    
    text_features = text_features[:seq_len]
    if text_features.shape[0] < seq_len:
        padding = np.zeros((seq_len - text_features.shape[0], feature_dim), dtype=np.float32)
        text_features = np.concatenate([text_features, padding], axis=0)
    
    return text_features.astype(np.float32), valid_length

print("Feature extraction functions ready (CLIP + TEF)")

Feature extraction functions ready (CLIP + TEF)


## 4.1 SlowFast video feature extraction
Use a pretrained SlowFast R50 to further encode the video's spatio-temporal dynamics and complement CLIP frame features' lack of motion information.
For each sampled clip we:
- perform uniform temporal sampling, spatial scaling and center crop to match SlowFast pretraining distribution;
- build Slow/Fast dual pathways to leverage SlowFast's multi-rate perception;
- register a forward hook at the model head to capture a 2304-d feature vector from the pooling/projection layer;
- zero-pad when clips are missing so the final tensor shape is `(num_clips, 2304)` for easy concatenation.

In [8]:
import numpy as np
from typing import Tuple
from pytorchvideo.data.encoded_video import EncodedVideo
from pytorchvideo.models.hub import slowfast_r50
from pytorchvideo.transforms import UniformTemporalSubsample

_slowfast_state = {"model": None, "device": None}

SLOWFAST_ALPHA = 4
SLOWFAST_NUM_FRAMES = 32
SLOWFAST_SHORT_SIDE = 256
SLOWFAST_CROP_SIZE = 256
SLOWFAST_CLIP_SECONDS = 2.0
SLOWFAST_MEAN = [0.45, 0.45, 0.45],
SLOWFAST_STD = [0.225, 0.225, 0.225],
SLOWFAST_FEATURE_DIM = 2304

def _resize_short_side(frames: torch.Tensor, short_side: int) -> torch.Tensor:
    """Resize video frames so that the shorter spatial side equals `short_side`."""
    if short_side <= 0:
        return frames
    c, t, h, w = frames.shape
    if h == 0 or w == 0:
        return frames
    scale = short_side / min(h, w)
    target_h = max(int(round(h * scale)), 1)
    target_w = max(int(round(w * scale)), 1)
    frames_btchw = frames.permute(1, 0, 2, 3)  # (t, c, h, w)
    frames_resized = F.interpolate(
        frames_btchw,
        size=(target_h, target_w),
        mode="bilinear",
        align_corners=False,
    )
    return frames_resized.permute(1, 0, 2, 3)  # back to (c, t, h, w)

def _center_crop_video(frames: torch.Tensor, size: int) -> torch.Tensor:
    """Center crop the spatial dimensions to `size`. Pads if needed."""
    if size <= 0:
        return frames
    c, t, h, w = frames.shape
    if h < size or w < size:
        pad_h = max(size - h, 0)
        pad_w = max(size - w, 0)
        frames = frames.permute(1, 0, 2, 3)
        frames = F.pad(frames, (0, pad_w, 0, pad_h))  # pad W then H
        frames = frames.permute(1, 0, 2, 3)
        c, t, h, w = frames.shape
    top = max((h - size) // 2, 0)
    left = max((w - size) // 2, 0)
    return frames[:, :, top:top + size, left:left + size]

def _normalize_video(frames: torch.Tensor, mean, std) -> torch.Tensor:
    mean_tensor = torch.tensor(mean, dtype=frames.dtype, device=frames.device).view(-1, 1, 1, 1)
    std_tensor = torch.tensor(std, dtype=frames.dtype, device=frames.device).view(-1, 1, 1, 1)
    return (frames - mean_tensor) / std_tensor

def _prepare_slowfast_clip(frames: torch.Tensor) -> torch.Tensor:
    """Apply temporal sampling, resize, crop, and normalize for SlowFast."""
    sampler = UniformTemporalSubsample(SLOWFAST_NUM_FRAMES)
    frames = sampler(frames)
    frames = frames.float() / 255.0
    frames = _resize_short_side(frames, SLOWFAST_SHORT_SIDE)
    frames = _center_crop_video(frames, SLOWFAST_CROP_SIZE)
    frames = _normalize_video(frames, SLOWFAST_MEAN, SLOWFAST_STD)
    return frames

def _pack_slowfast_pathway(frames: torch.Tensor) -> List[torch.Tensor]:
    """Create the slow and fast pathways expected by SlowFast."""
    if frames.dim() != 4:
        raise ValueError("Expected video tensor with shape (C, T, H, W)")
    fast_pathway = frames
    num_slow_frames = max(int(np.ceil(frames.shape[1] / SLOWFAST_ALPHA)), 1)
    slow_indices = torch.linspace(
        0,
        frames.shape[1] - 1,
        steps=num_slow_frames,
        dtype=torch.long,
    )
    slow_pathway = torch.index_select(fast_pathway, 1, slow_indices.clamp(max=frames.shape[1] - 1))
    return [slow_pathway, fast_pathway]

def get_slowfast_backbone(target_device: torch.device = device) -> torch.nn.Module:
    """Load (and cache) the pretrained SlowFast backbone."""
    global _slowfast_state
    cached = _slowfast_state
    if cached["model"] is not None and cached["device"] == target_device:
        return cached["model"]
    model = slowfast_r50(pretrained=True)
    model = model.to(target_device)
    model.eval()
    _slowfast_state["model"] = model
    _slowfast_state["device"] = target_device
    return model

def extract_slowfast_video_features(
    video_path: str,
    num_clips: int = 32,
    clip_duration: float = SLOWFAST_CLIP_SECONDS,
    target_device: torch.device = device,
    quiet: bool = False,
 ) -> Tuple[np.ndarray, int]:
    """
    Extract SlowFast features for uniformly sampled clips along a video.

    Returns an array of shape (num_clips, 2304) and the number of valid clips.
    """
    if not os.path.exists(video_path):
        raise FileNotFoundError(f"Video not found: {video_path}")
    model = get_slowfast_backbone(target_device)
    encoded = EncodedVideo.from_path(video_path)
    duration = encoded.duration if encoded.duration is not None else num_clips * clip_duration
    if duration is None or duration <= 0:
        duration = num_clips * clip_duration
    clip_centers = np.linspace(
        clip_duration / 2.0,
        max(duration - clip_duration / 2.0, clip_duration / 2.0),
        num=num_clips,
    )

    collected: List[np.ndarray] = []
    valid = 0
    feature_buffer: List[torch.Tensor] = []
    
    head_block = model.blocks[-1] if hasattr(model, "blocks") else None
    pool_module = getattr(head_block, "pool", None) if head_block is not None else None
    use_proj_inputs = False
    if pool_module is None:
        pool_module = getattr(head_block, "proj", None) if head_block is not None else None
        use_proj_inputs = True if pool_module is not None else False
    if pool_module is None:
        raise RuntimeError("Unable to locate SlowFast pooling/projection module for feature capture.")
    
    def _capture_hook(_, hook_inputs, hook_output):
        if use_proj_inputs:
            captured = hook_inputs[0].detach().cpu()
            if captured.dim() > 2:
                reduce_dims = tuple(range(2, captured.dim()))
                captured = captured.mean(dim=reduce_dims)
            feature_buffer.append(captured)
        else:
            feature_buffer.append(hook_output.detach().cpu())

    hook_handle = pool_module.register_forward_hook(_capture_hook)
    try:
        for center in clip_centers:
            start_sec = max(float(center - clip_duration / 2.0), 0.0)
            end_sec = min(float(center + clip_duration / 2.0), duration)
            clip = encoded.get_clip(start_sec=start_sec, end_sec=end_sec)
            if clip is None or "video" not in clip:
                continue
            frames = clip["video"]
            if frames.numel() == 0:
                continue
            frames = _prepare_slowfast_clip(frames)
            pathways = _pack_slowfast_pathway(frames)
            pathway_inputs = [p.unsqueeze(0).to(target_device) for p in pathways]
            with torch.no_grad():
                _ = model(pathway_inputs)
            if not feature_buffer:
                continue
            pooled = feature_buffer.pop()
            if pooled.dim() > 2:
                pooled = pooled.flatten(start_dim=1)
            collected.append(pooled.squeeze(0).numpy().astype(np.float32))
            valid += 1
    finally:
        hook_handle.remove()

    if not collected:
        if not quiet:
            print("SlowFast failed to extract valid clips, returning zero vectors.")
        return np.zeros((num_clips, SLOWFAST_FEATURE_DIM), dtype=np.float32), 0

    features = np.stack(collected)
    if features.shape[0] < num_clips:
        padding = np.zeros((num_clips - features.shape[0], features.shape[1]), dtype=np.float32)
        features = np.concatenate([features, padding], axis=0)

    return features[:num_clips], valid



In [9]:
# import torch, os, glob
# cache_dir = os.path.join(torch.hub.get_dir(), "checkpoints")
# for path in glob.glob(os.path.join(cache_dir, "SLOWFAST*")):
#     print(f"Removing {path}")
#     os.remove(path)

In [10]:
# # Test: extract SlowFast features for a given video
# test_video_path = "/Users/shuqi/Desktop/FYP/data/videos/0A8ZT.mp4"
# slowfast_feats, slowfast_valid = extract_slowfast_video_features(
#     test_video_path,
#     num_clips=16,
#     clip_duration=2.0,
#     target_device=device,
#     quiet=False,
#  )
# print(f"SlowFast features shape: {slowfast_feats.shape}")
# print(f"Valid clips extracted: {slowfast_valid}")
# if slowfast_valid > 0:
#     print(f"Sample feature vector L2 norm: {np.linalg.norm(slowfast_feats[0]):.4f}")

## 5. Video utility functions

Helpers for extracting and jumping to video segments.

In [11]:
class VideoState:
    """Global video state manager"""
    def __init__(self):
        self.original_video = None
        self.current_video = None
        self.is_segment = False
        self.segment_info = (0.0, 0.0)
        self.temp_files = []
        self.is_processing = False
        self.video_duration = 0.0
    
    def reset(self):
        """Reset state and remove temporary files"""
        self.cleanup_temp_files()
        self.original_video = None
        self.current_video = None
        self.is_segment = False
        self.segment_info = (0.0, 0.0)
        self.is_processing = False
        self.video_duration = 0.0
    
    def cleanup_temp_files(self):
        """Delete temporary files created by the UI/process."""
        for temp_file in self.temp_files:
            try:
                if os.path.exists(temp_file):
                    os.remove(temp_file)
            except:
                pass
        self.temp_files = []

# Create global state
video_state = VideoState()

def get_video_duration(video_path: str) -> float:
    """Get video duration in seconds using ffprobe."""
    try:
        cmd = [
            'ffprobe',
            '-v', 'error',
            '-show_entries', 'format=duration',
            '-of', 'default=noprint_wrappers=1:nokey=1',
            video_path
        ]
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
        if result.returncode == 0:
            return float(result.stdout.strip())
    except:
        pass
    return 0.0

def load_video(video_path):
    """Load a video file into the demo state and return UI elements."""
    video_state.reset()
    video_state.original_video = video_path
    video_state.current_video = video_path
    
    # Get video duration
    video_state.video_duration = get_video_duration(video_path)
    
    # Get file extension
    file_ext = os.path.splitext(video_path)[1].lower()
    mime_type_map = {
        '.mp4': 'video/mp4',
        '.avi': 'video/x-msvideo',
        '.mov': 'video/quicktime',
        '.mkv': 'video/x-matroska',
        '.webm': 'video/webm'
    }
    mime_type = mime_type_map.get(file_ext, 'video/mp4')
    
    # Create video HTML
    video_html = f"""
    <div style=\"width: 100%; max-width: 510px; margin: 0 auto; padding: 5px;\">
        <video 
            id=\"main-video\" 
            controls 
            width=\"100%\" 
            style=\"border-radius: 8px; width: 510px; height: 310px; max-width: 510px; max-height: 310px; object-fit: contain; display: block;\"
            preload=\"metadata\"
        >
            <source src=\"file={video_path}\" type=\"{mime_type}\">
            Your browser does not support this video format.
        </video>
    </div>
    """
    
    status_msg = f"Video '{os.path.basename(video_path)}' loaded successfully!\n"
    if video_state.video_duration > 0:
        status_msg += f"Video duration: {video_state.video_duration:.1f} seconds\n"
    status_msg += "You can start retrieval!"
    
    if "create_result_interface" in globals():
        interface_updates = create_result_interface([])
    else:
        num_candidates = globals().get("TOP_K_RESULTS", 5)
        summary_placeholder = gr.update(value="<div class='result-summary'>Waiting for retrieval</div>")
        candidate_placeholders = []
        for _ in range(num_candidates):
            candidate_placeholders.extend([
                gr.update(visible=False),
                gr.update(value=""),
                gr.update(interactive=False),
            ])
        interface_updates = [summary_placeholder, *candidate_placeholders]
    
    return (
        video_html,
        status_msg,
        [],
        *interface_updates,
    )

def simple_video_jump(start_time, end_time):
    """Jump to a specified time range in the video."""
    if not video_state.original_video:
        return "Please upload a video first", ""
    
    if start_time is None or end_time is None:
        return "Invalid time range", ""
    
    # compute midpoint
    middle_time = (start_time + end_time) / 2.0
    
    file_ext = os.path.splitext(video_state.original_video)[1].lower()
    mime_type_map = {
        '.mp4': 'video/mp4',
        '.avi': 'video/x-msvideo',
        '.mov': 'video/quicktime',
        '.mkv': 'video/x-matroska',
        '.webm': 'video/webm',
    }
    mime_type = mime_type_map.get(file_ext, 'video/mp4')
    
    video_html = f"""
    <div style=\"width: 100%; max-width: 510px; margin: 0 auto; padding: 5px;\">
        <video 
            id=\"main_video\" 
            controls 
            width=\"100%\" 
            style=\"border-radius: 8px; width: 510px; height: 300px; max-width: 510px; max-height: 300px; object-fit: contain; display: block;\"
            preload=\"metadata\"
        >
            <source src=\"file={video_state.original_video}#t={middle_time}\" type=\"{mime_type}\">
            Your browser does not support video playback.
        </video>
        <div style=\"background: #f0f8ff; padding: 8px; margin-top: 8px; border-radius: 4px; text-align: center;\">
            Current segment: {start_time:.1f}s - {end_time:.1f}s
        </div>
    </div>
    """
    
    status_msg = f"Jumped to: {start_time:.1f}s - {end_time:.1f}s (middle position: {middle_time:.1f}s)"
    
    return status_msg, video_html


# def simple_video_jump(start_time, end_time):
#     """Ë∑≥ËΩ¨Âà∞ËßÜÈ¢ëÁöÑÊåáÂÆöÊó∂Âàª"""
#     if not video_state.original_video:
#         return "ËØ∑ÂÖà‰∏ä‰º†ËßÜÈ¢ë", ""
    
#     if start_time is None or end_time is None:
#         return "Êó†ÊïàÁöÑÊó∂Èó¥ËåÉÂõ¥", ""
    
#     file_ext = os.path.splitext(video_state.original_video)[1].lower()
#     mime_type_map = {
#         '.mp4': 'video/mp4',
#         '.avi': 'video/x-msvideo',
#         '.mov': 'video/quicktime',
#         '.mkv': 'video/x-matroska',
#         '.webm': 'video/webm'
#     }
#     mime_type = mime_type_map.get(file_ext, 'video/mp4')
    
#     video_html = f"""
#     <div style=\"width: 100%; max-width: 510px; margin: 0 auto; padding: 5px;\">
#         <video 
#             id=\"main_video\" 
#             controls 
#             width=\"100%\" 
#             style=\"border-radius: 8px; width: 510px; height: 300px; max-width: 510px; max-height: 300px; object-fit: contain; display: block;\"
#             preload=\"metadata\"
#         >
#             <source src=\"file={video_state.original_video}#t={start_time}\" type=\"{mime_type}\">
#             ÊÇ®ÁöÑÊµèËßàÂô®‰∏çÊîØÊåÅËßÜÈ¢ëÊí≠Êîæ„ÄÇ
#         </video>
#         <div style=\"background: #f0f8ff; padding: 8px; margin-top: 8px; border-radius: 4px; text-align: center;\">
#             ÂΩìÂâçÊó∂Èó¥ÊÆµ: {start_time:.1f}s - {end_time:.1f}s
#         </div>
#     </div>
#     """
    
#     status_msg = f"Ë∑≥ËΩ¨Âà∞: {start_time:.1f}s - {end_time:.1f}s"
    
#     return status_msg, video_html

# print("Video processing functions defined!")

## 6. Model inference function
`search_moments` is the main orchestrator for inference:
- lazily load a trained Switch-NET checkpoint and cache the model/config;
- choose fusion strategy based on `vid_dim`: 514 uses CLIP+TEF, 2818 also includes SlowFast features;
- build masks for video/text to match the sequence lengths used at training;
- run the model and convert center-width span predictions to start-end times for the frontend.

In [12]:
# Global model cache
class ModelCache:
    def __init__(self):
        self.model = None
        self.config = None
        self.is_loaded = False

model_cache = ModelCache()

TOP_K_RESULTS = 5

def search_moments(query_text, video_player_html, model_path="output/models/best_model_moe.pth"):
    """
    Search for relevant moments in a video using the Switch-NET model.
    
    Args:
        query_text: the query string
        video_player_html: the video player HTML (used to get state)
        model_path: path to the model checkpoint
    
    Yields:
        (status_message, results) for incremental UI updates; results are ranked candidate lists.
    """
    if not query_text:
        gr.Warning("Please enter a query!")
        yield "Please enter a query", []
        return
    
    actual_video_path = video_state.original_video
    if not actual_video_path:
        gr.Warning("Please upload a video first!")
        yield "Please upload a video file first", []
        return
    
    if video_state.is_processing:
        gr.Warning("Processing is in progress, please wait...")
        yield "Processing is in progress, please wait...", []
        return
    
    video_state.is_processing = True

    try:
        if not model_cache.is_loaded:
            yield "Loading Switch-NET model...\n", []

            if not os.path.exists(model_path):
                yield f"Error: model file not found {model_path}\nPlease train the model or provide a correct checkpoint path", []
                return

            model_cache.model, model_cache.config = load_trained_model(model_path, device)
            model_cache.is_loaded = True
            yield "Model loaded!\n", []

        cfg = model_cache.config or {}
        vid_dim = int(cfg.get("vid_dim", 512))
        txt_dim = int(cfg.get("txt_dim", 512))
        max_v_l = int(cfg.get("max_v_l", 64))
        max_q_l = int(cfg.get("max_q_l", 32))
        num_queries = int(cfg.get("num_queries", 10))
        clip_duration = float(cfg.get("clip_len", SLOWFAST_CLIP_SECONDS))

        cache_key = _make_cache_key(
            actual_video_path,
            max_v_l,
            vid_dim,
            clip_duration if vid_dim != 514 else None,
        )
        cache_entry = get_cached_video_features(cache_key, vid_dim)

        fused_video_features: Optional[np.ndarray] = None
        fused_valid_len: Optional[int] = None

        if cache_entry is not None:
            fused_video_features = cache_entry["fused_features"]
            fused_valid_len = cache_entry["fused_valid_len"]
            yield "Using cached video features...\n", []
        else:
            yield "Extracting video features...\n", []
            clip_feats, clip_valid_len = extract_video_clip_features(
                actual_video_path, num_clips=max_v_l, feature_dim=512
            )
            if clip_valid_len == 0:
                yield "Failed to extract frames from the video. Please ensure a valid video file was uploaded", []
                return

            tef_features = build_tef_features(max_v_l)
            if tef_features.shape[0] != max_v_l:
                tef_features = np.resize(tef_features, (max_v_l, TEF_DIM))

            fused_valid_len = clip_valid_len

            if vid_dim == 514:
                fused_video_features = np.concatenate([clip_feats, tef_features], axis=-1)
            elif vid_dim == 2818:
                slowfast_feats, slowfast_valid = extract_slowfast_video_features(
                    actual_video_path,
                    num_clips=max_v_l,
                    clip_duration=clip_duration,
                    target_device=device,
                    quiet=True,
                )
                slowfast_padded = np.zeros((max_v_l, SLOWFAST_FEATURE_DIM), dtype=np.float32)
                if slowfast_valid > 0:
                    slowfast_len = min(slowfast_valid, max_v_l)
                    slowfast_padded[:slowfast_len] = slowfast_feats[:slowfast_len]
                    fused_valid_len = max(fused_valid_len, slowfast_len)
                else:
                    print("Warning: SlowFast feature extraction failed, using zero vectors as placeholder.")
                fused_video_features = np.concatenate([clip_feats, slowfast_padded, tef_features], axis=-1)
            else:
                raise ValueError(
                    "vid_dim={} in model config is not supported by the inference pipeline; expected 514 or 2818".format(vid_dim)
                )

            if fused_video_features.shape[1] != vid_dim:
                raise ValueError(
                    f"Fused video feature dim is {fused_video_features.shape[1]}, which does not match vid_dim={vid_dim} from config"
                )

            fused_valid_len = max(1, min(fused_valid_len, max_v_l))
            cache_video_features(cache_key, fused_video_features, fused_valid_len, vid_dim)

        fused_video_features = np.ascontiguousarray(fused_video_features)
        fused_valid_len = max(1, min(int(fused_valid_len), max_v_l))
        video_mask = np.zeros(max_v_l, dtype=np.float32)
        video_mask[:fused_valid_len] = 1.0

        yield "Extracting query features...\n", []
        query_features, query_valid_len = extract_query_clip_features(
            query_text, seq_len=max_q_l, feature_dim=txt_dim
        )
        if query_valid_len == 0:
            yield "Query text is empty or could not be encoded; please re-enter", []
            return
        query_length = max(1, min(query_valid_len, max_q_l))
        query_mask = np.zeros(max_q_l, dtype=np.float32)
        query_mask[:query_length] = 1.0

        video_features_tensor = torch.from_numpy(fused_video_features).unsqueeze(0).to(device)
        query_features_tensor = torch.from_numpy(query_features).unsqueeze(0).to(device)
        video_mask_tensor = torch.from_numpy(video_mask).unsqueeze(0).to(device)
        query_mask_tensor = torch.from_numpy(query_mask).unsqueeze(0).to(device)

        yield "Running Switch-NET model inference...\n", []
        with torch.no_grad():
            outputs = model_cache.model(
                src_txt=query_features_tensor,
                src_txt_mask=query_mask_tensor.bool(),
                src_vid=video_features_tensor,
                src_vid_mask=video_mask_tensor.bool(),
                is_training=False,
            )

        yield "Processing model predictions...\n", []
        pred_logits = outputs["pred_logits"][0]  # (num_queries, 2)
        pred_spans = outputs["pred_spans"][0]    # (num_queries, 2) center-width
        scores = F.softmax(pred_logits, dim=-1)[:, 1]

        spans_start_end = span_cxw_to_xx(pred_spans)
        spans_start_end = torch.clamp(spans_start_end, 0.0, 1.0)

        top_k = min(TOP_K_RESULTS, scores.shape[0])
        topk_scores, topk_indices = torch.topk(scores, k=top_k)

        candidate_pool = []
        for rank, (score, idx) in enumerate(zip(topk_scores, topk_indices), start=1):
            candidate_pool.append({
                "rank": rank,
                "confidence": float(score.item()),
                "start_norm": float(spans_start_end[idx, 0].item()),
                "end_norm": float(spans_start_end[idx, 1].item()),
                "query_index": int(idx.item()),
            })

        video_duration = video_state.video_duration or 100.0

        results = []
        for candidate in candidate_pool:
            start_time = candidate["start_norm"] * video_duration
            end_time = candidate["end_norm"] * video_duration
            if start_time > end_time:
                start_time, end_time = end_time, start_time
            results.append({
                "rank": candidate["rank"],
                "start_time": float(start_time),
                "end_time": float(end_time),
                "confidence": candidate["confidence"],
                "query_index": candidate["query_index"],
            })

        if not results:
            yield "No valid candidate results obtained", []
            return

        best_result = results[0]

        status_msg = [
            "Retrieval complete!",
            f"Top1 match: {best_result['start_time']:.1f}s - {best_result['end_time']:.1f}s",
            f"Confidence: {best_result['confidence']:.3f}",
        ]
        status_msg = "\n".join(status_msg)

        yield status_msg, results

    except Exception as e:
        error_msg = f"An error occurred during processing: {str(e)}"
        yield error_msg, []
    finally:
        video_state.is_processing = False

print("Search function defined!")

Search function defined!


## 7. Gradio UI

Build the interactive web interface.

In [13]:
def create_result_interface(results):
    """Update UI components based on candidate results."""
    if results:
        best = results[0]
        summary_html = (
            "<div class='result-summary'>"
            "<div class='summary-title'>Most relevant time segment</div>"
            f"<div class='summary-time'>{best['start_time']:.1f}s - {best['end_time']:.1f}s</div>"
            "</div>"
        )
    else:
        summary_html = (
            "<div class='result-summary result-summary--empty'>"
            "Waiting for results"
            "</div>"
        )

    updates = [gr.update(value=summary_html)]

    for idx in range(TOP_K_RESULTS):
        if idx < len(results):
            candidate = results[idx]
            info_html = (
                "<div class='candidate-info'>"
                f"<span class='candidate-rank'>Candidate {candidate['rank']}</span>"
                f"<span class='candidate-time'>{candidate['start_time']:.1f}s - {candidate['end_time']:.1f}s</span>"
                "</div>"
            )
            updates.extend([
                gr.update(visible=True),
                gr.update(value=info_html),
                gr.update(interactive=True),
            ])
        else:
            updates.extend([
                gr.update(visible=False),
                gr.update(value=""),
                gr.update(interactive=False),
            ])

    return updates

CSS = """
#video_player_container {
    width: 100% !important;
    max-width: 600px !important;
    height: 400px !important;
    margin: 0 auto 15px auto !important;
    display: flex !important;
    align-items: center !important;
    justify-content: center !important;
    border: 2px dashed #e0e0e0 !important;
    border-radius: 8px !important;
    background-color: #fafafa !important;
}

.result-summary {
    background: #f7fbff !important;
    border: 1px solid #d3e3ff !important;
    border-radius: 8px !important;
    padding: 12px 16px !important;
    margin-bottom: 12px !important;
    display: flex !important;
    flex-direction: column !important;
    gap: 6px !important;
}

.result-summary--empty {
    background: #f9f9f9 !important;
    border-color: #e5e5e5 !important;
    color: #7a7a7a !important;
}

.summary-title {
    font-weight: 600 !important;
    font-size: 15px !important;
    color: #3056d3 !important;
}

.summary-time {
    font-size: 20px !important;
    font-weight: 700 !important;
    color: #1f2937 !important;
}

.candidate-row {
    align-items: center !important;
    margin-bottom: 8px !important;
}

.candidate-info {
    background: #ffffff !important;
    border: 1px solid #e5e7eb !important;
    border-radius: 8px !important;
    padding: 10px 14px !important;
    flex: 1 !important;
    display: flex !important;
    justify-content: space-between !important;
    align-items: center !important;
    font-size: 14px !important;
    color: #1f2937 !important;
}

.candidate-rank {
    font-weight: 600 !important;
    color: #2563eb !important;
}

.candidate-time {
    font-variant-numeric: tabular-nums !important;
}
"""

with gr.Blocks(theme=gr.themes.Soft(), css=CSS) as demo:
    gr.Markdown("# üé¨ Switch-NET Video Moment Retrieval")
    gr.Markdown("Upload a video, enter a natural language query, and use the Switch-NET model to find relevant moments in the video!")

    with gr.Row():
        with gr.Column(scale=2):
            video_player = gr.HTML(
                value='<div style="width: 400px; height: 200px; border: 2px dashed #ccc; border-radius: 8px; display: flex; align-items: center; justify-content: center; background-color: #f9f9f9;"><div style="text-align: center; color: #666;"><h3>üìπ Video Player</h3><p>Please upload a video file to begin</p></div></div>',
                elem_id="video_player_container",
            )
            upload_button = gr.UploadButton("Upload video file", file_types=["video"])
            status_display = gr.Textbox(label="Status", interactive=False, lines=8)

        with gr.Column(scale=1):
            query_box = gr.Textbox(label="Query", placeholder="For example: a person is chopping peppers...")
            search_button = gr.Button("üîç Search", variant="primary")

            gr.Markdown("### Search Results")
            summary_html = gr.HTML("<div class='result-summary result-summary--empty'>Waiting for retrieval</div>")
            gr.Markdown("#### Candidate time segments list")

            candidate_components = []
            for idx in range(TOP_K_RESULTS):
                with gr.Row(visible=False, elem_classes="candidate-row") as row:
                    info = gr.HTML("")
                    btn = gr.Button("Go", variant="secondary")
                candidate_components.append((row, info, btn))

    search_results = gr.State([])

    upload_outputs = [
        video_player,
        status_display,
        search_results,
        summary_html,
    ]
    for row, info, btn in candidate_components:
        upload_outputs.extend([row, info, btn])

    upload_button.upload(
        fn=load_video,
        inputs=[upload_button],
        outputs=upload_outputs,
    )

    def handle_search_and_update(query_text, video_path):
        for status, results in search_moments(query_text, video_path):
            interface_updates = create_result_interface(results)
            yield status, results, *interface_updates

    search_outputs = [
        status_display,
        search_results,
        summary_html,
    ]
    for row, info, btn in candidate_components:
        search_outputs.extend([row, info, btn])

    search_button.click(
        fn=handle_search_and_update,
        inputs=[query_box, video_player],
        outputs=search_outputs,
        show_progress=False
    )

    def jump_to_candidate(results, candidate_index):
        if video_state.is_processing:
            return "Processing is in progress, please wait...", gr.update()
        if not results:
            return "No available results", gr.update()
        if candidate_index >= len(results):
            return "Candidate index out of range", gr.update()
        candidate = results[candidate_index]
        status, video_html = simple_video_jump(candidate['start_time'], candidate['end_time'])
        return status, video_html

    for idx, (_, _, btn) in enumerate(candidate_components):
        btn.click(
            fn=jump_to_candidate,
            inputs=[search_results, gr.State(idx)],
            outputs=[status_display, video_player]
        )

print("Gradio interface created!")

Gradio interface created!


Model loaded from output/models/best_model_moe.pth
Resolved config: {'txt_dim': 512, 'vid_dim': 2818, 'hidden_dim': 256, 'num_queries': 5, 'aux_loss': True, 'position_embedding': 'sine', 'max_v_l': 194, 'max_q_l': 32, 'span_loss_type': 'l1', 'use_txt_pos': False, 'aud_dim': 0, 'queue_length': 65536, 'momentum': 0.995, 'distillation_coefficient': 0.3, 'num_v2t_encoder_layers': 2, 'num_encoder1_layers': 2, 'num_convolutional_blocks': 4, 'num_encoder2_layers': 2, 'num_decoder_layers': 2, 'num_decoder_loops': 3, 'clip_len': 1}
Trained for 20 epochs
Loaded CLIP model (ViT-B/32) on cpu


## 8. Launch the app

Run the Gradio app.

In [None]:
# Launch the demo
demo.launch(debug=True, share=False)

Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.
