# Switch-Net: Expert Iterative Decoder Network for Video Moment Retrieval

This notebook implements Switch-Net, a state-of-the-art model for video moment retrieval and highlight detection. Switch-Net addresses three key challenges in video-language understanding:

1. **Overlapping semantic information** in contrastive learning → Solved by Distill Align module
2. **Inefficient local video feature extraction** → Solved by Convolutional Fuser
3. **Inadequate multimodal decoding** → Solved by Loop Decoder with iterative refinement

**Architecture Overview:** Video/Text Encoders → Distill Align → Convolutional Fuser → Loop Decoder → Prediction Heads

**Paper:** https://arxiv.org/abs/2501.10787

## Setup and Dependencies

Import core libraries for deep learning (PyTorch), numerical computation (NumPy), and data processing. Set random seeds for reproducibility across different runs.

In [None]:
import os
import time
import json
import random
import numpy as np
from tqdm import tqdm
from collections import defaultdict
import math
import copy
from typing import Dict, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import torchvision.ops as ops
import wandb

# Import required for Hungarian algorithm
from scipy.optimize import linear_sum_assignment
import torchvision

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Utility Functions

Core utilities for span manipulation, IoU computation, and accuracy metrics:
- **Span conversion:** Transform between (center, width) and (start, end) formats for different loss functions
- **Generalized IoU:** Compute temporal intersection-over-union with penalty for non-overlapping regions
- **Position encoding:** Generate sinusoidal embeddings for temporal positions
- **Accuracy:** Compute top-k classification accuracy

In [None]:
def span_xx_to_cxw(xx_spans: torch.Tensor) -> torch.Tensor:
    """Convert [start, end] spans to [center, width]."""
    center = xx_spans.sum(-1) * 0.5
    width = xx_spans[..., 1] - xx_spans[..., 0]
    return torch.stack([center, width], dim=-1)

def span_cxw_to_xx(cxw_spans: torch.Tensor) -> torch.Tensor:
    """Convert [center, width] spans to [start, end]."""
    x1 = cxw_spans[..., 0] - 0.5 * cxw_spans[..., 1]
    x2 = cxw_spans[..., 0] + 0.5 * cxw_spans[..., 1]
    return torch.stack([x1, x2], dim=-1)

def temporal_iou(spans1: torch.Tensor, spans2: torch.Tensor):
    """Compute IoU and union between two span sets."""
    spans1 = spans1.float()
    spans2 = spans2.float()
    areas1 = spans1[:, 1] - spans1[:, 0]
    areas2 = spans2[:, 1] - spans2[:, 0]
    left = torch.max(spans1[:, None, 0], spans2[:, 0])
    right = torch.min(spans1[:, None, 1], spans2[:, 1])
    inter = (right - left).clamp(min=0)
    union = areas1[:, None] + areas2 - inter
    iou = inter / union.clamp(min=1e-6)
    return iou, union

def generalized_temporal_iou(spans1: torch.Tensor, spans2: torch.Tensor) -> torch.Tensor:
    """Generalized IoU for 1D spans."""
    spans1 = spans1.float()
    spans2 = spans2.float()
    assert (spans1[:, 1] >= spans1[:, 0]).all()
    assert (spans2[:, 1] >= spans2[:, 0]).all()
    iou, union = temporal_iou(spans1, spans2)
    left = torch.min(spans1[:, None, 0], spans2[:, 0])
    right = torch.max(spans1[:, None, 1], spans2[:, 1])
    enclosing = (right - left).clamp(min=0)
    return iou - (enclosing - union) / enclosing.clamp(min=1e-6)

def get_detr_position_encoding(max_length, num_pos_feats=128):
    """
    Create positional encodings similar to DETR for video sequences.
    Used when raw counting positions are needed.
    """
    position = torch.arange(max_length, dtype=torch.float32).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, num_pos_feats, 2, dtype=torch.float32) * (-math.log(10000.0) / num_pos_feats))
    pe = torch.zeros(max_length, num_pos_feats)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe

def gen_sineembed_for_position(pos_tensor):
    """
    Generate sine/cosine embeddings for reference points.
    Used when normalized [0, 1] positions are needed.
    """
    scale = 2 * math.pi
    dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
    dim_t = 10000 ** (2 * (dim_t // 2) / 128)
    center_embed = pos_tensor[:, :, 0] * scale
    pos_x = center_embed[:, :, None] / dim_t
    pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
    span_embed = pos_tensor[:, :, 1] * scale
    pos_w = span_embed[:, :, None] / dim_t
    pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
    return torch.cat((pos_x, pos_w), dim=2)

def inverse_sigmoid(x, eps=1e-3):
    """Inverse sigmoid function"""
    x = x.clamp(min=0, max=1)
    x1 = x.clamp(min=eps)
    x2 = (1 - x).clamp(min=eps)
    return torch.log(x1 / x2)

# Accuracy calculation
def accuracy(output, target, topk=(1,)):
    """Compute top-k accuracy (matches Switch-Net's (formerly LD-DETR) implementation)."""
    maxk = max(topk)
    num_items = output.size(0)
    if num_items == 0:
        return [output.new_tensor(0.0) for _ in topk]

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()

    if isinstance(target, torch.Tensor):
        if target.numel() == 0:
            return [output.new_tensor(0.0) for _ in topk]
        target = target.view(1, -1).expand_as(pred)
    else:
        target = torch.full_like(pred, target)

    correct = pred.eq(target)

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / max(num_items, 1)))
    return res

## Positional Encoding

Inject temporal information into feature sequences using sinusoidal functions at different frequencies. This allows the model to understand the relative ordering of video frames and text tokens.

**Two implementations:**
- **Sine:** Continuous sinusoidal positional encoding with learnable normalization scale
- **Learned:** Trainable position embeddings (similar to BERT)

**Why needed:** Transformers are permutation-invariant by design. Positional encoding provides crucial sequence order information for temporal understanding.

**Output dimension:** Returns `hidden_dim` dimensional embeddings (duplicates `num_pos_feats` to match model dimension)

In [None]:
class PositionEmbeddingSine(nn.Module):
    """
    Sine position embedding for video features
    This is a learnable positional encoding based on sine/cosine functions
    """
    def __init__(self, num_pos_feats=128, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, x, mask):
        """
        Args:
            x: (batch_size, L, d)
            mask: (batch_size, L), with 1 indicating valid positions
        Returns:
            pos: (batch_size, L, d) positional encoding
        """
        assert mask is not None, "mask is required for positional encoding"
        
        # Mask uses 1.0 for valid tokens; accumulate positions across valid entries.
        valid = mask.to(dtype=torch.float32, device=x.device)

        # Compute cumulative sum of valid positions, e.g., [1,1,1,0,0] -> [1,2,3,3,3]
        x_embed = valid.cumsum(1, dtype=torch.float32)

        # If normalization is enabled, scale positions to [0, scale], otherwise positions are raw counts
        if self.normalize:
            eps = 1e-6
            lengths = x_embed[:, -1:].clamp(min=eps)
            x_embed = x_embed / lengths * self.scale

        # Zero out positions corresponding to invalid tokens, e.g., [1,2,3,3,3] -> [1,2,3,0,0]
        x_embed = x_embed * valid

        # Create a 1D tensor of dimension indices: [0, 1, 2, ..., num_pos_feats-1]
        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        
        # 10000^(2i / d_model) (denominator of the positional encoding formula)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        # pos_x = pos/dim_t
        pos_x = x_embed[:, :, None] / dim_t  # (batch_size, L, num_pos_feats)

        # concatenate sin and cos embeddings
        pos = torch.cat((pos_x.sin(), pos_x.cos()), dim=-1)
        return pos

class PositionEmbeddingLearned(nn.Module):
    """Learned positional embeddings."""
    def __init__(self, num_pos_feats=256, max_len=512):
        super().__init__()
        self.position_embeddings = nn.Embedding(max_len, num_pos_feats)

    def forward(self, x, mask):
        batch_size, seq_len = x.shape[:2]
        position_ids = torch.arange(seq_len, dtype=torch.long, device=x.device)
        position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
        return self.position_embeddings(position_ids)

def build_position_encoding(hidden_dim, position_embedding_type="sine", max_len=512):
    """Factory that returns video/text positional encoders."""
    n_steps = hidden_dim // 2
    if position_embedding_type == "sine":
        position_embed = PositionEmbeddingSine(n_steps, normalize=True)
        txt_position_embed = PositionEmbeddingSine(n_steps, normalize=True)
    elif position_embedding_type == "learned":
        position_embed = PositionEmbeddingLearned(hidden_dim, max_len)
        txt_position_embed = PositionEmbeddingLearned(hidden_dim, max_len)
    else:
        raise ValueError(f"Unknown position embedding type: {position_embedding_type}")
    return position_embed, txt_position_embed

## Unimodal Encoder

Projects video and text features from their original feature spaces (e.g., CLIP: 512-dim, SlowFast: 2304-dim) into a shared latent space of `hidden_dim=256`. This dimensionality reduction enables efficient cross-modal interaction in subsequent modules.

**Architecture:** LayerNorm → Dropout → Linear → ReLU → LayerNorm → Dropout → Linear  
**Design rationale:**  
- Dual normalization layers stabilize training with pre-extracted features
- Heavy dropout (0.5) prevents overfitting to specific feature extractors
- ReLU activation introduces non-linearity for better expressiveness

**Role in pipeline:** First step in creating unified multimodal representations

In [None]:
class UnimodalEncoder(nn.Module):
    """
    Unimodal Encoder that maps input features to a shared latent space
    Architecture: LayerNorm -> Linear -> ReLU -> LayerNorm -> Linear
    """
    def __init__(self, input_dim, hidden_dim, dropout=0.5):
        super().__init__()
        self.norm1 = nn.LayerNorm(input_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.dropout2 = nn.Dropout(dropout)
        
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x):
        """
        Args:
            x: input features (batch_size, seq_len, input_dim)
        Returns:
            x: encoded features (batch_size, seq_len, hidden_dim)
        """
        x = self.linear1(self.dropout1(self.norm1(x)))
        x = F.relu(x, inplace=True)
        x = self.linear2(self.dropout2(self.norm2(x)))
        return x

print("Unimodal Encoder loaded successfully")

## Distill Align Module

**Key innovation:** Addresses the fundamental problem of overlapping semantics in video-text contrastive learning.

**Problem statement:** Traditional contrastive learning (e.g., CLIP) treats all negatives equally. However, video-text pairs often share partial semantic information:
- "A person walking in the park" vs "A person running in the park"
- Both should be partially similar, not completely negative

**Solution - Similarity Matrix Distillation:**
1. **Momentum encoders:** Maintain stable teacher representations (EMA with β=0.995)
2. **Queue mechanism:** Store 65,536 historical features for large-scale negative sampling
3. **Distillation loss:** Guide student similarity matrix toward teacher's identity matrix
4. **Temperature scaling:** Control the sharpness of similarity distributions

**Two complementary losses:**
- `loss_align`: Video-text alignment distilled from momentum encoder (coefficient α=0.4)
- `loss_sim`: Cross-modal similarity with queue-based hard negatives

**Why momentum:** Prevents representation collapse and provides consistent supervision signals during optimization

In [None]:
class DistillAlign(nn.Module):
    """
    Distill Align Module for multimodal alignment
    Key innovation: Distills similarity matrices into identity matrices to reduce 
    the impact of overlapping semantic information
    """
    def __init__(self, hidden_dim, queue_length=65536, temp=0.07, alpha=0.4):
        super().__init__()
        self.temp = nn.Parameter(torch.ones([]) * temp)
        self.alpha = alpha
        self.queue_length = queue_length
        
        # Cosine similarity
        self.cos = nn.CosineSimilarity(dim=1)
        
        # Prediction MLP
        self.h = nn.Linear(hidden_dim, hidden_dim)
        
        # Register global feature queues and normalize them
        # Initialize queues with random normalized vectors
        self.register_buffer("vid_queue", torch.randn(hidden_dim, queue_length))
        self.register_buffer("txt_queue", torch.randn(hidden_dim, queue_length))
        self.vid_queue = F.normalize(self.vid_queue, dim=0)
        self.txt_queue = F.normalize(self.txt_queue, dim=0)
        
        # Register the pointer
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
        
        # Initialize parameters
        for _, p in self.h.named_parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src_vid_cls, src_txt_cls, src_vid_cls_m, src_txt_cls_m, 
                epoch_i, batch_idx, train_loader_length=0, is_training=False):
        """Distills momentum similarities into the student branch.

        Args:
            src_vid_cls: video features from student branch (batch_size, hidden_dim)
            src_txt_cls: text features from student branch (batch_size, hidden_dim)
            src_vid_cls_m: video features from momentum branch (batch_size, hidden_dim)
            src_txt_cls_m: text features from momentum branch (batch_size, hidden_dim)
        """

        # Gradually increase alpha during the first epoch （alpha is the distillation weight）
        if epoch_i > 0 or train_loader_length == 0:
            alpha = self.alpha
        else:
            alpha = self.alpha * min(1, batch_idx / train_loader_length)


        # Compute Loss_align ("Learn from teacher")
        with torch.no_grad():
            self.temp.clamp_(0.001, 0.5)

            # Construct features for similarity computation
            # Video Features: [256, B + K]; Text Features: [256, B + K]
            vid_feat = torch.cat([src_vid_cls_m.t(), self.vid_queue.clone().detach()], dim=1)
            txt_feat = torch.cat([src_txt_cls_m.t(), self.txt_queue.clone().detach()], dim=1)

            # Compute similarity matrices from momentum branch
            # [B, 256] @ [256, B + K] -> [B, B + K]
            sim_v2t_m = src_vid_cls_m @ txt_feat / self.temp
            sim_t2v_m = src_txt_cls_m @ vid_feat / self.temp

            # Hard label, we believe video_1 is only related to text_1
            sim_targets = torch.zeros(sim_v2t_m.size(), device=src_vid_cls.device)
            sim_targets.fill_diagonal_(1)

            sim_v2t_targets = alpha * F.softmax(sim_v2t_m, dim=1) + (1 - alpha) * sim_targets
            sim_t2v_targets = alpha * F.softmax(sim_t2v_m, dim=1) + (1 - alpha) * sim_targets

        sim_v2t = src_vid_cls @ txt_feat / self.temp
        sim_t2v = src_txt_cls @ vid_feat / self.temp
        loss_v2t = -torch.sum(F.log_softmax(sim_v2t, dim=1) * sim_v2t_targets, dim=1).mean()
        loss_t2v = -torch.sum(F.log_softmax(sim_t2v, dim=1) * sim_t2v_targets, dim=1).mean()
        loss_align = (loss_v2t + loss_t2v) / 2
        
        if is_training:
            self._dequeue_and_enqueue(src_vid_cls_m, src_txt_cls_m)
        
        p_vid_cls = self.h(F.relu(src_vid_cls, inplace=False))
        p_txt_cls = self.h(F.relu(src_txt_cls, inplace=False))
        loss_sim = -(self.cos(p_vid_cls, src_txt_cls).mean() + self.cos(p_txt_cls, src_vid_cls).mean()) / 2
        return loss_align, loss_sim

    @torch.no_grad()
    def _dequeue_and_enqueue(self, vid_feat, txt_feat):
        """Push momentum global features into queues with wrap-around support."""
        batch_size = vid_feat.shape[0]
        if batch_size == 0:
            return
        vid_feat = vid_feat.detach()
        txt_feat = txt_feat.detach()
        queue_len = self.queue_length
        ptr = int(self.queue_ptr.item())
        
        if batch_size >= queue_len:
            self.vid_queue.copy_(vid_feat[-queue_len:].t())
            self.txt_queue.copy_(txt_feat[-queue_len:].t())
            ptr = 0
        else:
            end_ptr = ptr + batch_size
            if end_ptr <= queue_len:
                self.vid_queue[:, ptr:end_ptr] = vid_feat.t()
                self.txt_queue[:, ptr:end_ptr] = txt_feat.t()
            else:
                first_part = queue_len - ptr
                if first_part > 0:
                    self.vid_queue[:, ptr:] = vid_feat[:first_part].t()
                    self.txt_queue[:, ptr:] = txt_feat[:first_part].t()
                overflow = end_ptr - queue_len
                if overflow > 0:
                    self.vid_queue[:, :overflow] = vid_feat[first_part:].t()
                    self.txt_queue[:, :overflow] = txt_feat[first_part:].t()
            ptr = (ptr + batch_size) % queue_len
        self.queue_ptr[0] = ptr

## Transformer Components

Standard Transformer encoder and decoder layers forming the backbone for multimodal fusion and temporal modeling.

**TransformerEncoderLayer:**
- Self-attention mechanism for modeling within-modality relationships
- Position-aware: Queries and keys enhanced with positional encoding
- FFN (Feed-Forward Network): Two-layer MLP with ReLU/GELU activation
- Residual connections + Layer normalization for stable gradient flow

**TransformerDecoderLayer:**
- Self-attention: Models interactions among moment queries
- Cross-attention: Attends to fused video-text memory  
- Query position handling: Optional position projection for iterative refinement
- Three sub-layers: Self-attn → Cross-attn → FFN (each with residual + norm)

**Design choices:**
- `batch_first=False`: Sequence length as first dimension for compatibility
- Flexible activation: ReLU (faster) or GELU (smoother gradients)
- `normalize_before`: Optional pre-norm architecture (used in ViT-style models)

**Role in Switch-Net:** These form the building blocks for V2T extraction, convolutional fusion, and loop decoding

In [None]:
# Multi-Layer Perceptron (MLP)
class MLP(nn.Module):
    """Simple multi-layer perceptron (also called FFN)"""
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(
            nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x

# Top-2 Mixture-of-Experts FFN used inside the decoder
class Top2MoEFFN(nn.Module):
    """Top-2 gated mixture-of-experts feed-forward module."""
    def __init__(self, d_model, dim_feedforward, num_experts=8, top_k=2,
                 dropout=0.1, activation=F.relu, load_balance_coef=1.0):
        super().__init__()
        assert top_k <= num_experts, "top_k must be <= num_experts"
        self.num_experts = num_experts
        self.top_k = top_k
        self.activation = activation
        self.load_balance_coef = load_balance_coef
        self.gate = nn.Linear(d_model, num_experts)
        self.expert_fc1 = nn.ModuleList(
            nn.Linear(d_model, dim_feedforward) for _ in range(num_experts))
        self.expert_fc2 = nn.ModuleList(
            nn.Linear(dim_feedforward, d_model) for _ in range(num_experts))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        Args:
            x: (seq_len, batch_size, d_model)
        Returns:
            output: tensor with same shape as x after expert routing
            load_balance_loss: scalar auxiliary loss encouraging balanced expert usage
        """
        seq_len, batch_size, _ = x.shape
        token_count = seq_len * batch_size
        if token_count == 0:
            return x, x.new_zeros(())

        x_flat = x.reshape(token_count, -1)
        router_logits = self.gate(x_flat)
        router_probs = F.softmax(router_logits, dim=-1)

        topk_probs, topk_indices = torch.topk(router_probs, self.top_k, dim=-1)
        topk_probs = topk_probs / (topk_probs.sum(dim=-1, keepdim=True) + 1e-9)

        dispatch_mask = F.one_hot(topk_indices, num_classes=self.num_experts).to(router_probs.dtype)
        load_denominator = float(max(token_count * self.top_k, 1))
        expert_load = dispatch_mask.sum(dim=(0, 1)) / load_denominator
        router_prob_mean = router_probs.mean(dim=0)
        load_balance_loss = router_prob_mean * expert_load
        load_balance_loss = load_balance_loss.sum() * self.num_experts * self.load_balance_coef

        final_output_flat = x_flat.new_zeros(x_flat.shape)

        for expert_id in range(self.num_experts):
            fc1 = self.expert_fc1[expert_id]
            fc2 = self.expert_fc2[expert_id]
            for slot_idx in range(self.top_k):
                slot_mask = topk_indices[:, slot_idx] == expert_id
                if not slot_mask.any():
                    continue
                token_indices = torch.nonzero(slot_mask, as_tuple=False).squeeze(1)
                expert_input = x_flat.index_select(0, token_indices)
                gate_weight = topk_probs.index_select(0, token_indices)[:, slot_idx]
                hidden = fc1(expert_input)
                hidden = self.activation(hidden)
                hidden = self.dropout(hidden)
                expert_output = fc2(hidden)
                weighted_output = expert_output * gate_weight.unsqueeze(1)
                final_output_flat.index_add_(0, token_indices, weighted_output)

        output = final_output_flat.view(seq_len, batch_size, -1)
        return output, load_balance_loss

# Transformer Encoder Layer
class TransformerEncoderLayer(nn.Module):
    """Standard Transformer encoder layer"""
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 
                 activation="relu", normalize_before=False):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=False)
        
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        
        self.activation = F.relu if activation == "relu" else F.gelu
        self.normalize_before = normalize_before

    # This implements a standard self-attention + feedforward layer
    def forward(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
        """
        Args:
            src: (seq_len, batch_size, d_model)
            pos: (seq_len, batch_size, d_model) positional encoding
        """
        q = k = src if pos is None else src + pos
        src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
                             key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

# Transformer Decoder Layer
class TransformerDecoderLayer(nn.Module):
    """Standard Transformer decoder layer with cross-attention"""
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False, keep_query_pos=False,
                 moe_num_experts=8, moe_top_k=2, moe_load_balance_coef=0.01):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=False)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=False)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
        
        if activation == "relu":
            self.activation_fn = F.relu
        elif activation == "gelu":
            self.activation_fn = F.gelu
        elif activation == "prelu":
            self.activation_fn = nn.PReLU()
        else:
            self.activation_fn = F.gelu
        self.activation = self.activation_fn
        self.normalize_before = normalize_before
        self.keep_query_pos = keep_query_pos
        
        self.moe_ffn = Top2MoEFFN(
            d_model,
            dim_feedforward,
            num_experts=moe_num_experts,
            top_k=moe_top_k,
            dropout=dropout,
            activation=self.activation_fn,
            load_balance_coef=moe_load_balance_coef,
        )
        
        # For query position
        self.ca_qpos_proj = nn.Linear(d_model, d_model) if keep_query_pos else None

    # This implements self-attention, cross-attention, and feedforward
    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
                tgt_key_padding_mask=None, memory_key_padding_mask=None,
                pos=None, query_pos=None, query_sine_embed=None, is_first=False):
        """
        Args:
            tgt: (num_queries, batch_size, d_model)
            memory: (seq_len, batch_size, d_model)
        """
        # Self-attention
        q = k = tgt if query_pos is None else tgt + query_pos
        tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
                             key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        
        # Cross-attention
        if self.ca_qpos_proj is not None:
            q = tgt + self.ca_qpos_proj(query_pos)
        else:
            q = tgt + query_pos if query_pos is not None else tgt
        
        k = memory if pos is None else memory + pos
        tgt2 = self.multihead_attn(query=q, key=k, value=memory,
                                   attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        
        # Mixture-of-experts feedforward
        tgt2, moe_loss = self.moe_ffn(tgt)
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        
        return tgt, moe_loss

def _get_clones(module, N):
    """Clone a module N times"""
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

print("Transformer components loaded successfully")

### 4.4 Convolutional Fuser

The next three cells unpack the multimodal fusion stack in detail:

- **V2TExtractor** strips away text-irrelevant video content via cross-modal attention so the downstream encoder receives query-aware clips.
- **ConvolutionalBlock** alternates temporal convolutions and residual connections to capture local dynamics that transformers might miss.
- **ConvolutionalFuser** stitches the pieces together, routing video/text through stacked transformer encoders and the convolutional block before handing a fused memory tensor to the decoder.

Skimming the docstrings gives implementation specifics, while the summary above clarifies how the pieces cooperate.

In [None]:
# V2T Extractor - removes text-irrelevant information from video features
class V2TExtractor(nn.Module):
    """
    V2T Extractor extracts text-irrelevant video features
    """
    def __init__(self, hidden_dim=256, dropout=0.1):
        super().__init__()

        # Weight for context
        # Weight for query
        # Weight for interaction
        w4C = torch.empty(hidden_dim, 1)
        w4Q = torch.empty(hidden_dim, 1)
        w4mlu = torch.empty(1, 1, hidden_dim)
        self.w4C = nn.Parameter(w4C, requires_grad=True)
        self.w4Q = nn.Parameter(w4Q, requires_grad=True)
        self.w4mlu = nn.Parameter(w4mlu, requires_grad=True)
        self.dropout = nn.Dropout(p=dropout)
        self.cqa_linear = nn.Conv1d(
            in_channels=4 * hidden_dim,
            out_channels=hidden_dim,
            kernel_size=1,
            padding=0,
            stride=1,
            bias=True,
        )
        weight = torch.empty(hidden_dim, 1)
        self.weight = nn.Parameter(weight, requires_grad=True)
        self.conv1d = nn.Conv1d(
            in_channels=2 * hidden_dim,
            out_channels=hidden_dim,
            kernel_size=1,
            padding=0,
            stride=1,
            bias=True,
        )
        
        # Initialize parameters
        nn.init.xavier_uniform_(self.w4C)
        nn.init.xavier_uniform_(self.w4Q)
        nn.init.xavier_uniform_(self.w4mlu)
        nn.init.xavier_uniform_(self.weight)

    def forward(self, src_vid, src_txt, src_vid_mask, src_txt_mask):
        """
        Args:
            src_vid: video features (batch_size, L, hidden_dim)
            src_txt: text features (batch_size, N, hidden_dim)
            src_vid_mask: video mask (batch_size, L)
            src_txt_mask: text mask (batch_size, N)
        Returns:
            output: highly text-irrelevant video features (batch_size, L, hidden_dim)
        """

        # Unpacking the input shapes, which will be used for broadcasting
        batch_size, c_seq_len, dim = src_vid.shape
        batch_size, q_seq_len, dim = src_txt.shape

        # Random dropout (Regularization to enhance generalization)
        context = self.dropout(src_vid)
        query = self.dropout(src_txt)
        
        # Compute attention scores (Score = subres0 + subres1 + subres2)
        # The subres0 is the contribution from video. Finally the W4C will be a video importance determiner
        # subres1 from text. Finally the W4Q will be a text importance determiner
        # subres2 from their interaction, this is the standard way to compute attention scores.
        subres0 = torch.matmul(context, self.w4C).expand([-1, -1, q_seq_len])
        subres1 = torch.matmul(query, self.w4Q).transpose(1, 2).expand([-1, c_seq_len, -1])
        subres2 = torch.matmul(context * self.w4mlu, query.transpose(1, 2))

        # Final attention score, where score[i,j] = i-th basic video feature + j-th basic text feature + their interaction
        score = subres0 + subres1 + subres2

        # The following code blocks compute our new video features based on cross-modal attention

        # Step 1 Compute c2q and q2c attention
        # Apply masks and compute attention
        score_ = F.softmax(self.mask_logits(score, src_txt_mask.unsqueeze(1)), dim=2)
        score_t = F.softmax(self.mask_logits(score, src_vid_mask.unsqueeze(2)), dim=1)
        score_t = score_t.transpose(1, 2)
        
        # Context-to-query and query-to-context attention
        c2q = torch.matmul(score_, src_txt)
        q2c = torch.matmul(torch.matmul(score_, score_t), src_vid)
        
        # Step 2 combine original video features with attended features
        # Concatenate features
        feats = torch.cat([src_vid, c2q, torch.mul(src_vid, c2q), torch.mul(src_vid, q2c)], dim=2)
        feats = feats.transpose(1, 2)
        feats = self.cqa_linear(feats)
        feats = feats.transpose(1, 2)

        # After step 2, we have obtained text-aware video features (feats)

        # Step 3, combine "feats" with "pooled global text" (pooled_src_txt) to get the FINAL text-aware video features
        # Pooling query features
        alpha = torch.tensordot(src_txt, self.weight, dims=1)
        alpha = self.mask_logits(alpha, mask=src_txt_mask.unsqueeze(2))
        alphas = F.softmax(alpha, dim=1)
        pooled_src_txt = torch.matmul(src_txt.transpose(1, 2), alphas)
        pooled_src_txt = pooled_src_txt.squeeze(2)
        # The pooled_src_txt is the global text feature after pooling
        
        # Combine with pooled text
        _, c_seq_len, _ = feats.shape
        pooled_src_txt = pooled_src_txt.unsqueeze(1).repeat(1, c_seq_len, 1)
        output = torch.cat([feats, pooled_src_txt], dim=2)
        output = output.transpose(1, 2)
        output = self.conv1d(output)
        output = output.transpose(1, 2)
        output = F.relu(output)
        return output

    def mask_logits(self, inputs, mask, mask_value=-1e30):
        mask = mask.type(torch.float32)
        return inputs + (1.0 - mask) * mask_value

# Transformer is used for global feature extraction, but we also need local feature extraction
# Convolutional Block for local feature extraction
class ConvolutionalBlock(nn.Module):
    """Residual convolutional blocks for extracting local features, This is the standard implementation of CNN blocks with residual connections."""
    def __init__(self, hidden_dim=256, n_blocks=5):
        super().__init__()
        
        class TheBlock(nn.Module):
            def __init__(self, hidden_dim=256):
                super().__init__()
                self.conv1 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1)
                self.bn1 = nn.BatchNorm1d(hidden_dim)
                self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1)
                self.bn2 = nn.BatchNorm1d(hidden_dim)

            def forward(self, x):
                residual = x
                out = F.relu(self.bn1(self.conv1(x)))
                out = self.bn2(self.conv2(out))
                return F.relu(out + residual)
        
        self.blocks = nn.ModuleList([TheBlock(hidden_dim) for _ in range(n_blocks)])

    def forward(self, x):
        """
        Args:
            x: (seq_len, batch_size, hidden_dim)
        Returns:
            out: (seq_len, batch_size, hidden_dim)
        """
        out = x.permute(1, 2, 0)  # (batch_size, hidden_dim, seq_len)
        for layer in self.blocks:
            out = layer(out)
        out = out.permute(2, 0, 1)  # (seq_len, batch_size, hidden_dim)
        return out

print("V2T Extractor and Convolutional Block loaded successfully")

#### Transformer Containers

To keep the architecture modular we reimplement DETR-style encoder/decoder containers. They wrap individual layers, expose hooks for intermediate activations, and are reused across both moment retrieval and highlight detection heads.

In [None]:
# Transformer Encoder and Decoder containers
class TransformerEncoder(nn.Module):
    """Stack of Transformer encoder layers"""
    def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False):
        super().__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm
        self.return_intermediate = return_intermediate

    def forward(self, src, mask=None, src_key_padding_mask=None, pos=None, **kwargs):
        output = src
        intermediate = []
        
        for layer in self.layers:
            output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, 
                          pos=pos, **kwargs)
            if self.return_intermediate:
                intermediate.append(output)
        
        if self.norm is not None:
            output = self.norm(output)
        
        if self.return_intermediate:
            return torch.stack(intermediate)
        
        return output


class TransformerDecoder(nn.Module):
    """Stack of Transformer decoder layers with iterative refinement"""
    def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False,
                 d_model=256, query_dim=2, keep_query_pos=False, 
                 query_scale_type="cond_elewise", modulate_t_attn=False,
                 bbox_embed_diff_each_layer=False):
        super().__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm
        self.return_intermediate = return_intermediate
        self.query_dim = query_dim
        
        # Query scale module
        self.query_scale_type = query_scale_type
        if query_scale_type == "cond_elewise":
            self.query_scale = MLP(d_model, d_model, d_model, 2)
        elif query_scale_type == "cond_scalar":
            self.query_scale = MLP(d_model, d_model, 1, 2)
        elif query_scale_type == "fix_elewise":
            self.query_scale = nn.Embedding(num_layers, d_model)
        
        self.ref_point_head = MLP(d_model, d_model, d_model, 2)
        
        # Bbox embedding for iterative refinement
        if bbox_embed_diff_each_layer:
            self.bbox_embed = nn.ModuleList([MLP(d_model, d_model, 2, 3) for _ in range(num_layers)])
        else:
            self.bbox_embed = MLP(d_model, d_model, 2, 3)
        
        # Initialize bbox_embed
        if bbox_embed_diff_each_layer:
            for bbox_embed in self.bbox_embed:
                nn.init.constant_(bbox_embed.layers[-1].weight.data, 0)
                nn.init.constant_(bbox_embed.layers[-1].bias.data, 0)
        else:
            nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
            nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
        
        self.d_model = d_model
        self.modulate_t_attn = modulate_t_attn
        self.bbox_embed_diff_each_layer = bbox_embed_diff_each_layer
        
        if modulate_t_attn:
            self.ref_anchor_head = MLP(d_model, d_model, 1, 2)
        
        if not keep_query_pos:
            for layer_id in range(num_layers - 1):
                self.layers[layer_id + 1].ca_qpos_proj = None

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
                tgt_key_padding_mask=None, memory_key_padding_mask=None,
                pos=None, refpoints_unsigmoid=None):

        # output is our query notes
        # first time it is all zero, then it is updated iteratively
        output = tgt
        # This will store the reference points at each layer
        intermediate = []
        # [mid, span], the shape is (batch_size, num_queries, 2)
        reference_points = refpoints_unsigmoid.sigmoid()
        ref_points = [reference_points]
        
        total_moe_loss = tgt.new_zeros(())
        
        # Layer by layer decoding
        for layer_id, layer in enumerate(self.layers):
            obj_center = reference_points[..., :self.query_dim]

            # Generate sine embedding for query dynamically
            query_sine_embed = gen_sineembed_for_position(obj_center)
            query_pos = self.ref_point_head(query_sine_embed)
            
            # Apply transformation
            if self.query_scale_type != "fix_elewise":
                pos_transformation = 1 if layer_id == 0 else self.query_scale(output)
            else:
                pos_transformation = self.query_scale.weight[layer_id]
            
            query_sine_embed = query_sine_embed * pos_transformation
            
            # Modulated attention
            if self.modulate_t_attn:
                reft_cond = self.ref_anchor_head(output).sigmoid()
                query_sine_embed *= (reft_cond[..., 0] / obj_center[..., 1]).unsqueeze(-1)

            # Apply decoder layer (self attention and cross-attention)
            output, layer_moe_loss = layer(output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
                                           tgt_key_padding_mask=tgt_key_padding_mask,
                                           memory_key_padding_mask=memory_key_padding_mask,
                                           pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed,
                                           is_first=(layer_id == 0))
            total_moe_loss = total_moe_loss + layer_moe_loss
            
            # Iterative box refinement
            if self.bbox_embed is not None:
                if self.bbox_embed_diff_each_layer:
                    tmp = self.bbox_embed[layer_id](output)
                else:
                    tmp = self.bbox_embed(output)
                tmp[..., :self.query_dim] += inverse_sigmoid(reference_points)
                new_reference_points = tmp[..., :self.query_dim].sigmoid()
                if layer_id != self.num_layers - 1:
                    ref_points.append(new_reference_points)
                reference_points = new_reference_points.detach()
            
            if self.return_intermediate:
                intermediate.append(self.norm(output))
        
        if self.norm is not None:
            output = self.norm(output)
            if self.return_intermediate:
                intermediate.pop()
                intermediate.append(output)
        
        total_moe_loss = total_moe_loss / max(self.num_layers, 1)
        
        if self.return_intermediate:
            if self.bbox_embed is not None:
                # torch.stack(intermediate), this is (num_layers, batch_size, num_queries, d_model)
                # torch.stack(ref_points), this is (num_layers, batch_size, num_queries, 2)
                return [torch.stack(intermediate).transpose(1, 2), 
                       torch.stack(ref_points).transpose(1, 2)], total_moe_loss
            else:
                return [torch.stack(intermediate).transpose(1, 2),
                       reference_points.unsqueeze(0).transpose(1, 2)], total_moe_loss
        
        return output.unsqueeze(0), total_moe_loss


def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    if activation == "prelu":
        return nn.PReLU()
    if activation == "selu":
        return F.selu
    raise RuntimeError(f"activation should be relu/gelu/glu/prelu/selu, not {activation}.")


class T2V_TransformerEncoderLayer_no_global(nn.Module):
    """Text-to-Video cross-attention layer without global token"""
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=False)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before
        self.nhead = nhead

    def with_pos_embed(self, tensor, pos):
        return tensor if pos is None else tensor + pos

    # A cross-attention layer from text to video and a feedforward layer
    def forward(self, src, src_mask=None, src_key_padding_mask=None, pos=None, video_length=None, **kwargs):
        assert video_length is not None
        
        pos_src = self.with_pos_embed(src, pos)

        # note that from T2VExtractor, src is [video_length + text_length, batch_size, d_model]
        # q pos_src is video features + video positional encodings
        # k pos_src is text features + text positional encodings
        # v src is text features
        q, k, v = pos_src[:video_length], pos_src[video_length:], src[video_length:]
        
        # Create attention mask for video-to-text attention
        qmask = src_key_padding_mask[:, :video_length].unsqueeze(2)
        kmask = src_key_padding_mask[:, video_length:].unsqueeze(1)
        attn_mask = torch.matmul(qmask.float(), kmask.float()).bool().repeat(self.nhead, 1, 1)
        
        src2 = self.self_attn(q, k, value=v, attn_mask=attn_mask,
                             key_padding_mask=src_key_padding_mask[:, video_length:])[0]
        
        src2 = src[:video_length] + self.dropout1(src2)

        src3 = self.norm1(src2)
        src3 = self.linear2(self.dropout(self.activation(self.linear1(src3))))
        
        src2 = src2 + self.dropout2(src3)
        src2 = self.norm2(src2)
       
        src = torch.cat([src2, src[video_length:]])
        
        return src

print("Transformer Encoder and Decoder containers loaded successfully")

#### Decoder Stack & Heads

With the container utilities in place, the next definitions wire up the Switch-Net decoder stack (formerly LD-DETR):
- **LoopDecoder** performs iterative query refinement across multiple decoder layers.
- **PredictionHeads** branch into the moment-retrieval classifier/regressor and the highlight-detection saliency head, keeping auxiliary outputs when deep supervision is enabled.

These components consume the fused memory tensor produced earlier and emit the logits/spans/saliency signals that feed the training losses.

In [None]:
# Convolutional Fuser - combines multiple encoding stages
class ConvolutionalFuser(nn.Module):
    """
    Convolutional Fuser for efficient multimodal feature extraction
    Pipeline: V2T Extractor -> T2V Encoder -> Encoder1 -> Conv Blocks -> Encoder2
    """
    def __init__(self, hidden_dim=256, nhead=8, dim_feedforward=1024, dropout=0.1,
                 activation="prelu", num_v2t_encoder_layers=2, num_encoder1_layers=2,
                 num_convolutional_blocks=5, num_encoder2_layers=2, normalize_before=False):
        super().__init__()
        
        # V2T Extractor
        self.v2t_extractor = V2TExtractor(hidden_dim, dropout)
        
        # T2V Encoder
        self.t2v_encoder = TransformerEncoder(
            T2V_TransformerEncoderLayer_no_global(hidden_dim, nhead, dim_feedforward, 
                                                  dropout, activation, normalize_before),
            num_v2t_encoder_layers, None)
        
        # Transformer Encoder 1
        self.transformer_encoder1 = TransformerEncoder(
            TransformerEncoderLayer(hidden_dim, nhead, dim_feedforward, 
                                   dropout, activation, normalize_before),
            num_encoder1_layers, None)
        
        # Convolutional Blocks
        self.convolutional_block = ConvolutionalBlock(hidden_dim, num_convolutional_blocks)
        
        # Transformer Encoder 2
        self.transformer_encoder2 = TransformerEncoder(
            TransformerEncoderLayer(hidden_dim, nhead, dim_feedforward, 
                                   dropout, activation, normalize_before),
            num_encoder2_layers, None)
        
        # Initialize parameters
        for module in [self.v2t_extractor, self.t2v_encoder, 
                      self.transformer_encoder1, self.transformer_encoder2]:
            for n, p in module.named_parameters():
                if p.dim() > 1:
                    nn.init.xavier_uniform_(p)

    def forward(self, src_vid, src_txt, src_vid_mask, src_txt_mask, pos_vid, pos_txt):
        """
        Args:
            src_vid: video features (batch_size, L, hidden_dim)
            src_txt: text features (batch_size, N, hidden_dim)
            src_vid_mask: video mask (batch_size, L)
            src_txt_mask: text mask (batch_size, N)
            pos_vid: video position encoding (batch_size, L, hidden_dim)
            pos_txt: text position encoding (batch_size, N, hidden_dim)
        Returns:
            memory: fused multimodal features (L, batch_size, hidden_dim)
        """
        # V2T Extractor, this return a video features with less text-irrelevant information
        # Multi-Co Attention between video and text
        src_vid = self.v2t_extractor(src_vid, src_txt, src_vid_mask, src_txt_mask)
        
        # T2V Encoder (cross-modal fusion)
        src = torch.cat([src_vid, src_txt], dim=1)
        # mask help the model ignore the padding positions
        mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool()
        pos = torch.cat([pos_vid, pos_txt], dim=1)
        video_length = src_vid.shape[1]
        
        # Transform to (seq_len, batch_size, hidden_dim) which will be used in Transformer
        src = src.permute(1, 0, 2)  # (L+N, batch_size, hidden_dim)
        pos = pos.permute(1, 0, 2)
        
        # Note we need to invert the mask for Transformer
        src = self.t2v_encoder(src, src_key_padding_mask=~mask, pos=pos, 
                              video_length=video_length)
        
        # Since we only need video features for moment retrieval, we slice them here,and discard text features
        src_vid = src[:video_length]
        src_vid_mask = mask[:, :video_length]
        pos_vid = pos[:video_length]
        
        # Transformer Encoder 1
        memory = self.transformer_encoder1(src_vid, src_key_padding_mask=~src_vid_mask, 
                                          pos=pos_vid)
        
        # Convolutional Blocks (local feature extraction)
        memory = self.convolutional_block(memory)
        
        # Transformer Encoder 2
        memory = self.transformer_encoder2(memory, src_key_padding_mask=~src_vid_mask, 
                                          pos=pos_vid)
        
        return memory


class LoopDecoder(nn.Module):
    """
    Loop Decoder for iterative refinement of predictions
    Key innovation: Feed decoder output back as input multiple times
    """
    def __init__(self, hidden_dim=256, nhead=8, dim_feedforward=1024, dropout=0.1,
                 activation="prelu", normalize_before=False, keep_query_pos=False,
                 num_decoder_layers=2, return_intermediate_dec=True, query_dim=2,
                 query_scale_type="cond_elewise", modulate_t_attn=True,
                 bbox_embed_diff_each_layer=False, num_decoder_loops=3,
                 moe_num_experts=8, moe_top_k=2, moe_load_balance_coef=0.01):
        super().__init__()
        self.num_decoder_loops = num_decoder_loops
        
        # Transformer Decoder
        self.transformer_decoder = TransformerDecoder(
            TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, dropout,
                                   activation, normalize_before, keep_query_pos=keep_query_pos,
                                   moe_num_experts=moe_num_experts,
                                   moe_top_k=moe_top_k,
                                   moe_load_balance_coef=moe_load_balance_coef),
            num_decoder_layers, nn.LayerNorm(hidden_dim),
            return_intermediate=return_intermediate_dec, d_model=hidden_dim,
            query_dim=query_dim, keep_query_pos=keep_query_pos,
            query_scale_type=query_scale_type, modulate_t_attn=modulate_t_attn,
            bbox_embed_diff_each_layer=bbox_embed_diff_each_layer)
        
        # Initialize parameters
        for n, p in self.transformer_decoder.named_parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, tgt, memory, src_vid_mask, pos_vid, refpoint_embed):
        """
        Args:
            tgt: zero matrix for queries (num_queries, batch_size, hidden_dim)
            memory: multimodal features (L, batch_size, hidden_dim) < This is read-only
            src_vid_mask: video mask (batch_size, L)
            pos_vid: video position encoding (batch_size, L, hidden_dim)
            refpoint_embed: reference point embeddings (num_queries, batch_size, 2)
        Returns:
            hs: decoded features (num_layers, batch_size, num_queries, hidden_dim)
            reference: refined reference points (num_layers, batch_size, num_queries, 2)
            moe_loss: scalar load-balancing loss aggregated across loops
        """
        # Convert pos_vid to (L, batch_size, hidden_dim)
        pos_vid = pos_vid.permute(1, 0, 2)
        current_tgt = tgt
        current_refpoints = refpoint_embed

        # The results after multiple loops
        hs = None  # hidden states
        reference = None
        total_moe_loss = memory.new_zeros(())
        
        # Iterative refinement through multiple decoder loops
        for loop_idx in range(self.num_decoder_loops):
            decoder_outputs, loop_moe_loss = self.transformer_decoder(
                current_tgt,  # num queries
                memory,  # the multimodal features, read-only
                memory_key_padding_mask=~src_vid_mask,  # This tells the model which positions to ignore
                pos=pos_vid,  # position encoding for memory
                refpoints_unsigmoid=current_refpoints)  # reference points
            hs, reference = decoder_outputs
            total_moe_loss = total_moe_loss + loop_moe_loss
            
            if loop_idx < self.num_decoder_loops - 1:
                current_tgt = hs[-1].transpose(0, 1)  # The notes from previous loop
                current_refpoints = inverse_sigmoid(reference[-1]).transpose(0, 1)  # The refined reference points
        
        total_moe_loss = total_moe_loss / max(self.num_decoder_loops, 1)
        
        return hs, reference, total_moe_loss


class VideoMomentRetrievalPredictionHead(nn.Module):
    """Prediction head for moment retrieval (classification + span regression)"""
    def __init__(self, hidden_dim, span_loss_type):
        super().__init__()
        self.span_loss_type = span_loss_type
        self.class_embed = nn.Linear(hidden_dim, 2)  # foreground/background
        self.span_embed1 = nn.Linear(hidden_dim, hidden_dim)
        self.span_embed2 = nn.Linear(hidden_dim, hidden_dim)
        self.span_embed3 = nn.Linear(hidden_dim, 2)  # center, width

    def forward(self, hs, reference):
        """
        Args:
            hs: decoded features (num_layers, batch_size, num_queries, hidden_dim)
            reference: reference points (num_layers, batch_size, num_queries, 2)
        Returns:
            pred_logits: classification logits
            pred_spans: predicted spans
            pred_logits_others: auxiliary predictions
            pred_spans_others: auxiliary spans
        """
        outputs_class = self.class_embed(hs)
        reference_before_sigmoid = inverse_sigmoid(reference)
        # we deviate from original implementation for better stability
        tmp = self.span_embed3(F.relu(self.span_embed2(F.relu(self.span_embed1(hs)))))
        # Add the offset to the reference points (Residual connection)
        outputs_coord = tmp + reference_before_sigmoid
        
        if self.span_loss_type == "l1":
            outputs_coord = outputs_coord.sigmoid()  # normalize to [0, 1] (Sigmoid)
        
        pred_logits = outputs_class[-1]
        pred_spans = outputs_coord[-1]
        pred_logits_others = outputs_class[:-1]
        pred_spans_others = outputs_coord[:-1]
        
        return pred_logits, pred_spans, pred_logits_others, pred_spans_others


class HighlightDetectionPredictionHead(nn.Module):
    """Prediction head for highlight detection (saliency scores)"""
    def __init__(self, hidden_dim=256, clip_len=2):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.clip_len = clip_len
        self.gru = nn.GRU(hidden_dim, hidden_dim, num_layers=1, 
                         bidirectional=False, batch_first=True)
        self.saliency_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, pred_logits, pred_spans, memory, src_vid):
        """
        Args:
            pred_logits: classification logits (batch_size, num_queries, 2)
            pred_spans: predicted spans (batch_size, num_queries, 2)
            memory: multimodal features (batch_size, L, hidden_dim)
            src_vid: original video features (batch_size, L, video_dim)
        Returns:
            saliency_scores: per-clip saliency scores (batch_size, L)
        """
        video_length = memory.shape[1]
        batch_size = memory.size(0)
        
        # Get top predicted moment
        prob = F.softmax(pred_logits, -1)
        scores = prob[..., 0]
        sorted_indices = torch.sort(scores, dim=-1, descending=True)[1]
        sorted_indices_max = sorted_indices[:, :1]
        
        # Convert spans to clip indices
        spans = span_cxw_to_xx(pred_spans) * (video_length * self.clip_len)
        spans = torch.floor(spans / self.clip_len)
        
        selected_values_max = spans[torch.arange(batch_size).unsqueeze(1), 
                                   sorted_indices_max].squeeze(1)
        
        # Extract features from predicted moments
        sliced_samples = []
        for i in range(batch_size):
            start_time = int(selected_values_max[i, 0].clamp(0, video_length-1))
            end_time = int(selected_values_max[i, 1].clamp(0, video_length-1))
            sliced_sample = src_vid[i, start_time:end_time + 1, :]
            
            padding_size = video_length - sliced_sample.size(0)
            if padding_size > 0:
                padded_slice = F.pad(sliced_sample, (0, 0, 0, padding_size), value=0)
            else:
                padded_slice = sliced_sample[:video_length, :]
            
            sliced_samples.append(padded_slice)
        
        # Process with GRU and compute saliency
        sliced_features = torch.stack(sliced_samples, dim=0)
        _, hidden = self.gru(sliced_features)
        hidden = hidden[-1, :, :]
        
        # Compute attention weights
        weight = torch.matmul(hidden.unsqueeze(1), src_vid.transpose(1, 2)).squeeze(1)
        # use weight to enhance memory
        memory = memory * weight.unsqueeze(-1) + memory

        # Project to saliency scores (use a mlp(saliency_proj) followed by sum and scaling)
        saliency_scores = torch.sum(self.saliency_proj(memory), dim=-1) / np.sqrt(self.hidden_dim)
        
        return saliency_scores


class PredictionHeads(nn.Module):
    """Combined prediction heads for moment retrieval and highlight detection"""
    def __init__(self, hidden_dim, span_loss_type, clip_len=2, aux_loss=True):
        super().__init__()
        self.aux_loss = aux_loss
        self.video_moment_retrieval_prediction_head = VideoMomentRetrievalPredictionHead(
            hidden_dim, span_loss_type)
        self.highlight_detection_prediction_head = HighlightDetectionPredictionHead(
            hidden_dim, clip_len)

    def forward(self, hs, reference, memory, src_vid):
        """
        Args:
            hs: decoded features
            reference: reference points
            memory: multimodal features
            src_vid: original video features
        Returns:
            pred_logits: classification predictions
            pred_spans: span predictions
            saliency_scores: highlight scores
            aux_outputs: auxiliary outputs for deep supervision
        """
        # Moment retrieval predictions
        pred_logits, pred_spans, pred_logits_others, pred_spans_others = \
            self.video_moment_retrieval_prediction_head(hs, reference)
        
        # Highlight detection predictions
        saliency_scores = self.highlight_detection_prediction_head(
            pred_logits, pred_spans, memory, src_vid)
        
        # Auxiliary outputs
        aux_outputs = None
        if self.aux_loss:
            aux_outputs = [{"pred_logits": a, "pred_spans": b} 
                          for a, b in zip(pred_logits_others, pred_spans_others)]
        
        return pred_logits, pred_spans, saliency_scores, aux_outputs

print("Convolutional Fuser, Loop Decoder, and Prediction Heads loaded successfully")

### 4.5 Complete Switch-Net Model

Now we combine all components into the complete Switch-Net model (renamed from LD-DETR).

In [None]:
# Complete Switch-Net Model (formerly LD-DETR)
class SwitchNet(nn.Module):
    """
    Switch-Net: Expert Decoder Refinement Network for video moment retrieval and highlight detection.

    Backward compatibility: the legacy name LD-DETR remains available via an alias.
    """
    def __init__(self, txt_dim, vid_dim, hidden_dim, num_queries, aux_loss=False,
                 position_embedding="sine", max_v_l=75, max_q_l=32, span_loss_type="l1",
                 use_txt_pos=False, aud_dim=0, queue_length=65536, momentum=0.995,
                 distillation_coefficient=0.4, num_v2t_encoder_layers=2,
                 num_encoder1_layers=2, num_convolutional_blocks=5,
                 num_encoder2_layers=2, num_decoder_layers=2, num_decoder_loops=3,
                 clip_len=2, moe_num_experts=8, moe_top_k=2, moe_load_balance_coef=0.01):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_queries = num_queries
        self.span_loss_type = span_loss_type
        self.max_v_l = max_v_l
        self.max_q_l = max_q_l
        self.use_txt_pos = use_txt_pos
        self.momentum = momentum
        self.clip_len = clip_len
        self.moe_num_experts = moe_num_experts
        self.moe_top_k = moe_top_k
        self.moe_load_balance_coef = moe_load_balance_coef
        
        # Positional embedding
        self.position_embed, self.txt_position_embed = build_position_encoding(
            hidden_dim, position_embedding, max_q_l)
        
        # Unimodal encoders
        self.video_encoder = UnimodalEncoder(vid_dim + aud_dim, hidden_dim)
        self.text_encoder = UnimodalEncoder(txt_dim, hidden_dim)
        self.momentum_video_encoder = UnimodalEncoder(vid_dim + aud_dim, hidden_dim)
        self.momentum_text_encoder = UnimodalEncoder(txt_dim, hidden_dim)
        
        self.model_pairs = [
            [self.video_encoder, self.momentum_video_encoder],
            [self.text_encoder, self.momentum_text_encoder],
        ]
        self._copy_params()
        
        # Distill align
        self.distill_align = DistillAlign(hidden_dim, queue_length=queue_length,
                                         alpha=distillation_coefficient)
        
        # Convolutional fuser
        self.convolutional_fuser = ConvolutionalFuser(
            hidden_dim=hidden_dim, num_v2t_encoder_layers=num_v2t_encoder_layers,
            num_encoder1_layers=num_encoder1_layers,
            num_convolutional_blocks=num_convolutional_blocks,
            num_encoder2_layers=num_encoder2_layers)
        
        # Loop decoder
        self.query_embed = nn.Embedding(num_queries, 2)
        self.loop_decoder = LoopDecoder(
            hidden_dim=hidden_dim, num_decoder_layers=num_decoder_layers,
            num_decoder_loops=num_decoder_loops,
            moe_num_experts=moe_num_experts,
            moe_top_k=moe_top_k,
            moe_load_balance_coef=moe_load_balance_coef)
        
        # Prediction heads
        self.prediction_heads = PredictionHeads(hidden_dim, span_loss_type,
                                               clip_len, aux_loss)

    def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, vid=None,
                qid=None, src_aud=None, src_aud_mask=None, epoch_i=0,
                batch_idx=0, train_loader_length=0, targets=None, is_training=False):
        """
        Forward pass through Switch-Net.

        Args:
            src_vid: video features (batch_size, L, vid_dim)
            src_txt: text features (batch_size, N, txt_dim)
            src_vid_mask: video mask (batch_size, L)
            src_txt_mask: text mask (batch_size, N)
            src_aud: audio features (batch_size, L, aud_dim) [optional]
            src_aud_mask: audio mask (batch_size, L) [optional]
            epoch_i: current epoch
            batch_idx: current batch index
            train_loader_length: total batches per epoch
            targets: ground truth targets
            is_training: training mode flag
            
        Returns:
            out: dictionary containing predictions and losses
        """
        # Concatenate audio if provided
        if src_aud is not None:
            src_vid = torch.cat([src_vid, src_aud], dim=2)
        
        # Unimodal encoders
        src_vid_copy = src_vid
        src_txt_copy = src_txt
        src_vid = self.video_encoder(src_vid)
        src_txt = self.text_encoder(src_txt)

        # We don't care about gradients for momentum encoders
        with torch.no_grad():
            self._momentum_update()
            src_vid_m = self.momentum_video_encoder(src_vid_copy)
            src_txt_m = self.momentum_text_encoder(src_txt_copy)
        
        # Distill align (We only use the DistillAlign module to compute loss_align and loss_sim during training)
        loss_align, loss_sim = self.distill_align(
            F.normalize(src_vid.mean(1), dim=-1),
            F.normalize(src_txt.mean(1), dim=-1),
            F.normalize(src_vid_m.mean(1), dim=-1).detach(),
            F.normalize(src_txt_m.mean(1), dim=-1).detach(),
            epoch_i=epoch_i, batch_idx=batch_idx,
            train_loader_length=train_loader_length, is_training=is_training)
        
        # Positional embedding
        pos_vid = self.position_embed(src_vid, src_vid_mask)
        pos_txt = (self.txt_position_embed(src_txt) if self.use_txt_pos 
                  else torch.zeros_like(src_txt))
        
        # Convolutional fuser
        memory = self.convolutional_fuser(src_vid, src_txt, src_vid_mask.bool(),
                                         src_txt_mask.bool(), pos_vid, pos_txt)
        
        # Loop decoder
        _, bs, d = memory.shape  # (L, batch_size, hidden_dim)

        # initialize reference point embeddings
        refpoint_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)

        # initialize tgt as zeros matrix, where tgt is the query notes
        tgt = torch.zeros(refpoint_embed.shape[0], bs, d).to(src_vid.device)
        
        hs, reference, moe_loss = self.loop_decoder(tgt, memory, src_vid_mask.bool(),
                                                    pos_vid, refpoint_embed)
        memory = memory.transpose(0, 1)  # (batch_size, L, hidden_dim)
        
        # Prediction heads
        pred_logits, pred_spans, saliency_scores, aux_outputs = \
            self.prediction_heads(hs, reference, memory, src_vid)
        
        # Prepare output
        out = {
            "loss_align": loss_align,
            "loss_sim": loss_sim,
            "loss_moe_balance": moe_loss,
            "video_mask": src_vid_mask,
            "pred_logits": pred_logits,
            "pred_spans": pred_spans,
            "saliency_scores": saliency_scores,
            "aux_outputs": aux_outputs if aux_outputs is not None else []
        }
        
        return out

    @torch.no_grad()
    def _copy_params(self):
        """Copy parameters to momentum encoders"""
        for model_pair in self.model_pairs:
            for param, param_m in zip(model_pair[0].parameters(), 
                                     model_pair[1].parameters()):
                param_m.data.copy_(param.data)
                param_m.requires_grad = False

    @torch.no_grad()
    def _momentum_update(self):
        """Update momentum encoders"""
        for model_pair in self.model_pairs:
            for param, param_m in zip(model_pair[0].parameters(),
                                     model_pair[1].parameters()):
                param_m.data = param_m.data * self.momentum + \
                              param.data * (1.0 - self.momentum)

# Legacy alias to keep external demos functional
LD_DETR = SwitchNet
print("Complete Switch-Net model loaded successfully")

### 4.6 Hungarian Matcher and Loss Criterion

The Hungarian algorithm matches predictions to ground truth, and the loss criterion computes all training losses.

In [None]:
# Hungarian Matcher
class HungarianMatcher(nn.Module):
    """
    Hungarian Matcher for bipartite matching between predictions and ground truth
    Uses the Hungarian algorithm to find optimal assignment
    """
    def __init__(self, cost_class=1, cost_span=1, cost_giou=1, 
                 span_loss_type="l1", max_v_l=75):
        super().__init__()
        self.cost_class = cost_class
        self.cost_span = cost_span
        self.cost_giou = cost_giou
        self.span_loss_type = span_loss_type
        self.max_v_l = max_v_l
        self.foreground_label = 0
        assert cost_class != 0 or cost_span != 0 or cost_giou != 0, \
            "All costs cannot be 0"

    @torch.no_grad()
    def forward(self, outputs, targets):
        """
        Performs bipartite matching
        
        Args:
            outputs: dict with "pred_spans" and "pred_logits"
            targets: dict with ground truth annotations
        
        Returns:
            List of (pred_idx, target_idx) tuples for each batch element
        """
        bs, num_queries = outputs["pred_spans"].shape[:2]
        targets = targets["span_labels"]
        
        # Flatten predictions
        out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)
        tgt_spans = torch.cat([v["spans"] for v in targets])
        tgt_ids = torch.full([len(tgt_spans)], self.foreground_label,
                            dtype=torch.long, device=out_prob.device)
        
        # Classification cost
        cost_class = -out_prob[:, tgt_ids]
        
        if self.span_loss_type == "l1":
            out_spans = outputs["pred_spans"].flatten(0, 1)
            
            # L1 cost
            cost_span = torch.cdist(out_spans, tgt_spans, p=1)
            
            # GIoU cost
            cost_giou = -generalized_temporal_iou(
                span_cxw_to_xx(out_spans), span_cxw_to_xx(tgt_spans))
        else:
            # Cross-entropy cost for classification-based span prediction
            pred_spans = outputs["pred_spans"]
            pred_spans = pred_spans.view(bs * num_queries, 2, self.max_v_l).softmax(-1)
            cost_span = -pred_spans[:, 0][:, tgt_spans[:, 0]] - \
                       pred_spans[:, 1][:, tgt_spans[:, 1]]
            cost_giou = 0
        
        # Final cost matrix
        C = self.cost_span * cost_span + self.cost_giou * cost_giou + \
            self.cost_class * cost_class
        C = C.view(bs, num_queries, -1).cpu()
        
        sizes = [len(v["spans"]) for v in targets]
        indices = [linear_sum_assignment(c[i]) 
                  for i, c in enumerate(C.split(sizes, -1))]
        
        return [(torch.as_tensor(i, dtype=torch.int64),
                torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]


def build_matcher(cost_span=10, cost_giou=1, cost_class=4, 
                  span_loss_type="l1", max_v_l=75):
    """Build Hungarian Matcher"""
    return HungarianMatcher(cost_class=cost_class, cost_span=cost_span,
                           cost_giou=cost_giou, span_loss_type=span_loss_type,
                           max_v_l=max_v_l)

print("Hungarian Matcher loaded successfully")

In [None]:
class SetCriterion(nn.Module):
    """Loss computation for Switch-Net (compatible with the legacy LD-DETR name)."""
    def __init__(self, matcher, weight_dict, losses, span_loss_type="l1", max_v_l=75,
                 eos_coef=0.1, saliency_margin=0.2, saliency_label_scale=12, use_matcher=True):
        super().__init__()
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.losses = losses
        self.span_loss_type = span_loss_type
        self.max_v_l = max_v_l
        self.saliency_margin = saliency_margin
        self.saliency_label_scale = max(int(saliency_label_scale), 1)
        self.use_matcher = use_matcher
        
        self.foreground_label = 0
        self.background_label = 1
        self.eos_coef = eos_coef
        
        empty_weight = torch.ones(2)
        empty_weight[-1] = self.eos_coef
        self.register_buffer("empty_weight", empty_weight)
        
    def loss_spans(self, outputs, targets, indices):
        assert "pred_spans" in outputs
        targets = targets["span_labels"]
        idx = self._get_src_permutation_idx(indices)
        src_spans = outputs["pred_spans"][idx]
        tgt_spans = torch.cat([t["spans"][i] for t, (_, i) in zip(targets, indices)], dim=0)
        
        if self.span_loss_type == "l1":
            loss_span = F.l1_loss(src_spans, tgt_spans, reduction="none")
            loss_giou = 1 - torch.diag(generalized_temporal_iou(
                span_cxw_to_xx(src_spans), span_cxw_to_xx(tgt_spans)))
        else:
            n_spans = src_spans.shape[0]
            src_spans = src_spans.view(n_spans, 2, self.max_v_l).transpose(1, 2)
            loss_span = F.cross_entropy(src_spans, tgt_spans, reduction="none")
            loss_giou = loss_span.new_zeros([1])
        
        return {"loss_span": loss_span.mean(), "loss_giou": loss_giou.mean()}
        
    def loss_labels(self, outputs, targets, indices, log=True):
        assert "pred_logits" in outputs
        src_logits = outputs["pred_logits"]
        idx = self._get_src_permutation_idx(indices)
        
        target_classes = torch.full(src_logits.shape[:2], self.background_label,
                                   dtype=torch.int64, device=src_logits.device)
        target_classes[idx] = self.foreground_label
        
        src_logits_flat = src_logits.flatten(0, 1)
        target_classes_flat = target_classes.flatten()
        
        loss_ce = F.cross_entropy(
            src_logits_flat.to(torch.float32),
            target_classes_flat,
            weight=self.empty_weight.to(src_logits.device)
        )
        
        losses = {"loss_label": loss_ce}
        
        if log and target_classes_flat.numel() > 0:
            with torch.no_grad():
                class_err = 100 - accuracy(src_logits_flat, target_classes_flat)[0]
            losses["class_error"] = class_err
        elif log:
            losses["class_error"] = torch.tensor(100.0, device=src_logits.device)
        
        return losses
        
    def loss_saliency(self, outputs, targets, indices, log=True):
        if "saliency_pos_labels" not in targets:
            zero = torch.tensor(0.0, device=outputs["saliency_scores"].device)
            return {"loss_saliency": zero}
        
        vid_token_mask = outputs["video_mask"]
        saliency_scores = outputs["saliency_scores"].clone()
        saliency_contrast_label = targets["saliency_all_labels"]
        saliency_scores = (vid_token_mask * saliency_scores + (1.0 - vid_token_mask) * -1e3)
        
        tau = 0.5
        loss_rank_contrastive = 0.0
        max_rank = max(self.saliency_label_scale, 1)
        for rand_idx in range(1, max_rank + 1):
            drop_mask = ~(saliency_contrast_label > 100)
            pos_mask = saliency_contrast_label >= rand_idx
            if torch.sum(pos_mask) == 0:
                continue
            batch_drop_mask = torch.sum(pos_mask, dim=1) > 0
            cur_saliency_scores = saliency_scores * drop_mask / tau + ~drop_mask * -1e3
            logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0]
            exp_logits = torch.exp(logits)
            log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)
            mean_log_prob_pos = (pos_mask * log_prob * vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6)
            loss = -mean_log_prob_pos * batch_drop_mask
            loss_rank_contrastive = loss_rank_contrastive + loss.mean()
        if max_rank > 0:
            loss_rank_contrastive = loss_rank_contrastive / max_rank
        
        saliency_scores = outputs["saliency_scores"]
        pos_indices = targets["saliency_pos_labels"]
        neg_indices = targets["saliency_neg_labels"]
        num_pairs = pos_indices.shape[1]
        batch_indices = torch.arange(len(saliency_scores)).to(saliency_scores.device)
        pos_scores = torch.stack([saliency_scores[batch_indices, pos_indices[:, col_idx]]
                                 for col_idx in range(num_pairs)], dim=1)
        neg_scores = torch.stack([saliency_scores[batch_indices, neg_indices[:, col_idx]]
                                 for col_idx in range(num_pairs)], dim=1)
        loss_saliency = (torch.clamp(self.saliency_margin + neg_scores - pos_scores,
                                     min=0).sum() / (len(pos_scores) * num_pairs) * 2)
        loss_saliency = loss_saliency + loss_rank_contrastive
        
        return {"loss_saliency": loss_saliency}
        
    def loss_align(self, outputs, targets, indices, log=True):
        return {"loss_align": outputs["loss_align"]}
        
    def loss_sim(self, outputs, targets, indices, log=True):
        return {"loss_sim": outputs["loss_sim"]}
        
    def loss_moe_balance(self, outputs, targets, indices, log=True):
        moe_loss = outputs.get("loss_moe_balance")
        base_tensor = self.empty_weight
        if moe_loss is None:
            moe_loss = base_tensor.new_tensor(0.0)
        elif not isinstance(moe_loss, torch.Tensor):
            moe_loss = base_tensor.new_tensor(moe_loss)
        else:
            moe_loss = moe_loss.to(base_tensor.device)
        return {"loss_moe_balance": moe_loss}
        
    def _get_src_permutation_idx(self, indices):
        batch_idx = torch.cat([torch.full_like(src, i)
                               for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx
        
    def forward(self, outputs, targets):
        if self.use_matcher:
            indices = self.matcher(outputs, targets)
        else:
            indices = [(torch.arange(outputs["pred_spans"].shape[1], device=outputs["pred_spans"].device),
                        torch.arange(t["spans"].shape[0], device=outputs["pred_spans"].device))
                       for t in targets["span_labels"]]
        losses = {}
        for loss in self.losses:
            losses.update(getattr(self, f"loss_{loss}")(outputs, targets, indices))
        
        if "aux_outputs" in outputs and outputs["aux_outputs"]:
            for aux_idx, aux_outputs in enumerate(outputs["aux_outputs"]):
                if self.use_matcher:
                    aux_indices = self.matcher(aux_outputs, targets)
                else:
                    aux_indices = [(torch.arange(aux_outputs["pred_spans"].shape[1], device=aux_outputs["pred_spans"].device),
                                    torch.arange(t["spans"].shape[0], device=aux_outputs["pred_spans"].device))
                                   for t in targets["span_labels"]]
                for loss in self.losses:
                    if loss in {"saliency", "align", "sim", "moe_balance"}:
                        continue
                    l_dict = getattr(self, f"loss_{loss}")(aux_outputs, targets, aux_indices)
                    l_dict = {k + f"_{aux_idx}": v for k, v in l_dict.items()}
                    losses.update(l_dict)
        
        return losses

## Important Notes

**1. Data Loading**
- Current notebook has placeholder data paths
- You need to implement actual dataset loading
- Create proper dataloaders with collate functions
- Ensure features are pre-extracted (CLIP, SlowFast)

**2. Queue Length**
- `queue_length` in Distill Align must be divisible by batch size
- Default: 65536 (works with batch sizes like 32, 64, 128)
- If you get errors, adjust queue_length or batch_size

**3. Feature Extraction**
- Video features: SlowFast (2304-dim) + CLIP (512-dim) = 2816-dim
- Text features: CLIP text encoder (512-dim)
- Ensure your features match these dimensions

**4. Training Data Format**
The model expects data in this format:
```python
{
    "qid": query_id,
    "query": "text description",
    "duration": video_duration,
    "vid": video_id,
    "relevant_windows": [[start, end]]  # in seconds
}
```

### Testing the Notebook:

Before full training, test with:
1. Small batch size (e.g., 2-4 samples)
2. Few epochs (e.g., 2-3)
3. Reduced max_v_l (e.g., 30 instead of 75)
4. Check all losses are computed correctly

## 5. Configuration and Hyperparameters

This section instantiates a single `Config` object that centralises every tunable setting: model depth, optimiser defaults, and the custom loss weights we just re-balanced. The comments in the cell call out the rationale for each group of parameters so future adjustments can be made without hunting through the notebook.

In [None]:
class Config:
    """Configuration container for training hyperparameters."""
    def __init__(self):
        # General settings
        self.device = device
        self.hidden_dim = 256
        self.dropout = 0.1
        
        # Feature dimensions
        self.txt_dim = 512
        self.vid_dim = 514
        self.aud_dim = 0
        
        # Query/video parameters
        self.num_queries = 10
        self.max_q_l = 32
        self.max_v_l = -1  # use full sequence length by default
        self.clip_len = 1
        
        # Position embedding
        self.position_embedding = "sine"
        self.use_txt_pos = False
        
        # Distill Align parameters
        self.queue_length = 65536
        self.momentum = 0.995
        self.distillation_coefficient = 0.3
        
        # Encoder parameters
        self.num_v2t_encoder_layers = 2
        self.num_encoder1_layers = 2
        self.num_convolutional_blocks = 4
        self.num_encoder2_layers = 2
        
        # Decoder parameters
        self.num_decoder_layers = 2
        self.num_decoder_loops = 3
        self.moe_num_experts = 8
        self.moe_top_k = 2
        self.moe_load_balance_coef = 1.0
        
        # Loss options
        self.span_loss_type = "l1"
        self.aux_loss = True
        
        # Optimization parameters
        self.lr = 1e-4
        self.weight_decay = 1e-4
        self.n_epoch = 200
        self.batch_size = 64
        self.grad_clip = 0.1
        
        # Loss weights 
        self.span_loss_coef = 10
        self.giou_loss_coef = 1
        self.label_loss_coef = 4
        self.saliency_loss_coef = 1.5
        self.align_loss_coef = 0.6
        self.sim_loss_coef = 0.4
        self.moe_loss_coef = 0.01
        
        # Saliency and span stabilization
        self.saliency_margin = 0.2
        self.saliency_label_scale = 3.0
        self.saliency_warmup_epochs = 3
        self.align_warmup_epochs = 2
        self.min_span_width_ratio = 0.0



config = Config()

max_v_desc = "full sequence" if config.max_v_l <= 0 else f"{config.max_v_l} clips"
print("Configuration loaded")
print(f"- Hidden dimension: {config.hidden_dim}")
print(f"- Number of queries: {config.num_queries}")
print(f"- Max video length: {max_v_desc}")
print(f"- Clip length (seconds per feature): {config.clip_len}")
print(f"- Batch size: {config.batch_size}")
print(f"- Learning rate: {config.lr}")

## 6. Data Loading

Before instantiating the datasets we pin down the absolute paths to the JSON annotations and pre-computed CLIP features. Keeping this configuration in one place makes it obvious how to repoint the notebook to a new dataset split or feature directory.

In [None]:
# Dataset paths - adjust these to your actual data paths
DATA_ROOT = "data"
TRAIN_PATH = f"{DATA_ROOT}/train.jsonl"
TEST_PATH = f"{DATA_ROOT}/test.jsonl"

# Feature directories
FEATURE_ROOT = f"{DATA_ROOT}/features"

# The Video Feature directories (can include multiple types)
VIDEO_FEAT_DIRS = [
    f"{FEATURE_ROOT}/video_clip_features",
    # f"{FEATURE_ROOT}/video_slowfast_features",
 ]  # TEF channels are appended automatically during loading

QUERY_FEAT_DIR = f"{FEATURE_ROOT}/query_clip_features"

print("Data paths configured:")
print(f"- Train data: {TRAIN_PATH}")
print(f"- Test data: {TEST_PATH}")
print("- Video features:")
for path in VIDEO_FEAT_DIRS:
    print(f"    {path}")
print(f"- Query features: {QUERY_FEAT_DIR}")

## 7. Model Initialization

Now let's initialize the Switch-Net model (formerly LD-DETR) with our configuration.

In [None]:
# Initialize the Switch-Net model
model = SwitchNet(
    txt_dim=config.txt_dim,
    vid_dim=config.vid_dim,
    hidden_dim=config.hidden_dim,
    num_queries=config.num_queries,
    aux_loss=config.aux_loss,
    position_embedding=config.position_embedding,
    max_v_l=config.max_v_l,
    max_q_l=config.max_q_l,
    span_loss_type=config.span_loss_type,
    use_txt_pos=config.use_txt_pos,
    aud_dim=config.aud_dim,
    queue_length=config.queue_length,
    momentum=config.momentum,
    distillation_coefficient=config.distillation_coefficient,
    num_v2t_encoder_layers=config.num_v2t_encoder_layers,
    num_encoder1_layers=config.num_encoder1_layers,
    num_convolutional_blocks=config.num_convolutional_blocks,
    num_encoder2_layers=config.num_encoder2_layers,
    num_decoder_layers=config.num_decoder_layers,
    num_decoder_loops=config.num_decoder_loops,
    clip_len=config.clip_len,
    moe_num_experts=config.moe_num_experts,
    moe_top_k=config.moe_top_k,
    moe_load_balance_coef=config.moe_load_balance_coef,
 ).to(config.device)

# Count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params = count_parameters(model)
print(f"Total trainable parameters: {total_params:,}")
print(f"Model device: {next(model.parameters()).device}")
print(f"Model initialized successfully")

## 8. Loss Function and Matcher

We mirror Switch-Net's (formerly LD-DETR) matching strategy (Hungarian assignment on span + GIoU + class costs) and then assemble a weighted loss dictionary. The coefficients reflect our latest diagnostics: span and classification remain dominant, while saliency, alignment, and similarity receive tempered weights so they still guide learning without overwhelming the gradients.

In [None]:
# Build matcher for Hungarian algorithm
matcher = build_matcher(
    cost_span=config.span_loss_coef,
    cost_giou=config.giou_loss_coef,
    cost_class=config.label_loss_coef,
    span_loss_type=config.span_loss_type,
    max_v_l=config.max_v_l
 )

# Weight dictionary for losses
weight_dict = {
    "loss_span": config.span_loss_coef,
    "loss_giou": config.giou_loss_coef,
    "loss_label": config.label_loss_coef,
    "loss_saliency": config.saliency_loss_coef,
    "loss_align": config.align_loss_coef,           # DistillAlign loss weight
    "loss_sim": config.sim_loss_coef,              # DistillAlign similarity loss weight
    "loss_moe_balance": config.moe_loss_coef,      # MoE load balancing loss weight
 }

# Add auxiliary losses
if config.aux_loss:
    aux_weight_dict = {}
    for i in range(max(config.num_decoder_layers - 1, 0)):
        aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()
                               if k not in ["loss_saliency", "loss_align", "loss_sim", "loss_moe_balance"]})
    weight_dict.update(aux_weight_dict)

# Initialize criterion
criterion = SetCriterion(
    matcher=matcher,
    weight_dict=weight_dict,
    losses=["spans", "labels", "saliency", "align", "sim", "moe_balance"],
    span_loss_type=config.span_loss_type,
    max_v_l=config.max_v_l,
    saliency_margin=config.saliency_margin,
    saliency_label_scale=config.saliency_label_scale,
    use_matcher=True
 ).to(config.device)

print("Loss function and matcher initialized")
print(f"Loss weights: {list(weight_dict.keys())}")

## 9. Optimizer and Scheduler

We use AdamW with weight decay and a cosine annealing schedule that gradually decays the learning rate over the full training horizon.

In [None]:
# Optimizer
optimizer = AdamW(
    model.parameters(),
    lr=config.lr,
    weight_decay=config.weight_decay
)

# Learning rate scheduler
lr_scheduler = CosineAnnealingLR(
    optimizer,
    T_max=max(config.n_epoch, 1),
    eta_min=0.0
)

print(f"Optimizer: AdamW (lr={config.lr}, weight_decay={config.weight_decay})")
print(f"Scheduler: CosineAnnealingLR (T_max={max(config.n_epoch, 1)}, eta_min=0.0)")

## 10. Training Loop

The following cell defines the main training loop helper functions for Switch-Net, including:

- `prepare_batch_inputs`: Prepares and moves batch data to the correct device.
- `train_one_epoch`: Runs a full training epoch with detailed diagnostics (loss breakdowns, gradient norm tracking, etc.).
- Utility functions for logging and loss formatting.

In [None]:
PRIMARY_LOSS_KEYS = [
    "loss_span",
    "loss_giou",
    "loss_label",
    "loss_saliency",
    "loss_align",
    "loss_sim",
    "loss_moe_balance",
]

def prepare_batch_inputs(batch_data, device):
    """
    Prepare batch inputs for model.
    """
    model_inputs, targets = batch_data
    
    model_inputs = {
        'src_vid': model_inputs['src_vid'].to(device),
        'src_txt': model_inputs['src_txt'].to(device),
        'src_vid_mask': model_inputs['src_vid_mask'].to(device),
        'src_txt_mask': model_inputs['src_txt_mask'].to(device),
    }
    
    processed_targets = {
        'span_labels': [
            {k: v.to(device) if isinstance(v, torch.Tensor) else v 
             for k, v in t.items()}
            for t in targets['span_labels']
        ]
    }
    
    if 'saliency_pos_labels' in targets:
        processed_targets['saliency_pos_labels'] = targets['saliency_pos_labels'].to(device)
        processed_targets['saliency_neg_labels'] = targets['saliency_neg_labels'].to(device)
    if 'saliency_all_labels' in targets:
        processed_targets['saliency_all_labels'] = targets['saliency_all_labels'].to(device)
        processed_targets['relevant_clips'] = processed_targets['saliency_all_labels']
    
    return model_inputs, processed_targets


def _canonical_loss_key(name: str) -> Optional[str]:
    for key in PRIMARY_LOSS_KEYS:
        if name.startswith(key):
            return key
    if name == "class_error":
        return "class_error"
    return None

def _format_loss_dict(loss_dict: Dict[str, torch.Tensor]) -> str:
    segments = []
    for key in sorted(loss_dict.keys()):
        if key.startswith("loss_") or key == "class_error":
            value = loss_dict[key]
            value = value.item() if isinstance(value, torch.Tensor) else float(value)
            segments.append(f"{key}:{value:.3f}")
    return " | ".join(segments)

def _compute_grad_norm(parameters) -> float:
    norms = [p.grad.detach().data.norm(2) for p in parameters if p.grad is not None]
    if not norms:
        return 0.0
    total_norm = torch.norm(torch.stack(norms), 2)
    return float(total_norm.item())

def train_one_epoch(model, criterion, data_loader, optimizer, epoch, config):
    """Train for a single epoch with detailed diagnostics."""
    model.train()
    criterion.train()

    total_loss = 0.0
    loss_sums = defaultdict(float)
    grad_norm_total = 0.0
    grad_norm_count = 0
    num_batches = len(data_loader)
    epoch_start = time.time()

    for batch_idx, batch in enumerate(tqdm(data_loader, desc=f"Epoch {epoch+1}", leave=False)):
        model_inputs, targets = prepare_batch_inputs(batch[1], config.device)

        outputs = model(
            **model_inputs,
            epoch_i=epoch,
            batch_idx=batch_idx,
            train_loader_length=num_batches,
            targets=targets,
            is_training=True,
        )

        loss_dict = criterion(outputs, targets)
        weighted_loss = sum(
            loss_dict[k] * criterion.weight_dict[k]
            for k in loss_dict.keys()
            if k in criterion.weight_dict
        )

        optimizer.zero_grad()
        weighted_loss.backward()

        if config.grad_clip > 0:
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
            grad_norm = float(grad_norm.item() if hasattr(grad_norm, "item") else grad_norm)
        else:
            grad_norm = _compute_grad_norm(model.parameters())

        optimizer.step()

        total_loss += weighted_loss.item()
        for key, value in loss_dict.items():
            canon = _canonical_loss_key(key)
            if canon is not None:
                loss_value = value.item() if isinstance(value, torch.Tensor) else float(value)
                loss_sums[canon] += loss_value

        if grad_norm is not None:
            grad_norm_total += grad_norm
            grad_norm_count += 1

        if (batch_idx + 1) % 1000 == 0 or batch_idx == 0 or (batch_idx + 1) == num_batches:
            elapsed = time.time() - epoch_start
            diagnostics = _format_loss_dict(loss_dict)
            print(
                f"  Batch {batch_idx + 1}/{num_batches} | "
                f"loss={weighted_loss.item():.4f} | grad_norm={grad_norm:.3f} | "
                f"elapsed={elapsed:.1f}s"
            )
            if diagnostics:
                print(f"    {diagnostics}")

    avg_loss = total_loss / max(num_batches, 1)
    loss_breakdown = {key: loss_sums[key] / max(num_batches, 1) for key in loss_sums}
    avg_grad_norm = (grad_norm_total / grad_norm_count) if grad_norm_count > 0 else 0.0

    epoch_time = time.time() - epoch_start
    print(f"Epoch {epoch+1} finished in {epoch_time/60:.2f} min")

    return {
        'loss': avg_loss,
        'lr': optimizer.param_groups[0]['lr'],
        'loss_breakdown': loss_breakdown,
        'grad_norm': avg_grad_norm,
        'epoch_time': epoch_time,
    }

print("Training loop helper functions loaded successfully")

## 11. Evaluation and Inference

Evaluation metrics for Video Moment Retrieval:
- **R@N, IoU=m**: Recall at N predictions with IoU threshold m
- **mIoU**: Mean Intersection over Union

For Highlight Detection:
- **mAP**: Mean Average Precision

In [None]:
def inference(model, data_loader, config):
    """
    Run inference on validation/test set
    
    Args:
        model: Switch-Net model (alias LD_DETR remains supported)
        data_loader: Validation/test data loader
        config: Configuration object
    
    Returns:
        predictions: List of predictions for each query
    """
    model.eval()
    predictions = []
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(data_loader, desc="Inference")):
            # Prepare batch inputs
            model_inputs, targets = prepare_batch_inputs(batch[1], config.device)
            
            # Forward pass
            outputs = model(
                **model_inputs,
                epoch_i=0,
                batch_idx=batch_idx,
                train_loader_length=len(data_loader),
                targets=None,
                is_training=False
            )
            
            # Extract predictions
            pred_logits = outputs["pred_logits"]  # (batch_size, num_queries, 2)
            pred_spans = outputs["pred_spans"]    # (batch_size, num_queries, 2)
            saliency_scores = outputs["saliency_scores"]  # (batch_size, max_v_l)
            
            # Convert predictions to numpy
            batch_predictions = {
                "pred_logits": pred_logits.cpu().numpy(),
                "pred_spans": pred_spans.cpu().numpy(),
                "saliency_scores": saliency_scores.cpu().numpy(),
                "video_ids": batch[0]  # Assuming batch contains video IDs
            }
            predictions.append(batch_predictions)
    
    return predictions

print("Inference function defined")

## 12. Complete Training Script

This is the main training loop that brings everything together.

## Dataset and Training

**Current Setup (CLIP only):**
- Video features: 512-dim CLIP
- Query features: 512-dim CLIP (truncated to 32 tokens)
- Model vid_dim: 512

**Future: Adding SlowFast features:**
1. Extract SlowFast features to `/data/features/video_slowfast_features/`
2. Update VIDEO_FEAT_DIRS to list: `[clip_dir, slowfast_dir]`
3. Update `config.vid_dim = 2816` (512 + 2304)
4. Re-run model initialization
5. Features will be automatically concatenated during data loading

In [None]:
class VideoDataset(Dataset):
    """Dataset for video moment retrieval with TEF augmentation and saliency supervision."""
    def __init__(self, data_path, video_feat_dirs, query_feat_dir, 
                 max_v_l=75, max_q_l=32, clip_len=2, normalize_features=True,
                 saliency_pairs=5, saliency_value_scale=12.0, min_span_width_ratio=0.0):
        self.data_path = data_path
        self.video_feat_dirs = video_feat_dirs if isinstance(video_feat_dirs, list) else [video_feat_dirs]
        self.query_feat_dir = query_feat_dir
        self.max_v_l = max_v_l if (max_v_l is not None and max_v_l > 0) else None
        self.max_q_l = max_q_l
        self.clip_len = clip_len
        self.normalize_features = normalize_features
        self.saliency_pairs = saliency_pairs
        self.saliency_value_scale = float(saliency_value_scale)
        self.min_span_width_ratio = float(max(min_span_width_ratio, 0.0))
        
        self.data = []
        with open(data_path, 'r') as f:
            for line in f:
                if line.strip():
                    self.data.append(json.loads(line.strip()))
        
        self.feature_dim = None
        self.query_dim = None
        self._infer_feature_dims()
        if self.max_v_l is None:
            self.max_v_l = self._infer_max_context_length()
        
        print(f"Loaded {len(self.data)} samples from {data_path}")
        print(f"[VideoDataset] video_dim={self.feature_dim}, query_dim={self.query_dim}, normalize={self.normalize_features}")
        print(f"[VideoDataset] max_v_l={self.max_v_l}, saliency scale={self.saliency_value_scale}, min_span_width_ratio={self.min_span_width_ratio}")
        
    def _infer_feature_dims(self):
        for item in self.data:
            vid = item.get('vid')
            qid = item.get('qid')
            dims = 0
            for feat_dir in self.video_feat_dirs:
                npz_path = os.path.join(feat_dir, f"{vid}.npz")
                npy_path = os.path.join(feat_dir, f"{vid}.npy")
                if os.path.exists(npz_path):
                    with np.load(npz_path) as arr:
                        dims += arr['features'].shape[-1]
                elif os.path.exists(npy_path):
                    dims += np.load(npy_path).shape[-1]
            if dims > 0:
                self.feature_dim = dims + 2  # +2 for TEF channels
            q_path = os.path.join(self.query_feat_dir, f"{qid}.npz")
            if os.path.exists(q_path):
                with np.load(q_path) as q_arr:
                    self.query_dim = q_arr['last_hidden_state'].shape[-1]
            if self.feature_dim is not None and self.query_dim is not None:
                break
        if self.feature_dim is None:
            self.feature_dim = 512 * len(self.video_feat_dirs) + 2
        if self.query_dim is None:
            self.query_dim = 512
        
    def _infer_max_context_length(self):
        """Infer the maximum temporal length across all videos."""
        max_len = 0
        for item in self.data:
            vid = item.get('vid')
            lengths = []
            for feat_dir in self.video_feat_dirs:
                npz_path = os.path.join(feat_dir, f"{vid}.npz")
                npy_path = os.path.join(feat_dir, f"{vid}.npy")
                if os.path.exists(npz_path):
                    with np.load(npz_path) as arr:
                        lengths.append(arr['features'].shape[0])
                elif os.path.exists(npy_path):
                    lengths.append(np.load(npy_path).shape[0])
            if lengths:
                max_len = max(max_len, min(lengths))
        return int(max_len or 1)

    # Build Temporal Encoding Features (TEF)
    def _build_tef(self, length):
        if length <= 0:
            return np.zeros((0, 2), dtype=np.float32)
        positions = np.arange(length, dtype=np.float32)
        denom = float(max(length, 1))
        tef_start = positions / denom
        tef_end = np.minimum((positions + 1.0) / denom, 1.0)
        return np.stack([tef_start, tef_end], axis=1)
        
    def __len__(self):
        return len(self.data)
    
    def _generate_saliency_labels(self, spans, duration, video_len):
        scores = np.zeros(self.max_v_l, dtype=np.float32)
        if video_len <= 0 or duration <= 0 or not spans:
            filler = np.zeros(self.saliency_pairs, dtype=np.int64)
            return filler, filler.copy(), scores
        clip_length = duration / max(video_len, 1)
        pos_indices = set()
        for start, end in spans:
            st_idx = int(np.floor(start / clip_length))
            ed_idx = int(np.ceil(end / clip_length)) - 1
            st_idx = max(0, min(st_idx, video_len - 1))
            ed_idx = max(st_idx, min(ed_idx, video_len - 1))
            if ed_idx >= st_idx:
                pos_indices.update(range(st_idx, ed_idx + 1))
                scores[st_idx:ed_idx + 1] = self.saliency_value_scale
        neg_pool = [i for i in range(video_len) if i not in pos_indices]
        if not pos_indices:
            pos_indices = {0}
        if not neg_pool:
            neg_pool = list(pos_indices)
        pos_indices = list(pos_indices)
        num_pairs = min(self.saliency_pairs, len(pos_indices), len(neg_pool))
        if num_pairs == 0:
            num_pairs = 1
        pos_samples = np.random.choice(pos_indices, size=num_pairs, replace=len(pos_indices) < num_pairs)
        neg_samples = np.random.choice(neg_pool, size=num_pairs, replace=len(neg_pool) < num_pairs)
        pos_array = np.full(self.saliency_pairs, pos_samples[0], dtype=np.int64)
        neg_array = np.full(self.saliency_pairs, neg_samples[0], dtype=np.int64)
        pos_array[:num_pairs] = pos_samples
        neg_array[:num_pairs] = neg_samples
        return pos_array, neg_array, scores
    
    def __getitem__(self, idx):
        item = self.data[idx]
        qid = item['qid']
        vid = item['vid']
        duration = float(item.get('duration', 0.0))
        relevant_windows = item.get('relevant_windows', [])
        duration = max(duration, 1e-6)
        
        try:
            v_feat_list = []
            for feat_dir in self.video_feat_dirs:
                npz_path = os.path.join(feat_dir, f"{vid}.npz")
                npy_path = os.path.join(feat_dir, f"{vid}.npy")
                if os.path.exists(npz_path):
                    with np.load(npz_path) as arr:
                        _feat = arr['features'].astype(np.float32)
                elif os.path.exists(npy_path):
                    _feat = np.load(npy_path).astype(np.float32)
                else:
                    raise FileNotFoundError(f"Missing features for {vid} in {feat_dir}")
                if self.max_v_l is not None:
                    _feat = _feat[:self.max_v_l]
                if self.normalize_features:
                    norms = np.linalg.norm(_feat, axis=1, keepdims=True) + 1e-6
                    _feat = _feat / norms
                v_feat_list.append(_feat)
            min_len = min(len(e) for e in v_feat_list)
            v_feat_list = [e[:min_len] for e in v_feat_list]
            video_feat = np.concatenate(v_feat_list, axis=1)
            tef = self._build_tef(video_feat.shape[0])
            video_feat = np.concatenate([video_feat, tef], axis=1)
            self.feature_dim = video_feat.shape[-1]
        except Exception as exc:
            print(f"Warning: Could not load video features for vid={vid} - {exc}")
            fallback_len = int(self.max_v_l or 1)
            fallback_dim = int(self.feature_dim or (len(self.video_feat_dirs) * 512 + 2))
            video_feat = np.zeros((fallback_len, fallback_dim), dtype=np.float32)
        
        try:
            with np.load(os.path.join(self.query_feat_dir, f"{qid}.npz")) as query_data:
                query_feat = query_data['last_hidden_state'].astype(np.float32)[:self.max_q_l]
            if self.normalize_features:
                norms = np.linalg.norm(query_feat, axis=1, keepdims=True) + 1e-6
                query_feat = query_feat / norms
            self.query_dim = query_feat.shape[-1]
        except Exception as exc:
            print(f"Warning: Could not load query features for qid={qid} - {exc}")
            fallback_query_dim = int(self.query_dim or 512)
            query_feat = np.zeros((self.max_q_l, fallback_query_dim), dtype=np.float32)
        
        spans = []
        for start, end in relevant_windows:
            center = (start + end) / (2 * duration)
            width = max(end - start, 1e-6) / duration
            spans.append([center, width])
        spans = np.array(spans, dtype=np.float32) if spans else np.zeros((0, 2), dtype=np.float32)
        
        video_len = min(video_feat.shape[0], self.max_v_l)
        query_len = min(query_feat.shape[0], self.max_q_l)
        video_mask = np.zeros(self.max_v_l, dtype=np.float32)
        video_mask[:video_len] = 1
        query_mask = np.zeros(self.max_q_l, dtype=np.float32)
        query_mask[:query_len] = 1
        
        video_feat_padded = np.zeros((self.max_v_l, video_feat.shape[-1]), dtype=np.float32)
        video_feat_padded[:video_len] = video_feat[:video_len]
        query_feat_padded = np.zeros((self.max_q_l, query_feat.shape[-1]), dtype=np.float32)
        query_feat_padded[:query_len] = query_feat[:query_len]
        
        saliency_pos, saliency_neg, saliency_all = self._generate_saliency_labels(
            relevant_windows, duration, video_len)
        
        return {
            'qid': qid,
            'vid': vid,
            'video_feat': video_feat_padded,
            'query_feat': query_feat_padded,
            'video_mask': video_mask,
            'query_mask': query_mask,
            'spans': spans,
            'duration': duration,
            'saliency_pos_labels': saliency_pos,
            'saliency_neg_labels': saliency_neg,
            'saliency_all_labels': saliency_all
        }

def collate_fn(batch):
    """Collate function for batching."""
    video_feats = torch.from_numpy(np.stack([item['video_feat'] for item in batch])).float()
    query_feats = torch.from_numpy(np.stack([item['query_feat'] for item in batch])).float()
    video_masks = torch.from_numpy(np.stack([item['video_mask'] for item in batch])).float()
    query_masks = torch.from_numpy(np.stack([item['query_mask'] for item in batch])).float()
    
    span_labels = []
    for item in batch:
        spans_np = item['spans']
        spans_tensor = torch.from_numpy(spans_np).float() if spans_np.size > 0 else torch.zeros((0, 2))
        span_labels.append({'spans': spans_tensor})
    
    targets = {
        'span_labels': span_labels,
        'saliency_pos_labels': torch.from_numpy(np.stack([item['saliency_pos_labels'] for item in batch])).long(),
        'saliency_neg_labels': torch.from_numpy(np.stack([item['saliency_neg_labels'] for item in batch])).long(),
        'saliency_all_labels': torch.from_numpy(np.stack([item['saliency_all_labels'] for item in batch])).float(),
    }
    
    model_inputs = {
        'src_vid': video_feats,
        'src_txt': query_feats,
        'src_vid_mask': video_masks,
        'src_txt_mask': query_masks,
    }
    
    return {'meta': [item['vid'] for item in batch]}, (model_inputs, targets)

In [None]:
# Create datasets
try:
    train_dataset = VideoDataset(
        data_path=TRAIN_PATH,
        video_feat_dirs=VIDEO_FEAT_DIRS,  # Supports multiple feature types
        query_feat_dir=QUERY_FEAT_DIR,
        max_v_l=config.max_v_l,
        max_q_l=config.max_q_l,
        clip_len=config.clip_len,
        normalize_features=True,
        saliency_value_scale=config.saliency_label_scale,
        min_span_width_ratio=config.min_span_width_ratio,
    )
    
    # Sync config with inferred feature dimensions
    config.max_v_l = train_dataset.max_v_l
    config.vid_dim = train_dataset.feature_dim
    config.txt_dim = train_dataset.query_dim
    print(f"Updated config feature dims -> vid_dim={config.vid_dim}, txt_dim={config.txt_dim}")
    print(f"Updated config max_v_l -> {config.max_v_l}")
    
    # Ensure model/criterion reflect the updated temporal length
    if 'model' in globals():
        model.max_v_l = config.max_v_l
    if 'criterion' in globals():
        criterion.max_v_l = config.max_v_l
    
    test_dataset = VideoDataset(
        data_path=TEST_PATH,
        video_feat_dirs=VIDEO_FEAT_DIRS,
        query_feat_dir=QUERY_FEAT_DIR,
        max_v_l=config.max_v_l,
        max_q_l=config.max_q_l,
        clip_len=config.clip_len,
        normalize_features=True,
        saliency_value_scale=config.saliency_label_scale,
        min_span_width_ratio=config.min_span_width_ratio,
    )
    
    print(f"\nDatasets created successfully!")
    print(f"Train samples: {len(train_dataset)}")
    print(f"Test samples: {len(test_dataset)}")
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=0,  # Set to 0 for debugging, increase for faster loading
        collate_fn=collate_fn
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=0,
        collate_fn=collate_fn
    )
    
    print(f"\nDataloaders created successfully!")
    print(f"Train batches: {len(train_loader)}")
    print(f"Test batches: {len(test_loader)}")
    
    # Test loading one batch
    print("\nTesting data loading...")
    meta, (model_inputs, targets) = next(iter(train_loader))
    print(f"Video features shape: {model_inputs['src_vid'].shape}")
    print(f"Query features shape: {model_inputs['src_txt'].shape}")
    print(f"Video mask shape: {model_inputs['src_vid_mask'].shape}")
    print(f"Query mask shape: {model_inputs['src_txt_mask'].shape}")
except Exception as e:
    print(f"Error creating datasets or dataloaders: {e}")
    raise

# Evaluation Metrics for Video Moment Retrieval

This section implements **standard evaluation metrics** used in video moment retrieval:

1. **Temporal IoU (Intersection over Union)**
   - Measures overlap between predicted and ground truth spans
   - Formula: `IoU = (intersection) / (union)`
   - Range: [0, 1], where 1 = perfect match

2. **Recall at N (R@N, IoU≥θ)**
   - Percentage of queries with at least one correct prediction in top-N
   - Common settings: R@1, R@5, R@10 with IoU≥0.5 or IoU≥0.7
   - Example: R@5, IoU≥0.5 = "Top-5 predictions contain at least one span with IoU ≥ 0.5"

3. **mean Average Precision (mAP @ IoU≥θ)**
   - Average precision across all queries at a given IoU threshold
   - Considers ranking quality (better predictions ranked higher = higher mAP)
   - Common thresholds: IoU≥0.5, IoU≥0.7

These metrics are **essential for evaluating moment retrieval quality** and comparing model performance.

# Evaluation Functions

Functions to compute retrieval metrics (mAP, IoU, R@N) for moment retrieval.

In [None]:
def compute_iou_single(pred_span, gt_span):
    """
    Simple IoU computation for two spans (reuses logic from generalized_temporal_iou).
    
    Args:
        pred_span: [start, end] predicted span
        gt_span: [start, end] ground truth span
    
    Returns:
        iou: float in [0, 1]
    """
    inter_start = max(pred_span[0], gt_span[0])
    inter_end = min(pred_span[1], gt_span[1])
    inter = max(0, inter_end - inter_start)
    
    union = (pred_span[1] - pred_span[0]) + (gt_span[1] - gt_span[0]) - inter
    
    if union <= 0:
        return 0.0
    
    return inter / union


def evaluate_moment_retrieval(model, dataloader, config, device, iou_thresholds=[0.3, 0.5, 0.7]):
    """
    Evaluate moment retrieval performance on test/validation set.
    
    Computes:
    - mIoU: Mean IoU of the top-1 ranked prediction for each query
    - R@N, IoU≥θ: Recall at top-N predictions for each IoU threshold
    
    Args:
        model: Trained Switch-Net model (legacy LD_DETR alias supported)
        dataloader: DataLoader for evaluation
        config: Configuration object
        device: Device to run evaluation on
        iou_thresholds: List of IoU thresholds for R@N metrics
        
    Returns:
        metrics: Dictionary with mIoU, R@1, R@5, R@10 at different IoU thresholds
    """
    model.eval()
    
    all_top1_ious = []
    recall_hits = {thr: {'R@1': 0, 'R@5': 0, 'R@10': 0} for thr in iou_thresholds}
    total_queries = 0
    loader_length = len(dataloader)
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            _, batch_data = batch
            model_inputs, targets = prepare_batch_inputs(batch_data, device)
            
            outputs = model(
                **model_inputs,
                epoch_i=0,
                batch_idx=0,
                train_loader_length=loader_length,
                targets=None,
                is_training=False,
            )
            
            pred_logits = outputs['pred_logits']  # (batch_size, num_queries, num_classes)
            pred_spans_cxw = outputs['pred_spans']  # (batch_size, num_queries, 2) in (center, width)
            pred_spans_xx = span_cxw_to_xx(pred_spans_cxw)  # Convert to (start, end) for IoU
            target_list = targets['span_labels']
            batch_size = pred_spans_xx.shape[0]
            
            for sample_idx in range(batch_size):
                if sample_idx >= len(target_list):
                    continue
                
                gt_spans_cxw = target_list[sample_idx]['spans']
                if gt_spans_cxw.numel() == 0:
                    continue
                
                gt_spans_xx = span_cxw_to_xx(gt_spans_cxw)
                gt_spans_np = gt_spans_xx.detach().cpu().numpy()
                if gt_spans_np.ndim == 1:
                    gt_spans_np = gt_spans_np[None, :]
                
                sample_spans = pred_spans_xx[sample_idx].detach().cpu().numpy()
                if sample_spans.shape[0] == 0:
                    continue
                
                sample_logits = pred_logits[sample_idx]
                probs = F.softmax(sample_logits, dim=-1)
                if probs.shape[-1] == 1:
                    scores = probs.squeeze(-1)
                else:
                    scores = probs[:, 0]
                scores_np = scores.detach().cpu().numpy()
                sorted_indices = np.argsort(scores_np)[::-1]
                sorted_spans = sample_spans[sorted_indices]
                
                per_pred_best_ious = []
                for pred_span in sorted_spans:
                    ious_to_gts = [compute_iou_single(pred_span, gt_span) for gt_span in gt_spans_np]
                    per_pred_best_ious.append(max(ious_to_gts) if ious_to_gts else 0.0)
                
                if not per_pred_best_ious:
                    continue
                
                top1_iou = per_pred_best_ious[0]
                all_top1_ious.append(top1_iou)
                
                for thr in iou_thresholds:
                    if top1_iou >= thr:
                        recall_hits[thr]['R@1'] += 1
                    top5_hits = any(iou >= thr for iou in per_pred_best_ious[:5])
                    if top5_hits:
                        recall_hits[thr]['R@5'] += 1
                    top10_hits = any(iou >= thr for iou in per_pred_best_ious[:10])
                    if top10_hits:
                        recall_hits[thr]['R@10'] += 1
                
                total_queries += 1
    
    metrics = {'mIoU': float(np.mean(all_top1_ious)) if all_top1_ious else 0.0}
    
    if total_queries > 0:
        for thr in iou_thresholds:
            metrics[f'R@1_IoU{thr}'] = recall_hits[thr]['R@1'] / total_queries * 100
            metrics[f'R@5_IoU{thr}'] = recall_hits[thr]['R@5'] / total_queries * 100
            metrics[f'R@10_IoU{thr}'] = recall_hits[thr]['R@10'] / total_queries * 100
    else:
        for thr in iou_thresholds:
            metrics[f'R@1_IoU{thr}'] = 0.0
            metrics[f'R@5_IoU{thr}'] = 0.0
            metrics[f'R@10_IoU{thr}'] = 0.0
    
    return metrics


def print_metrics(metrics, prefix=""):
    """Pretty print evaluation metrics."""
    print(f"\n{prefix}Evaluation Metrics:")
    print(f"  mIoU: {metrics['mIoU']:.4f}")
    
    for thr in [0.3, 0.5, 0.7]:
        if f'R@1_IoU{thr}' in metrics:
            print(f"  IoU≥{thr}:")
            print(f"    R@1:  {metrics[f'R@1_IoU{thr}']:.2f}%")
            print(f"    R@5:  {metrics[f'R@5_IoU{thr}']:.2f}%")
            print(f"    R@10: {metrics[f'R@10_IoU{thr}']:.2f}%")

print("Evaluation functions loaded!")

# Training Pipeline with Checkpointing

Complete training pipeline with:
- Model checkpointing (best, latest, final)
- Results logging (CSV format)
- Early stopping mechanism
- Periodic evaluation on test set

In [None]:
import csv
import json
from datetime import datetime

# Create output directories
OUTPUT_DIR = 'output'
MODELS_DIR = os.path.join(OUTPUT_DIR, 'models')
RESULTS_DIR = os.path.join(OUTPUT_DIR, 'results')

os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

print(f"Output directories created:")
print(f"  Models: {MODELS_DIR}")
print(f"  Results: {RESULTS_DIR}")


def save_checkpoint(model, optimizer, epoch, metrics, filename, config):
    """Save model checkpoint to disk."""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'metrics': metrics or {},
        'config': vars(config)
    }
    torch.save(checkpoint, filename)
    print(f"  Checkpoint saved to {filename}")


def load_checkpoint(model, optimizer, filename):
    """Load a previously saved checkpoint."""
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch'], checkpoint.get('metrics', {})


class TrainingLogger:
    """CSV logger with rich training / evaluation diagnostics."""

    LOSS_KEYS = [
        'loss_span',
        'loss_giou',
        'loss_label',
        'loss_saliency',
        'loss_align',
        'loss_sim',
        'loss_moe_balance',
        'class_error',
    ]

    def __init__(self, log_file):
        self.log_file = log_file
        self.results = []

        self.fieldnames = [
            'timestamp', 'epoch', 'train_loss', 'lr', 'grad_norm',
        ] + self.LOSS_KEYS + [
            'mIoU', 'R@1_IoU0.3', 'R@5_IoU0.3', 'R@10_IoU0.3',
            'R@1_IoU0.5', 'R@5_IoU0.5', 'R@10_IoU0.5',
            'R@1_IoU0.7', 'R@5_IoU0.7', 'R@10_IoU0.7'
        ]

        with open(log_file, 'w', newline='') as f:
            csv.DictWriter(f, fieldnames=self.fieldnames).writeheader()

    def log(self, epoch, train_stats, eval_metrics=None):
        """Append a line of metrics to the CSV log."""
        result = {
            'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
            'epoch': epoch,
            'train_loss': float(train_stats.get('loss', 0.0)),
            'lr': float(train_stats.get('lr', 0.0)),
            'grad_norm': float(train_stats.get('grad_norm', 0.0) or 0.0),
        }

        loss_breakdown = train_stats.get('loss_breakdown', {}) or {}
        for key in self.LOSS_KEYS:
            result[key] = float(loss_breakdown.get(key, 0.0))

        if eval_metrics:
            for key in self.fieldnames[5 + len(self.LOSS_KEYS):]:
                result[key] = float(eval_metrics.get(key, 0.0))

        with open(self.log_file, 'a', newline='') as f:
            csv.DictWriter(f, fieldnames=self.fieldnames).writerow(result)

        self.results.append(result)

    def save_summary(self, summary_file):
        """Persist accumulated training history as JSON."""
        summary = {
            'total_epochs': len(self.results),
            'best_miou': max((r.get('mIoU', 0.0) for r in self.results), default=0.0),
            'final_loss': self.results[-1]['train_loss'] if self.results else 0.0,
            'training_history': self.results
        }

        with open(summary_file, 'w') as f:
            json.dump(summary, f, indent=2)

        print(f"Training summary saved to {summary_file}")


class EarlyStopping:
    """Early stopping handler."""

    def __init__(self, patience=5, min_delta=0.001, mode='min'):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_value = None
        self.should_stop = False

    def __call__(self, current_value):
        if self.best_value is None:
            self.best_value = current_value
            return False

        if self.mode == 'min':
            improved = (self.best_value - current_value) > self.min_delta
        else:
            improved = (current_value - self.best_value) > self.min_delta

        if improved:
            self.best_value = current_value
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
                print(f"\nEarly stopping triggered! No improvement for {self.patience} epochs.")
        return self.should_stop


print("Training utilities initialized!")

### Training Orchestration

The next cell drives the full training schedule: it initialises logging, steps the scheduler, triggers periodic evaluation, and manages checkpoint persistence (best/latest/final). Run it after all preceding definitions have executed.

In [None]:
# Training configuration
num_epochs = config.n_epoch
eval_every = 1  # Evaluate every N epochs
use_early_stopping = False  # Set to True to enable early stopping
early_stop_patience = 5

# Weights & Biases experiment tracking
use_wandb = True  # Set to False to disable logging to Weights & Biases
wandb_project = "Switch-net-moment-retrieval"  # Update this to your own W&B project name if needed
wandb_run = None
if use_wandb and wandb is None:
    print("Weights & Biases module not found. Install it with `pip install wandb` or set use_wandb=False.")
    use_wandb = False

# Initialize logger
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = os.path.join(RESULTS_DIR, f'training_log_{timestamp}.csv')
logger = TrainingLogger(log_file)

# Initialize early stopping (optional)
if use_early_stopping:
    early_stopping = EarlyStopping(patience=early_stop_patience, mode='max')  # Monitor mIoU
    print(f"Early stopping enabled with patience={early_stop_patience}")

# Optionally initialise Weights & Biases
if use_wandb:
    wandb_run_name = f"Switch-net_{timestamp}"
    initial_lr = optimizer.param_groups[0].get('lr', None) if optimizer.param_groups else None
    wandb_config = {
        "model": "Switch-net",
        "num_epochs": num_epochs,
        "eval_every": eval_every,
        "optimizer": optimizer.__class__.__name__,
        "learning_rate": initial_lr,
        "batch_size": getattr(config, 'batch_size', None),
        "dataset": getattr(config, 'dataset_name', None),
        "timestamp": timestamp,
    }
    try:
        wandb_run = wandb.init(
            project=wandb_project,
            name=wandb_run_name,
            config={k: v for k, v in wandb_config.items() if v is not None},
            reinit=True,
        )
        wandb_run.define_metric("epoch")
        wandb_run.define_metric("train/*", step_metric="epoch")
        wandb_run.define_metric("val/*", step_metric="epoch")
        print(f"W&B logging enabled -> project='{wandb_project}', run='{wandb_run_name}'")
    except Exception as exc:
        print(f"WARNING: W&B initialisation failed: {exc}")
        print("Continuing without W&B logging. Set use_wandb=False to silence this message.")
        wandb_run = None
        use_wandb = False

# Track best model
best_miou = 0.0
best_epoch = 0
best_checkpoint_path = None
completed_epochs = 0

print(f"\n{'='*80}")
print(f"Starting Training")
print(f"{'='*80}")
print(f"Total epochs: {num_epochs}")
print(f"Evaluation frequency: every {eval_every} epochs")
print(f"Results will be saved to: {RESULTS_DIR}")
print(f"Models will be saved to: {MODELS_DIR}")
if use_wandb:
    print("Metrics will also be logged to Weights & Biases.")
print(f"{'='*80}\n")

# Training loop
for epoch in range(num_epochs):
    print(f"\n{'='*80}")
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"{'='*80}")

    train_stats = train_one_epoch(model, criterion, train_loader, optimizer, epoch, config)
    lr_scheduler.step()

    print(f"\nTraining Stats:")
    print(f"  Loss: {train_stats['loss']:.4f}")
    print(f"  Learning Rate: {train_stats['lr']:.6f}")
    if train_stats.get('epoch_time') is not None:
        print(f"  Epoch time: {train_stats['epoch_time'] / 60:.2f} min")
    if train_stats.get('grad_norm') is not None:
        print(f"  Avg Grad Norm: {train_stats['grad_norm']:.4f}")
    if train_stats.get('saliency_warmup_active'):
        print("  Saliency warm-up active -> loss_saliency weight set to 0 this epoch")
    applied_loss_weights = train_stats.get('applied_loss_weights')
    if applied_loss_weights:
        print("  Applied loss weights:")
        for key in sorted(applied_loss_weights):
            print(f"    {key}: {applied_loss_weights[key]:.3f}")
    if train_stats.get('loss_breakdown'):
        print("  Loss Breakdown:")
        for key in sorted(train_stats['loss_breakdown']):
            print(f"    {key}: {train_stats['loss_breakdown'][key]:.4f}")

    if use_wandb and wandb_run is not None:
        wandb_log = {
            "epoch": epoch + 1,
            "train/loss": float(train_stats.get('loss', 0.0)),
            "train/lr": float(train_stats.get('lr', 0.0)),
        }
        if train_stats.get('grad_norm') is not None:
            wandb_log["train/grad_norm"] = float(train_stats['grad_norm'])
        if train_stats.get('epoch_time') is not None:
            wandb_log["train/epoch_minutes"] = float(train_stats['epoch_time'] / 60.0)
        loss_breakdown = train_stats.get('loss_breakdown') or {}
        for key, value in loss_breakdown.items():
            wandb_log[f"train/{key}"] = float(value)
        for key, value in (applied_loss_weights or {}).items():
            wandb_log[f"train/weight_{key}"] = float(value)
        wandb.log(wandb_log, step=epoch + 1)

    eval_metrics = None
    should_eval = ((epoch + 1) % eval_every == 0) or ((epoch + 1) == num_epochs)
    if should_eval:
        print("\nRunning evaluation...")
        eval_metrics = evaluate_moment_retrieval(
            model,
            test_loader,
            config,
            device,
            iou_thresholds=[0.3, 0.5, 0.7],
        )
        print_metrics(eval_metrics, prefix="Validation ")

        current_miou = eval_metrics.get('mIoU', 0.0)
        if current_miou > best_miou:
            best_miou = current_miou
            best_epoch = epoch + 1
            best_checkpoint_path = os.path.join(MODELS_DIR, 'best_model_moe.pth')
            save_checkpoint(model, optimizer, epoch + 1, eval_metrics, best_checkpoint_path, config)
            print(f"  New best model saved (mIoU={current_miou:.4f})")
            if use_wandb and wandb_run is not None:
                wandb_run.summary['best_miou'] = best_miou
                wandb_run.summary['best_epoch'] = best_epoch

    latest_checkpoint_path = os.path.join(MODELS_DIR, 'latest_model_moe.pth')
    save_checkpoint(model, optimizer, epoch + 1, eval_metrics or {}, latest_checkpoint_path, config)

    if (epoch + 1) == num_epochs:
        final_checkpoint_path = os.path.join(MODELS_DIR, f'final_model_epoch{epoch + 1}_moe.pth')
        save_checkpoint(model, optimizer, epoch + 1, eval_metrics or {}, final_checkpoint_path, config)

    logger.log(epoch + 1, train_stats, eval_metrics)
    completed_epochs = epoch + 1

    if use_wandb and wandb_run is not None and eval_metrics:
        val_log = {"epoch": epoch + 1}
        for key, value in eval_metrics.items():
            val_log[f"val/{key}"] = float(value)
        wandb.log(val_log, step=epoch + 1)

    if use_early_stopping and eval_metrics is not None:
        if early_stopping(eval_metrics.get('mIoU', 0.0)):
            print("\nEarly stopping criterion met. Stopping training.")
            break

summary_file = os.path.join(RESULTS_DIR, f'training_summary_{timestamp}.json')
logger.save_summary(summary_file)

if best_epoch > 0:
    print(f"\nBest mIoU {best_miou:.4f} achieved at epoch {best_epoch}.")
else:
    print("\nTraining finished without evaluation. Consider increasing eval frequency.")

if use_wandb and wandb_run is not None:
    wandb_run.summary.setdefault('best_miou', best_miou)
    wandb_run.summary.setdefault('best_epoch', best_epoch)
    wandb_run.summary['trained_epochs'] = completed_epochs
    wandb_run.summary['log_file'] = log_file
    try:
        wandb_run.finish()
    except Exception:
        wandb.finish()

In [None]:
# Quick evaluation with the latest saved checkpoint
latest_checkpoint_path = 'output/models/best_model.pth'
assert os.path.exists(latest_checkpoint_path), f'Checkpoint not found: {latest_checkpoint_path}'

checkpoint = torch.load(latest_checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
model.to(device)
model.eval()
print(f"Loaded checkpoint from {latest_checkpoint_path}")

quick_metrics = evaluate_moment_retrieval(
    model,
    test_loader,
    config,
    device,
    iou_thresholds=[0.3, 0.5, 0.7],
)
print_metrics(quick_metrics, prefix="Latest model ")