In [3]:
import os
import json
import numpy as np
import pandas as pd

def load_or_create_data(json_path, excel_path, labels_path, video_dir, text_dir, caption_dir):
    """
    Load labels from NPY and JSON. If JSON is missing or mismatched, generate it from Excel.

    Args:
        json_path (str): path to the JSON file storing data_list
        excel_path (str): path to the Excel file containing the video list
        labels_path (str): path to the NPY file containing labels
        video_dir (str): directory containing .npy video features
        text_dir (str): directory containing .npy text features
        caption_dir (str): directory containing .npy caption features

    Returns:
        data_list (list of dict): list of dictionaries with keys 'video_feat', 'text_feat', 'caption_feat'
        labels (np.ndarray): labels
    """

    # Load labels
    if not os.path.exists(labels_path):
        raise FileNotFoundError(f"Labels file not found: {labels_path}")
    labels = np.load(labels_path)

    # Check JSON
    need_create_json = True
    if os.path.exists(json_path):
        with open(json_path, "r", encoding="utf-8") as f:
            try:
                data_list = json.load(f)
                if len(data_list) == len(labels):
                    need_create_json = False
            except json.JSONDecodeError:
                need_create_json = True

    # Create JSON 
    if need_create_json:
        if not os.path.exists(excel_path):
            raise FileNotFoundError(f"Excel file not found: {excel_path}")
        df = pd.read_excel(excel_path)
        data_list = []
        for vid in df['video_id'].astype(str):
            data_list.append({
                "name": vid,
                "video_feat": os.path.join(video_dir, f"{vid}.npy"),
                "text_feat": os.path.join(text_dir, f"{vid}.npy"),
                "caption_feat": os.path.join(caption_dir, f"{vid}.npy")
            })
        # Create folder
        os.makedirs(os.path.dirname(json_path), exist_ok=True)
        # Save JSON
        with open(json_path, "w", encoding="utf-8") as f:
            json.dump(data_list, f, ensure_ascii=False, indent=2)
        print(f"JSON created, items: {len(data_list)}")
    else:
        print(f"JSON loaded, items: {len(data_list)}")

    return data_list, labels

video_dir = "/Users/jumita/Downloads/final/frame"
text_dir = "/Users/jumita/Downloads/final/text"
caption_dir = "/Users/jumita/Downloads/final/caption"
json_path = "/Users/jumita/Downloads/final_code/data.json"
excel_path = "/Users/jumita/Downloads/Book5.xlsx"
labels_path = "/Users/jumita/Downloads/final_code/labels.npy"

data_list, labels = load_or_create_data(
    json_path=json_path,
    excel_path=excel_path,
    labels_path=labels_path,
    video_dir=video_dir,
    text_dir=text_dir,
    caption_dir=caption_dir
)



JSON loaded, items: 5047


In [4]:
import torch
import torch.nn as nn
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader

# ================= VIDEO FEATURE AUGMENTER =================
class VideoFeatureAugmenter:
    """
    Data augmenter for video features.

    This class applies temporal cropping and adaptive Gaussian noise 
    to video features to improve model generalization during training.

    Attributes
    ----------
    augment_prob : float
        Probability of applying augmentation.
    dropout : nn.Dropout
        Dropout layer used for augmentation (optional, not used here).
    noise_scale : float
        Scaling factor for Gaussian noise.

    Methods
    -------
    temporal_crop(features, crop_ratio=0.9):
        Randomly crops the sequence along the temporal dimension.
    adaptive_noise(features):
        Adds Gaussian noise proportional to the feature's standard deviation.
    __call__(features):
        Applies augmentation with probability `augment_prob`.
    """

    def __init__(self, augment_prob=0.5, dropout_prob=0.3, noise_scale=0.05):
        self.augment_prob = augment_prob
        self.dropout = nn.Dropout(p=dropout_prob)
        self.noise_scale = noise_scale

    def temporal_crop(self, features, crop_ratio=0.9):
        """
        Crop the feature sequence randomly along the temporal dimension.

        Parameters
        ----------
        features : torch.Tensor
            Video feature tensor of shape (T, D) or (T, H, D).
        crop_ratio : float
            Fraction of the sequence to keep.

        Returns
        -------
        torch.Tensor
            Cropped feature tensor.
        """
        T = features.shape[0]
        new_T = int(T * crop_ratio)
        start = np.random.randint(0, T - new_T + 1)
        return features[start:start+new_T]

    def adaptive_noise(self, features):
        """
        Add Gaussian noise scaled by feature standard deviation.

        Parameters
        ----------
        features : torch.Tensor
            Feature tensor.

        Returns
        -------
        torch.Tensor
            Noisy feature tensor.
        """
        std_per_feature = features.std(dim=0, keepdim=True) + 1e-6
        noise = torch.randn_like(features) * std_per_feature * self.noise_scale
        return features + noise

    def __call__(self, features):
        """
        Apply augmentation with probability `augment_prob`.

        Parameters
        ----------
        features : torch.Tensor
            Input feature tensor.

        Returns
        -------
        torch.Tensor
            Augmented feature tensor.
        """
        if torch.rand(1) < self.augment_prob:
            if features.ndim >= 2:
                features = self.temporal_crop(features)
            features = self.adaptive_noise(features)
        return features

# ================= TEXT/CAPTION FEATURE AUGMENTER =================
class TextFeatureAugmenter:
    """
    Data augmenter for text or caption features.

    Applies dropout and adaptive Gaussian noise to embeddings.

    Attributes
    ----------
    augment_prob : float
        Probability of applying augmentation.
    dropout : nn.Dropout
        Dropout layer applied to features.
    noise_scale : float
        Scaling factor for Gaussian noise.

    Methods
    -------
    adaptive_noise(features):
        Adds Gaussian noise proportional to feature's std.
    __call__(features):
        Applies augmentation with probability `augment_prob`.
    """
    def __init__(self, augment_prob=0.3, dropout_prob=0.05, noise_scale=0.02):
        self.augment_prob = augment_prob
        self.dropout = nn.Dropout(p=dropout_prob)
        self.noise_scale = noise_scale

    def adaptive_noise(self, features):
        """
        Add Gaussian noise scaled by feature standard deviation.

        Parameters
        ----------
        features : torch.Tensor
            Input feature tensor.

        Returns
        -------
        torch.Tensor
            Noisy feature tensor.
        """
        std_per_feature = features.std(dim=0, keepdim=True) + 1e-6
        noise = torch.randn_like(features) * std_per_feature * self.noise_scale
        return features + noise

    def __call__(self, features):
        """
        Apply dropout and noise augmentation with probability `augment_prob`.

        Parameters
        ----------
        features : torch.Tensor
            Input feature tensor.

        Returns
        -------
        torch.Tensor
            Augmented feature tensor.
        """
        if torch.rand(1) < self.augment_prob:
            features = self.dropout(features)
            features = self.adaptive_noise(features)
        return features


# ================= DATASET =================
class VideoTextCaptionDataset(Dataset):
    """
    PyTorch Dataset for multimodal video, text, and caption features.

    Handles optional data augmentation and per-sample feature normalization.

    Parameters
    ----------
    data_list : list of dict
        List of dictionaries containing paths for 'video_feat', 'text_feat', 'caption_feat'.
    labels : np.ndarray
        Array of integer labels corresponding to each sample.
    augment : bool, optional
        Whether to apply data augmentation (default: False).

    Methods
    -------
    __getitem__(idx):
        Loads, optionally augments, and returns a single sample.
    __len__():
        Returns the total number of samples.
    """
    def __init__(self, data_list, labels, augment=False):
        self.data_list = data_list
        self.labels = labels
        self.augment = augment
        self.video_aug = VideoFeatureAugmenter() if augment else None
        self.text_aug = TextFeatureAugmenter() if augment else None

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        """
        Load and process a single sample.

        Returns
        -------
        dict
            Dictionary with keys: 'video_feat', 'text_feat', 'caption_feat', 'label'.
        """
        item = self.data_list[idx]

        video_feat = torch.from_numpy(np.load(item["video_feat"])).float()
        text_feat = torch.from_numpy(np.load(item["text_feat"])).float()
        caption_feat = torch.from_numpy(np.load(item["caption_feat"])).float()

        # Apply augmentation if enabled
        if self.augment:
            if self.video_aug and video_feat.ndim >= 2:  
                video_feat = self.video_aug(video_feat)
            if self.text_aug:
                text_feat = self.text_aug(text_feat)
                caption_feat = self.text_aug(caption_feat)

        # Flatten features if needed
        if video_feat.ndim == 3:
            video_feat = video_feat.mean(dim=(0,1))
        elif video_feat.ndim == 2:
            video_feat = video_feat.mean(dim=0)
        
        text_feat = text_feat if text_feat.ndim == 1 else text_feat.mean(dim=0)
        caption_feat = caption_feat if caption_feat.ndim == 1 else caption_feat.mean(dim=0)

        # Normalize each feature
        video_feat = (video_feat - video_feat.mean()) / (video_feat.std() + 1e-6)
        text_feat = (text_feat - text_feat.mean()) / (text_feat.std() + 1e-6)
        caption_feat = (caption_feat - caption_feat.mean()) / (caption_feat.std() + 1e-6)

        label = torch.tensor(self.labels[idx], dtype=torch.long)

        return {
            "video_feat": video_feat,
            "text_feat": text_feat,
            "caption_feat": caption_feat,
            "label": label
        }

def create_dataloader(data_list, labels, batch_size=16, shuffle=True, augment=False):
    """
    Create a PyTorch DataLoader for the multimodal dataset.

    Parameters
    ----------
    data_list : list of dict
        List of dictionaries containing feature file paths.
    labels : np.ndarray
        Array of labels.
    batch_size : int, optional
        Batch size (default: 16).
    shuffle : bool, optional
        Whether to shuffle data (default: True).
    augment : bool, optional
        Whether to apply data augmentation (default: False).

    Returns
    -------
    DataLoader
        PyTorch DataLoader for the dataset.
    """
    dataset = VideoTextCaptionDataset(data_list, labels, augment=augment)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

# ================= LOAD DATA =================
data_list, labels = load_or_create_data(
    json_path="/Users/jumita/Downloads/final_code/data.json",
    excel_path="/Users/jumita/Downloads/Book5.xlsx",
    labels_path="/Users/jumita/Downloads/final_code/labels.npy",
    video_dir="/Users/jumita/Downloads/final/frame",
    text_dir="/Users/jumita/Downloads/final/text",
    caption_dir="/Users/jumita/Downloads/final/caption"
)

# ================= SPLIT DATA =================
data_train, data_temp, labels_train, labels_temp = train_test_split(
    data_list, labels, test_size=0.2, random_state=42, stratify=labels
)

try:
    data_val, data_test, labels_val, labels_test = train_test_split(
        data_temp, labels_temp, test_size=0.5, random_state=42, stratify=labels_temp
    )
except ValueError:
    data_val, data_test, labels_val, labels_test = train_test_split(
        data_temp, labels_temp, test_size=0.5, random_state=42, stratify=None
    )

def check_overlap(set1, set2):
    s1 = set([item["video_feat"] for item in set1])
    s2 = set([item["video_feat"] for item in set2])
    return len(s1 & s2)

print("Overlap train-val:", check_overlap(data_train, data_val))
print("Overlap train-test:", check_overlap(data_train, data_test))
print("Overlap val-test:", check_overlap(data_val, data_test))


# ================= DATALOADERS =================
train_loader = create_dataloader(data_train, labels_train, batch_size=16, shuffle=True, augment=True)
val_loader = create_dataloader(data_val, labels_val, batch_size=16, shuffle=False)
test_loader = create_dataloader(data_test, labels_test, batch_size=16, shuffle=False)

# ================= DEBUG SAMPLE =================
sample = next(iter(train_loader))
print("Augmented sample shapes:")
print("Video:", sample["video_feat"].shape)
print("Text:", sample["text_feat"].shape)
print("Caption:", sample["caption_feat"].shape) 



JSON loaded, items: 5047
Overlap train-val: 0
Overlap train-test: 0
Overlap val-test: 0
Augmented sample shapes:
Video: torch.Size([16, 256])
Text: torch.Size([16, 256])
Caption: torch.Size([16, 256])


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ================= StochasticDepth =================
class StochasticDepth(nn.Module):
    """
    Implements Stochastic Depth (also called DropPath) regularization.

    During training, randomly drops entire residual paths with probability `drop_prob`.

    Parameters
    ----------
    drop_prob : float
        Probability of dropping a path.
    """
    def __init__(self, drop_prob):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        """
        Forward pass with stochastic depth.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape (B, ...).

        Returns
        -------
        torch.Tensor
            Output tensor, either dropped or scaled.
        """
        if not self.training or self.drop_prob == 0.0:
            return x
        keep_prob = 1 - self.drop_prob
        mask = (torch.rand(x.shape[0], device=x.device) < keep_prob).to(x.dtype)
        mask = mask.view(-1, *([1] * (x.dim() - 1)))
        return x * mask / keep_prob

# ================= LoRA Linear =================
class LoRALinear(nn.Module):
    """
    Linear layer with Low-Rank Adaptation (LoRA).

    Parameters
    ----------
    in_features : int
        Input feature dimension.
    out_features : int
        Output feature dimension.
    r : int
        Rank of the LoRA adaptation matrices.
    alpha : int
        Scaling factor for LoRA updates.
    """
    def __init__(self, in_features, out_features, r=6, alpha=6):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.r = r
        self.alpha = alpha
        self.scaling = self.alpha / max(1, r)
        self.linear = nn.Linear(in_features, out_features)
        if r > 0:
            self.lora_A = nn.Parameter(torch.randn(r, in_features) * 1e-3)
            self.lora_B = nn.Parameter(torch.zeros(out_features, r))
        else:
            self.lora_A = None
            self.lora_B = None

    def forward(self, x):
        """
        Forward pass through LoRA linear layer.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape (B, in_features).

        Returns
        -------
        torch.Tensor
            Output tensor of shape (B, out_features).
        """
        result = self.linear(x)
        if self.r > 0:
            lora_update = (x @ self.lora_A.T) @ self.lora_B.T
            result += lora_update * self.scaling
        return result

# ================= LoRA MultiheadAttention =================
class LoRAMultiheadAttention(nn.Module):
    """
    Multihead Attention layer with LoRA adaptation.

    Parameters
    ----------
    embed_dim : int
        Dimension of embeddings.
    num_heads : int
        Number of attention heads.
    dropout : float
        Dropout probability in attention.
    batch_first : bool
        If True, input shape is (B, S, D).
    r : int
        LoRA rank.
    alpha : int
        LoRA scaling factor.
    """
    def __init__(self, embed_dim, num_heads, dropout=0.0, batch_first=True, r=6, alpha=6):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.r = r
        self.alpha = alpha
        self.scaling = self.alpha / max(1, r)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=batch_first)
        if r > 0:
            self.lora_A_q = nn.Parameter(torch.randn(r, embed_dim) * 1e-3)
            self.lora_B_q = nn.Parameter(torch.zeros(embed_dim, r))
            self.lora_A_k = nn.Parameter(torch.randn(r, embed_dim) * 1e-3)
            self.lora_B_k = nn.Parameter(torch.zeros(embed_dim, r))
            self.lora_A_v = nn.Parameter(torch.randn(r, embed_dim) * 1e-3)
            self.lora_B_v = nn.Parameter(torch.zeros(embed_dim, r))
            self.lora_A_out = nn.Parameter(torch.randn(r, embed_dim) * 1e-3)
            self.lora_B_out = nn.Parameter(torch.zeros(embed_dim, r))

    def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None):
        """
        Forward pass through LoRA multihead attention.

        Parameters
        ----------
        query, key, value : torch.Tensor
            Input tensors of shape (B, S, D).
        key_padding_mask : torch.Tensor, optional
            Mask for padded tokens.
        need_weights : bool
            If True, returns attention weights.
        attn_mask : torch.Tensor, optional
            Attention mask.

        Returns
        -------
        attn_output : torch.Tensor
            Attention output tensor.
        attn_weights : torch.Tensor
            Attention weights (if `need_weights` is True).
        """
        if self.r > 0:
            query = query + (query @ self.lora_A_q.T) @ self.lora_B_q.T * self.scaling
            key   = key   + (key   @ self.lora_A_k.T) @ self.lora_B_k.T * self.scaling
            value = value + (value @ self.lora_A_v.T) @ self.lora_B_v.T * self.scaling

        attn_output, attn_weights = self.attn(
            query, key, value,
            key_padding_mask=key_padding_mask,
            need_weights=need_weights,
            attn_mask=attn_mask
        )

        if self.r > 0:
            attn_output = attn_output + (attn_output @ self.lora_A_out.T) @ self.lora_B_out.T * self.scaling
        return attn_output, attn_weights

# ================= CoAttentionBlock =================
class CoAttentionBlock(nn.Module):
    """
    Co-Attention block between query and key/value features with LoRA and residual feed-forward.

    Parameters
    ----------
    dim_q : int
        Dimension of query features.
    dim_kv : int
        Dimension of key/value features.
    num_heads : int
        Number of attention heads.
    hidden_dim : int
        Hidden dimension for feed-forward network.
    dropout : float
        Dropout probability.
    lora_r : int
        LoRA rank.
    drop_path_prob : float
        Stochastic depth probability.
    """
    def __init__(self, dim_q=256, dim_kv=256, num_heads=8, hidden_dim=256, dropout=0.2, lora_r=6, drop_path_prob=0.2):
        super().__init__()
        self.query_proj = LoRALinear(dim_q, dim_q, r=lora_r)
        self.key_proj   = LoRALinear(dim_kv, dim_q, r=lora_r)
        self.value_proj = LoRALinear(dim_kv, dim_q, r=lora_r)
        self.attn = LoRAMultiheadAttention(dim_q, num_heads, dropout=dropout, batch_first=True, r=lora_r)
        self.gate = nn.Sequential(LoRALinear(dim_q * 2, dim_q, r=lora_r), nn.Sigmoid())
        self.ffn = nn.Sequential(
            LoRALinear(dim_q, hidden_dim, r=lora_r),
            nn.GELU(),
            nn.Dropout(dropout),
            LoRALinear(hidden_dim, dim_q, r=lora_r),
            nn.Dropout(dropout)
        )
        self.norm1 = nn.LayerNorm(dim_q)
        self.norm2 = nn.LayerNorm(dim_q)
        self.drop_path = StochasticDepth(drop_path_prob) if drop_path_prob > 0.0 else None

    def forward(self, query, key, value, key_padding_mask=None):
        """
        Forward pass of CoAttentionBlock.

        Parameters
        ----------
        query, key, value : torch.Tensor
            Input feature tensors.
        key_padding_mask : torch.Tensor, optional
            Mask for padded tokens.

        Returns
        -------
        torch.Tensor
            Output tensor after co-attention, gating, and feed-forward.
        """
        q = self.query_proj(query)
        k = self.key_proj(key)
        v = self.value_proj(value)
        attn_out, _ = self.attn(q, k, v, key_padding_mask=key_padding_mask)
        gate = self.gate(torch.cat([query, attn_out], dim=-1))
        out = self.norm1(query + gate * attn_out)
        ffn_out = self.ffn(out)
        if self.drop_path is not None:
            ffn_out = self.drop_path(ffn_out)
        return self.norm2(out + ffn_out)

# ================= CoAttentionLayer =================
class CoAttentionLayer(nn.Module):
    """
    Co-Attention layer for interaction between video and text features.

    It applies co-attention blocks in both directions:
    video-to-text and text-to-video, then pools the sequences
    and concatenates the outputs.

    Parameters
    ----------
    dim_q : int
        Dimension of query features (video features).
    dim_kv : int
        Dimension of key/value features (text features).
    num_heads : int
        Number of attention heads.
    hidden_dim : int
        Hidden dimension for feed-forward networks.
    dropout : float
        Dropout probability.
    lora_r : int
        LoRA rank.
    max_video_tokens : int
        Maximum number of video tokens to keep after pooling.
    """
    def __init__(self, dim_q=256, dim_kv=256, num_heads=8, hidden_dim=256, dropout=0.5, lora_r=6, max_video_tokens=32):
        super().__init__()
        self.video2text = CoAttentionBlock(dim_q, dim_kv, num_heads, hidden_dim, dropout, lora_r)
        self.text2video = CoAttentionBlock(dim_kv, dim_q, num_heads, hidden_dim, dropout, lora_r)
        self.max_video_tokens = max_video_tokens

    def forward(self, video_feat, text_feat, video_mask=None, text_mask=None):
        """
        Forward pass of co-attention layer.

        Parameters
        ----------
        video_feat : torch.Tensor
            Video features, shape (B, T, D) or (B, N, T, D) if patches.
        text_feat : torch.Tensor
            Text features, shape (B, S, D).
        video_mask : torch.Tensor, optional
            Mask for video tokens.
        text_mask : torch.Tensor, optional
            Mask for text tokens.

        Returns
        -------
        torch.Tensor
            Concatenated pooled co-attention features, shape (B, D_video + D_text).
        """
        B = video_feat.shape[0]

        # Flatten video patches if needed
        if video_feat.dim() == 4:  
            T, N, D = video_feat.shape[1], video_feat.shape[2], video_feat.shape[3]
            video_feat = video_feat.view(B, T * N, D)
        elif video_feat.dim() == 2:
            video_feat = video_feat.unsqueeze(1)

        # Pooling video sequence
        if video_feat.shape[1] > self.max_video_tokens:
            video_feat = video_feat.transpose(1, 2)  # (B, D, Seq)
            video_feat = F.adaptive_avg_pool1d(video_feat, self.max_video_tokens)
            video_feat = video_feat.transpose(1, 2)

        if text_feat.dim() == 2:
            text_feat = text_feat.unsqueeze(1)

        # Forward co-attention with optional masks
        v2t = self.video2text(video_feat, text_feat, text_feat, key_padding_mask=text_mask)
        t2v = self.text2video(text_feat, video_feat, video_feat, key_padding_mask=video_mask)
        
        # Pool sequences
        v2t_pooled = v2t.mean(dim=1)
        t2v_pooled = t2v.mean(dim=1)
        return torch.cat([v2t_pooled, t2v_pooled], dim=-1)

# ================= LateFusionLoRA =================
class LateFusionLoRA(nn.Module):
    """
    Late fusion module with LoRA for combining two feature modalities.

    Applies two-layer MLP with LayerNorm, GELU activation, dropout,
    and optional stochastic depth, then outputs class logits.

    Parameters
    ----------
    dim_feat1 : int
        Dimension of first input feature.
    dim_feat2 : int
        Dimension of second input feature.
    dim_hidden : int
        Hidden layer dimension.
    num_classes : int
        Number of output classes.
    r : int
        LoRA rank.
    """
    def __init__(self, dim_feat1, dim_feat2, dim_hidden, num_classes, r=6):
        super().__init__()
        self.fc1 = LoRALinear(dim_feat1 + dim_feat2, dim_hidden * 2, r=r)
        self.norm1 = nn.LayerNorm(dim_hidden * 2)
        self.fc2 = LoRALinear(dim_hidden * 2, dim_hidden, r=r)
        self.norm2 = nn.LayerNorm(dim_hidden)
        self.fc3 = LoRALinear(dim_hidden, num_classes, r=r)
        self.dropout = nn.Dropout(0.35)
        self.drop_path = StochasticDepth(0.2) if dim_hidden > 256 else None

    def forward(self, x1, x2):
        """
        Forward pass of late fusion module.

        Parameters
        ----------
        x1, x2 : torch.Tensor
            Input feature tensors to fuse, shape (B, dim_feat1), (B, dim_feat2).

        Returns
        -------
        torch.Tensor
            Output class logits, shape (B, num_classes).
        """
        x = torch.cat([x1, x2], dim=-1)
        x = self.fc1(x)
        x = self.norm1(x)
        x = F.gelu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.norm2(x)
        x = F.gelu(x)
        if self.drop_path is not None:
            x = self.drop_path(x)
        return self.fc3(x)

# ================= MultimodalModel =================
class MultimodalModel(nn.Module):
    """
    Multimodal model combining co-attention between video and text features
    and late fusion with caption features, with LoRA finetuning support.

    Parameters
    ----------
    video_dim : int
        Dimension of video features.
    text_dim : int
        Dimension of text features.
    caption_dim : int
        Dimension of caption features.
    hidden_dim : int
        Hidden dimension for co-attention and fusion layers.
    num_classes : int
        Number of output classes.
    r : int
        LoRA rank.
    finetune_mode : str
        Finetuning mode: 'coattn_plus_latefusion', 'late_fusion_only', 'coattn_only', etc.
    dropout_prob : float
        Dropout probability for the fusion layer.
    """
    def __init__(self, video_dim=256, text_dim=256, caption_dim=256, hidden_dim=256,
                 num_classes=3, r=6, finetune_mode: str = 'coattn_plus_latefusion',
                 dropout_prob=0.1):
        super().__init__()
        self.ca1 = CoAttentionLayer(video_dim, text_dim, num_heads=8, hidden_dim=hidden_dim, lora_r=r)
        self.fc_ca1 = LoRALinear(video_dim * 2, video_dim, r=r)
        self.dropout = nn.Dropout(dropout_prob)   
        self.late_fusion = LateFusionLoRA(video_dim, caption_dim, hidden_dim, num_classes, r=r)
        self.set_trainable(finetune_mode)

    def set_trainable(self, finetune_mode: str = 'coattn_plus_latefusion'):
        """
        Set which parameters are trainable according to finetune_mode.

        Parameters
        ----------
        finetune_mode : str
            Mode of finetuning.
        """
        for p in self.parameters():
            p.requires_grad = False

        finetune_mode = finetune_mode.lower()

        def enable_lora(module):
            for name, p in module.named_parameters():
                if "lora_A" in name or "lora_B" in name:
                    p.requires_grad = True

        def unfreeze_part(module, names_to_unfreeze):
            for name, p in module.named_parameters():
                for target_name in names_to_unfreeze:
                    if target_name in name:
                        p.requires_grad = True

        if finetune_mode == 'late_fusion_only':
            enable_lora(self.late_fusion)
            enable_lora(self.fc_ca1)
        elif finetune_mode == 'coattn_only':
            enable_lora(self.ca1)
            enable_lora(self.fc_ca1)
        elif finetune_mode == 'coattn_plus_latefusion':
            enable_lora(self.ca1)
            enable_lora(self.late_fusion)
            enable_lora(self.fc_ca1)
        elif finetune_mode == 'coattn_plus_latefusion_part':
            enable_lora(self.ca1)
            enable_lora(self.late_fusion)
            enable_lora(self.fc_ca1)
            unfreeze_part(self.ca1.video2text.query_proj, ['linear'])
            unfreeze_part(self.late_fusion.fc1, ['linear'])
        elif finetune_mode == 'all':
            for p in self.parameters():
                p.requires_grad = True
        else:
            raise ValueError(f"Unknown finetune_mode: {finetune_mode}")

    def get_param_groups(self, lr_late=2e-4, lr_coattn=1.5e-4, lr_fc_ca1=5e-5, weight_decay=5e-4):
        """
        Return parameter groups with separate learning rates for optimizer.

        Parameters
        ----------
        lr_late, lr_coattn, lr_fc_ca1 : float
            Learning rates for respective modules.
        weight_decay : float
            Weight decay for optimizer.

        Returns
        -------
        list of dict
            Parameter groups for optimizer.
        """
        groups = []
        if any(p.requires_grad for p in self.late_fusion.parameters()):
            groups.append({'params': [p for p in self.late_fusion.parameters() if p.requires_grad], 'lr': lr_late, 'weight_decay': weight_decay})
        if any(p.requires_grad for p in self.ca1.parameters()):
            groups.append({'params': [p for p in self.ca1.parameters() if p.requires_grad], 'lr': lr_coattn, 'weight_decay': weight_decay})
        if any(p.requires_grad for p in self.fc_ca1.parameters()):
            groups.append({'params': [p for p in self.fc_ca1.parameters() if p.requires_grad], 'lr': lr_fc_ca1, 'weight_decay': weight_decay})
        return groups

    def forward(self, video_feat, text_feat, caption_feat, video_mask=None, text_mask=None):
        """
        Forward pass of the multimodal model.

        Parameters
        ----------
        video_feat : torch.Tensor
            Video features.
        text_feat : torch.Tensor
            Text features.
        caption_feat : torch.Tensor
            Caption features.
        video_mask : torch.Tensor, optional
            Mask for video tokens.
        text_mask : torch.Tensor, optional
            Mask for text tokens.

        Returns
        -------
        torch.Tensor
            Class logits for each sample.
        """
        ca1_out = self.ca1(video_feat, text_feat, video_mask=video_mask, text_mask=text_mask)
        ca1_out = self.fc_ca1(ca1_out)
        ca1_out = self.dropout(ca1_out)
        return self.late_fusion(ca1_out, caption_feat)

In [None]:
import os
import torch
import torch.nn as nn
import numpy as np
from sklearn.model_selection import StratifiedKFold
import time

# ====================== EARLY STOPPING ======================
class EarlyStopping:
    """
    Early stopping utility to stop training when a monitored metric stops improving.

    Parameters
    ----------
    patience : int
        Number of epochs to wait after last improvement before stopping.
    delta : float
        Minimum change to qualify as improvement.
    verbose : bool
        Whether to print messages when counter increases.
    mode : str
        'min' for metrics to minimize (e.g., loss), 'max' for metrics to maximize (e.g., accuracy).
    """
    def __init__(self, patience=10, delta=0.0, verbose=False, mode='min'):
        self.patience = patience
        self.delta = delta
        self.verbose = verbose
        self.mode = mode
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, metric_value):
        """
        Call method to update early stopping status.

        Parameters
        ----------
        metric_value : float
            Current metric value to monitor.

        Returns
        -------
        bool
            True if training should stop, False otherwise.
        """
        score = -metric_value if self.mode == 'min' else metric_value
        if self.best_score is None:
            self.best_score = score
            return False
        if score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} / {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
                return True
        else:
            self.best_score = score
            self.counter = 0
        return False

# ====================== TRAINER ======================
class MultimodalTrainer:
    """
    Trainer class for multimodal model with gradient accumulation, max gradient clipping,
    and optional learning rate scheduler.

    Parameters
    ----------
    model : nn.Module
        PyTorch model to train.
    device : torch.device
        Device for computation.
    optimizer : torch.optim.Optimizer
        Optimizer instance.
    criterion : nn.Module
        Loss function.
    scheduler : torch.optim.lr_scheduler._LRScheduler, optional
        Learning rate scheduler.
    grad_accum_steps : int
        Number of steps for gradient accumulation.
    max_grad_norm : float
        Maximum gradient norm for clipping.
    """
    def __init__(self, model, device, optimizer, criterion, scheduler=None, grad_accum_steps=2, max_grad_norm=1.0):
        self.model = model.to(device)
        self.device = device
        self.optimizer = optimizer
        self.criterion = criterion
        self.scheduler = scheduler
        self.grad_accum_steps = grad_accum_steps
        self.max_grad_norm = max_grad_norm
        self.best_metrics = {'val_loss': float('inf'), 'val_acc': 0, 'epoch': 0}

    def _to_device(self, batch):
        """Move batch dictionary to device."""
        return {k: v.to(self.device) for k, v in batch.items()}

    def train_epoch(self, train_loader):
        """
        Train the model for one epoch.

        Parameters
        ----------
        train_loader : DataLoader
            Training dataloader.

        Returns
        -------
        float
            Average training loss.
        float
            Training accuracy.
        """
        self.model.train()
        total_loss, correct, total = 0, 0, 0
        for batch_idx, batch in enumerate(train_loader):
            batch = self._to_device(batch)
            outputs = self.model(
                batch['video_feat'], batch['text_feat'], batch['caption_feat'],
                video_mask=batch.get('video_mask', None),
                text_mask=batch.get('text_mask', None)
            )
            loss = self.criterion(outputs, batch['label']) / self.grad_accum_steps
            loss.backward()

            if (batch_idx + 1) % self.grad_accum_steps == 0:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
                self.optimizer.step()
                self.optimizer.zero_grad()

            total_loss += loss.item() * batch['label'].size(0)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == batch['label']).sum().item()
            total += batch['label'].size(0)
        return total_loss / total, correct / total

    @torch.no_grad()
    def validate(self, val_loader):
        """
        Evaluate the model on validation data.

        Parameters
        ----------
        val_loader : DataLoader
            Validation dataloader.

        Returns
        -------
        float
            Average validation loss.
        float
            Validation accuracy.
        """
        self.model.eval()
        total_loss, correct, total = 0, 0, 0
        for batch in val_loader:
            batch = self._to_device(batch)
            outputs = self.model(
                batch['video_feat'], 
                batch['text_feat'], 
                batch['caption_feat'],
                video_mask=batch.get('video_mask', None),
                text_mask=batch.get('text_mask', None)
            )
            loss = self.criterion(outputs, batch['label'])
            total_loss += loss.item() * batch['label'].size(0)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == batch['label']).sum().item()
            total += batch['label'].size(0)
        return total_loss / total, correct / total

    def train(self, train_loader, val_loader, num_epochs=25, patience=10, checkpoint_dir="checkpoints"):
        """
        Full training loop with early stopping and checkpoint saving.

        Parameters
        ----------
        train_loader : DataLoader
            Training dataloader.
        val_loader : DataLoader
            Validation dataloader.
        num_epochs : int
            Maximum number of training epochs.
        patience : int
            Early stopping patience.
        checkpoint_dir : str
            Directory to save checkpoints.

        Returns
        -------
        dict
            Best validation metrics including loss, accuracy, and epoch.
        """
        os.makedirs(checkpoint_dir, exist_ok=True)  
        checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pth')

        start_time = time.time()
        early_stopping = EarlyStopping(patience=patience, verbose=True, mode='min')

        for epoch in range(num_epochs):
            train_loss, train_acc = self.train_epoch(train_loader)
            val_loss, val_acc = self.validate(val_loader)

            if self.scheduler:
                self.scheduler.step(val_loss)

            if val_acc > self.best_metrics['val_acc']:
                self.best_metrics.update({'val_loss': val_loss, 'val_acc': val_acc, 'epoch': epoch + 1})
                torch.save({
                    'epoch': epoch + 1,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
                    'val_loss': val_loss,
                    'val_acc': val_acc
                }, checkpoint_path)

            print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | "
                  f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")

            if early_stopping(val_loss):
                print(f"Early stopping triggered at epoch {epoch+1}. Best Val Acc: {self.best_metrics['val_acc']:.4f}")
                break

        total_time = time.time() - start_time
        print(f"\nTraining completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m")
        print(f"Best Val Acc: {self.best_metrics['val_acc']:.4f} at epoch {self.best_metrics['epoch']}")

        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        return self.best_metrics

# ====================== 3-FOLD CV with mask support ======================
def run_3fold_cv(data_list, labels, create_dataloader, num_classes=3, batch_size=16, num_epochs=25, patience=10):
    """
    Run 3-fold cross-validation training for a multimodal model.

    Parameters
    ----------
    data_list : list
        List of data dictionaries.
    labels : np.ndarray or torch.Tensor
        Array of class labels.
    create_dataloader : callable
        Function to create dataloaders from data and labels.
    num_classes : int
        Number of classes.
    batch_size : int
        Batch size.
    num_epochs : int
        Maximum number of training epochs.
    patience : int
        Early stopping patience.

    Returns
    -------
    list
        List of trained models for each fold.
    callable
        Ensemble prediction function for a batch.
    """
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)
    fold_models = []

    for fold, (train_idx, val_idx) in enumerate(skf.split(data_list, labels)):
        print(f"\n=== Fold {fold+1}/3 ===")

        data_train = [data_list[i] for i in train_idx]
        labels_train = labels[train_idx]
        data_val = [data_list[i] for i in val_idx]
        labels_val = labels[val_idx]

        train_loader = create_dataloader(data_train, labels_train, batch_size=batch_size, shuffle=True, augment=True)
        val_loader = create_dataloader(data_val, labels_val, batch_size=batch_size, shuffle=False, augment=False)

        model = MultimodalModel(
            video_dim=256,
            text_dim=256,
            caption_dim=256,
            hidden_dim=256,
            num_classes=num_classes,
            r=6,
            finetune_mode='coattn_plus_latefusion_part'
        )

        # Class weights
        class_weights = torch.tensor([1.22, 0.90, 0.94], device=device)
        criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.08)

        # Optimizer + scheduler
        params_to_update = model.get_param_groups(lr_late=1.5e-4, lr_coattn=1.5e-4, lr_fc_ca1=5e-5, weight_decay=5e-4)
        optimizer = torch.optim.AdamW(params_to_update)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

        checkpoint_dir = f"/Users/jumita/Downloads/checkpoints_fold_{fold+1}"

        trainer = MultimodalTrainer(model=model, device=device, optimizer=optimizer,
                                     criterion=criterion, scheduler=scheduler, grad_accum_steps=2, max_grad_norm=1.0)

        trainer.train(train_loader, val_loader, num_epochs=num_epochs, patience=patience, checkpoint_dir=checkpoint_dir)

        checkpoint = torch.load(os.path.join(checkpoint_dir, 'best_model.pth'), map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        fold_models.append(model)

    print("\nEnsemble trained with 3 folds.")

    def ensemble_predict(batch):
        """
        Make predictions using the ensemble of 3 trained fold models.

        Parameters
        ----------
        batch : dict
            Batch dictionary with 'video_feat', 'text_feat', 'caption_feat'.

        Returns
        -------
        torch.Tensor
            Predicted class indices.
        """
        batch = {k: v.to(device) for k, v in batch.items()}
        probs = []
        with torch.no_grad():
            for model in fold_models:
                output = model(
                    batch['video_feat'], batch['text_feat'], batch['caption_feat'],
                    video_mask=batch.get('video_mask', None),
                    text_mask=batch.get('text_mask', None)
                )
                probs.append(torch.softmax(output, dim=1))
        avg_probs = torch.stack(probs).mean(dim=0)
        return avg_probs.argmax(dim=1)

    return fold_models, ensemble_predict


best_model = run_3fold_cv(data_list, labels, create_dataloader) 


=== Fold 1/3 ===
Epoch 1/25 | Train Loss: 0.5100, Acc: 0.5033 | Val Loss: 0.9212, Acc: 0.6007
Epoch 2/25 | Train Loss: 0.4574, Acc: 0.6070 | Val Loss: 0.8808, Acc: 0.6263
Epoch 3/25 | Train Loss: 0.4320, Acc: 0.6391 | Val Loss: 0.8575, Acc: 0.6506
Epoch 4/25 | Train Loss: 0.4110, Acc: 0.6647 | Val Loss: 0.8393, Acc: 0.6625
Epoch 5/25 | Train Loss: 0.3963, Acc: 0.6920 | Val Loss: 0.8402, Acc: 0.6661
EarlyStopping counter: 1 / 10
Epoch 6/25 | Train Loss: 0.3876, Acc: 0.6971 | Val Loss: 0.8438, Acc: 0.6702
EarlyStopping counter: 2 / 10
Epoch 7/25 | Train Loss: 0.3773, Acc: 0.7090 | Val Loss: 0.8305, Acc: 0.6649
Epoch 8/25 | Train Loss: 0.3668, Acc: 0.7283 | Val Loss: 0.8286, Acc: 0.6809
Epoch 9/25 | Train Loss: 0.3579, Acc: 0.7396 | Val Loss: 0.8417, Acc: 0.6673
EarlyStopping counter: 1 / 10
Epoch 10/25 | Train Loss: 0.3478, Acc: 0.7497 | Val Loss: 0.8430, Acc: 0.6738
EarlyStopping counter: 2 / 10
Epoch 11/25 | Train Loss: 0.3428, Acc: 0.7666 | Val Loss: 0.8392, Acc: 0.6762
EarlyStopping

In [None]:
import torch
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import logging

logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("peft").setLevel(logging.ERROR)

def test_ensemble(fold_models, test_loader, device):
    """
    Evaluate an ensemble of trained models on a test dataset.

    Parameters
    ----------
    fold_models : list of nn.Module
        List of trained fold models.
    test_loader : DataLoader
        Dataloader for test dataset.
    device : torch.device
        Device to run inference on (CPU/GPU/MPS).

    Returns
    -------
    tuple
        Accuracy, weighted precision, weighted recall, weighted F1-score.
    """
    all_preds = []
    all_labels = []

    for batch in test_loader:
        batch = {k: v.to(device) for k,v in batch.items()}
        probs = []
        with torch.no_grad():
            for model in fold_models:
                model.eval()
                output = model(batch['video_feat'], batch['text_feat'], batch['caption_feat'])
                probs.append(torch.softmax(output, dim=1))
        avg_probs = torch.stack(probs).mean(dim=0)
        predicted = avg_probs.argmax(dim=1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(batch['label'].cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    prec = precision_score(all_labels, all_preds, average='weighted')
    rec = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')
    print(f"Ensemble Test | Acc: {acc:.4f}, Prec: {prec:.4f}, Rec: {rec:.4f}, F1: {f1:.4f}")
    return acc, prec, rec, f1


# ================= DEVICE =================
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# ================= CHECKPOINTS =================
model_paths = [
    "/Users/jumita/Downloads/checkpoints_fold_1/best_model.pth",
    "/Users/jumita/Downloads/checkpoints_fold_2/best_model.pth",
    "/Users/jumita/Downloads/checkpoints_fold_3/best_model.pth"
]

# ================= MODEL FACTORY =================
def create_multimodal_model():
    """
    Create a fresh instance of the MultimodalModel with the same architecture
    as used during training.
    """
    return MultimodalModel(
        video_dim=256,
        text_dim=256,
        caption_dim=256,
        hidden_dim=256,
        num_classes=3,
        r=6,  
        finetune_mode='coattn_plus_latefusion_part'
    )

# ================= LOAD MODELS =================
fold_models = []
for path in model_paths:
    checkpoint = torch.load(path, map_location=device)  # load checkpoint 
    model = create_multimodal_model()
    model.load_state_dict(checkpoint["model_state_dict"])  # load weight
    model.to(device)
    model.eval()
    fold_models.append(model)

# ================= ENSEMBLE TEST =================
acc, prec, rec, f1 = test_ensemble(fold_models, test_loader, device)
print("Final Ensemble Results:", acc, prec, rec, f1)


Ensemble Test | Acc: 0.8099, Prec: 0.8165, Rec: 0.8099, F1: 0.8097
Final Ensemble Results: 0.80990099009901 0.816537841950824 0.80990099009901 0.8096565291857968
