In [35]:
import os
import sys
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.model_selection import train_test_split
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

sys.path.append("/Users/jumita/Downloads/FakingRecipe-main/model/")

# ================= Device =================
device = "mps" if torch.cuda.is_available() else "cpu"

# ================= Dataset =================
class VideoTextCaptionDataset(Dataset):
    """
    PyTorch Dataset for loading video, text, and caption features
    along with their corresponding labels.
    """
    def __init__(self, data_list, labels):
        """
        Args:
            data_list (list of dict): Each dict contains paths to video, text, and caption features.
            labels (array-like): Labels corresponding to each data sample.
        """
        self.data_list = data_list
        self.labels = labels

    def process_feat(self, feat):
        """
        Process a feature array by flattening or averaging depending on its shape.

        Args:
            feat (np.ndarray): Input feature array of shape (256,), (seq, 256), or (seq, 197, 256).

        Returns:
            torch.Tensor: Processed 1D feature tensor of size 256.
        """
        feat = np.array(feat)
        if feat.ndim == 1:  # (256,)
            return torch.from_numpy(feat).float()
        elif feat.ndim == 2:  # (seq, 256)
            return torch.from_numpy(feat.mean(axis=0)).float()
        elif feat.ndim == 3:  # (seq, 197, 256)
            seq_len = feat.shape[0] * feat.shape[1]
            feat = feat.reshape(seq_len, feat.shape[2])
            return torch.from_numpy(feat.mean(axis=0)).float()
        else:
            raise ValueError(f"Unexpected feature shape: {feat.shape}")

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        """
        Returns the features and label of a given index.

        Args:
            idx (int): Index of the data sample.

        Returns:
            tuple: (video_feat, text_feat, caption_feat, label) as tensors.
        """
        item = self.data_list[idx]
        video_feat = self.process_feat(np.load(item["video_feat"]))
        text_feat = self.process_feat(np.load(item["text_feat"]))
        caption_feat = self.process_feat(np.load(item["caption_feat"]))
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return video_feat, text_feat, caption_feat, label

# ================= Load or Create JSON =================
def load_or_create_data(json_path, excel_path, labels_path, video_dir, text_dir, caption_dir):
    """
    Load data list from JSON or create it from Excel and feature directories.

    Args:
        json_path (str): Path to JSON file storing data list.
        excel_path (str): Path to Excel file containing video IDs.
        labels_path (str): Path to .npy file containing labels.
        video_dir (str): Directory containing video feature .npy files.
        text_dir (str): Directory containing text feature .npy files.
        caption_dir (str): Directory containing caption feature .npy files.

    Returns:
        tuple: (data_list, labels)
    """
    labels = np.load(labels_path)
    if os.path.exists(json_path):
        with open(json_path, "r") as f:
            data_list = json.load(f)
    else:
        df = pd.read_excel(excel_path)
        data_list = []
        for vid in df['video_id'].astype(str):
            data_list.append({
                "video_feat": os.path.join(video_dir, f"{vid}.npy"),
                "text_feat": os.path.join(text_dir, f"{vid}.npy"),
                "caption_feat": os.path.join(caption_dir, f"{vid}.npy")
            })
        os.makedirs(os.path.dirname(json_path), exist_ok=True)
        with open(json_path, "w") as f:
            json.dump(data_list, f)
    return data_list, labels

# ================= Paths =================
json_path    = "/Users/jumita/Downloads/final_code/data.json"
excel_path   = "/Users/jumita/Downloads/Book5.xlsx"
labels_path  = "/Users/jumita/Downloads/final_code/labels.npy"
video_dir    = "/Users/jumita/Downloads/final/frame"
text_dir     = "/Users/jumita/Downloads/final/text"
caption_dir  = "/Users/jumita/Downloads/final/caption"

data_list, labels = load_or_create_data(json_path, excel_path, labels_path, video_dir, text_dir, caption_dir)

# ================= Remove duplicates =================
def remove_duplicates(data_list, labels):
    """
    Remove duplicate video entries to avoid repeated data.

    Args:
        data_list (list of dict): List containing feature paths.
        labels (np.ndarray): Corresponding labels.

    Returns:
        tuple: (unique_data_list, unique_labels)
    """
    seen = set()
    unique_data = []
    unique_labels = []
    for item, label in zip(data_list, labels):
        if item["video_feat"] not in seen:
            seen.add(item["video_feat"])
            unique_data.append(item)
            unique_labels.append(label)
    return unique_data, np.array(unique_labels)

data_list, labels = remove_duplicates(data_list, labels)

# ================= Split =================
data_train, data_temp, labels_train, labels_temp = train_test_split(
    data_list, labels, test_size=0.2, random_state=42, stratify=labels
)

try:
    data_val, data_test, labels_val, labels_test = train_test_split(
        data_temp, labels_temp, test_size=0.5, random_state=42, stratify=labels_temp
    )
except ValueError:
    data_val, data_test, labels_val, labels_test = train_test_split(
        data_temp, labels_temp, test_size=0.5, random_state=42, stratify=None
    )

# ================= DataLoader =================
train_loader = DataLoader(VideoTextCaptionDataset(data_train, labels_train), batch_size=16, shuffle=True)
val_loader   = DataLoader(VideoTextCaptionDataset(data_val, labels_val), batch_size=16, shuffle=False)
test_loader  = DataLoader(VideoTextCaptionDataset(data_test, labels_test), batch_size=16, shuffle=False)


In [37]:
import torch
import torch.nn as nn
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader

# ================= VIDEO FEATURE AUGMENTER =================
class VideoFeatureAugmenter:
    """
    Data augmenter for video features.

    This class applies temporal cropping and adaptive Gaussian noise 
    to video features to improve model generalization during training.

    Attributes
    ----------
    augment_prob : float
        Probability of applying augmentation.
    dropout : nn.Dropout
        Dropout layer used for augmentation (optional, not used here).
    noise_scale : float
        Scaling factor for Gaussian noise.

    Methods
    -------
    temporal_crop(features, crop_ratio=0.9):
        Randomly crops the sequence along the temporal dimension.
    adaptive_noise(features):
        Adds Gaussian noise proportional to the feature's standard deviation.
    __call__(features):
        Applies augmentation with probability `augment_prob`.
    """

    def __init__(self, augment_prob=0.5, dropout_prob=0.3, noise_scale=0.05):
        self.augment_prob = augment_prob
        self.dropout = nn.Dropout(p=dropout_prob)
        self.noise_scale = noise_scale

    def temporal_crop(self, features, crop_ratio=0.9):
        """
        Crop the feature sequence randomly along the temporal dimension.

        Parameters
        ----------
        features : torch.Tensor
            Video feature tensor of shape (T, D) or (T, H, D).
        crop_ratio : float
            Fraction of the sequence to keep.

        Returns
        -------
        torch.Tensor
            Cropped feature tensor.
        """
        T = features.shape[0]
        new_T = int(T * crop_ratio)
        start = np.random.randint(0, T - new_T + 1)
        return features[start:start+new_T]

    def adaptive_noise(self, features):
        """
        Add Gaussian noise scaled by feature standard deviation.

        Parameters
        ----------
        features : torch.Tensor
            Feature tensor.

        Returns
        -------
        torch.Tensor
            Noisy feature tensor.
        """
        std_per_feature = features.std(dim=0, keepdim=True) + 1e-6
        noise = torch.randn_like(features) * std_per_feature * self.noise_scale
        return features + noise

    def __call__(self, features):
        """
        Apply augmentation with probability `augment_prob`.

        Parameters
        ----------
        features : torch.Tensor
            Input feature tensor.

        Returns
        -------
        torch.Tensor
            Augmented feature tensor.
        """
        if torch.rand(1) < self.augment_prob:
            if features.ndim >= 2:
                features = self.temporal_crop(features)
            features = self.adaptive_noise(features)
        return features

# ================= TEXT/CAPTION FEATURE AUGMENTER =================
class TextFeatureAugmenter:
    """
    Data augmenter for text or caption features.

    Applies dropout and adaptive Gaussian noise to embeddings.

    Attributes
    ----------
    augment_prob : float
        Probability of applying augmentation.
    dropout : nn.Dropout
        Dropout layer applied to features.
    noise_scale : float
        Scaling factor for Gaussian noise.

    Methods
    -------
    adaptive_noise(features):
        Adds Gaussian noise proportional to feature's std.
    __call__(features):
        Applies augmentation with probability `augment_prob`.
    """
    def __init__(self, augment_prob=0.3, dropout_prob=0.05, noise_scale=0.02):
        self.augment_prob = augment_prob
        self.dropout = nn.Dropout(p=dropout_prob)
        self.noise_scale = noise_scale

    def adaptive_noise(self, features):
        """
        Add Gaussian noise scaled by feature standard deviation.

        Parameters
        ----------
        features : torch.Tensor
            Input feature tensor.

        Returns
        -------
        torch.Tensor
            Noisy feature tensor.
        """
        std_per_feature = features.std(dim=0, keepdim=True) + 1e-6
        noise = torch.randn_like(features) * std_per_feature * self.noise_scale
        return features + noise

    def __call__(self, features):
        """
        Apply dropout and noise augmentation with probability `augment_prob`.

        Parameters
        ----------
        features : torch.Tensor
            Input feature tensor.

        Returns
        -------
        torch.Tensor
            Augmented feature tensor.
        """
        if torch.rand(1) < self.augment_prob:
            features = self.dropout(features)
            features = self.adaptive_noise(features)
        return features


# ================= DATASET =================
class VideoTextCaptionDataset(Dataset):
    """
    PyTorch Dataset for multimodal video, text, and caption features.

    Handles optional data augmentation and per-sample feature normalization.

    Parameters
    ----------
    data_list : list of dict
        List of dictionaries containing paths for 'video_feat', 'text_feat', 'caption_feat'.
    labels : np.ndarray
        Array of integer labels corresponding to each sample.
    augment : bool, optional
        Whether to apply data augmentation (default: False).

    Methods
    -------
    __getitem__(idx):
        Loads, optionally augments, and returns a single sample.
    __len__():
        Returns the total number of samples.
    """
    def __init__(self, data_list, labels, augment=False):
        self.data_list = data_list
        self.labels = labels
        self.augment = augment
        self.video_aug = VideoFeatureAugmenter() if augment else None
        self.text_aug = TextFeatureAugmenter() if augment else None

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = self.data_list[idx]

        video_feat = torch.from_numpy(np.load(item["video_feat"])).float()
        text_feat = torch.from_numpy(np.load(item["text_feat"])).float()
        caption_feat = torch.from_numpy(np.load(item["caption_feat"])).float()

        if self.augment:
            if self.video_aug and video_feat.ndim >= 2:
                video_feat = self.video_aug(video_feat)
            if self.text_aug:
                text_feat = self.text_aug(text_feat)
                caption_feat = self.text_aug(caption_feat)

        # Flatten
        if video_feat.ndim == 3:
            video_feat = video_feat.mean(dim=(0,1))
        elif video_feat.ndim == 2:
            video_feat = video_feat.mean(dim=0)
        
        text_feat = text_feat if text_feat.ndim == 1 else text_feat.mean(dim=0)
        caption_feat = caption_feat if caption_feat.ndim == 1 else caption_feat.mean(dim=0)

        # Normalize
        video_feat = (video_feat - video_feat.mean()) / (video_feat.std() + 1e-6)
        text_feat = (text_feat - text_feat.mean()) / (text_feat.std() + 1e-6)
        caption_feat = (caption_feat - caption_feat.mean()) / (caption_feat.std() + 1e-6)

        label = torch.tensor(self.labels[idx], dtype=torch.long)

        return {
            "video_feat": video_feat,
            "text_feat": text_feat,
            "caption_feat": caption_feat,
            "label": label
        }

def create_dataloader(data_list, labels, batch_size=16, shuffle=True, augment=False):
    """
    Create a PyTorch DataLoader for the multimodal dataset.

    Parameters
    ----------
    data_list : list of dict
        List of dictionaries containing feature file paths.
    labels : np.ndarray
        Array of labels.
    batch_size : int, optional
        Batch size (default: 16).
    shuffle : bool, optional
        Whether to shuffle data (default: True).
    augment : bool, optional
        Whether to apply data augmentation (default: False).

    Returns
    -------
    DataLoader
        PyTorch DataLoader for the dataset.
    """
    dataset = VideoTextCaptionDataset(data_list, labels, augment=augment)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

# ================= LOAD DATA =================
data_list, labels = load_or_create_data(
    json_path="/Users/jumita/Downloads/final_code/data.json",
    excel_path="/Users/jumita/Downloads/Book5.xlsx",
    labels_path="/Users/jumita/Downloads/final_code/labels.npy",
    video_dir="/Users/jumita/Downloads/final/frame",
    text_dir="/Users/jumita/Downloads/final/text",
    caption_dir="/Users/jumita/Downloads/final/caption"
)

# ================= SPLIT DATA =================
data_train, data_temp, labels_train, labels_temp = train_test_split(
    data_list, labels, test_size=0.2, random_state=42, stratify=labels
)

try:
    data_val, data_test, labels_val, labels_test = train_test_split(
        data_temp, labels_temp, test_size=0.5, random_state=42, stratify=labels_temp
    )
except ValueError:
    data_val, data_test, labels_val, labels_test = train_test_split(
        data_temp, labels_temp, test_size=0.5, random_state=42, stratify=None
    )

def check_overlap(set1, set2):
    s1 = set([item["video_feat"] for item in set1])
    s2 = set([item["video_feat"] for item in set2])
    return len(s1 & s2)

print("Overlap train-val:", check_overlap(data_train, data_val))
print("Overlap train-test:", check_overlap(data_train, data_test))
print("Overlap val-test:", check_overlap(data_val, data_test))


# ================= DATALOADERS =================
train_loader = create_dataloader(data_train, labels_train, batch_size=16, shuffle=True, augment=True)
val_loader = create_dataloader(data_val, labels_val, batch_size=16, shuffle=False)
test_loader = create_dataloader(data_test, labels_test, batch_size=16, shuffle=False)

# ================= DEBUG SAMPLE =================
sample = next(iter(train_loader))
print("Augmented sample shapes:")
print("Video:", sample["video_feat"].shape)
print("Text:", sample["text_feat"].shape)
print("Caption:", sample["caption_feat"].shape) 



Overlap train-val: 0
Overlap train-test: 0
Overlap val-test: 0
Augmented sample shapes:
Video: torch.Size([16, 256])
Text: torch.Size([16, 256])
Caption: torch.Size([16, 256])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ================= StochasticDepth =================
class StochasticDepth(nn.Module):
    """
    Implements Stochastic Depth (also called DropPath) regularization.

    During training, randomly drops entire residual paths with probability `drop_prob`.

    Parameters
    ----------
    drop_prob : float
        Probability of dropping a path.
    """
    def __init__(self, drop_prob):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        """
        Forward pass with stochastic depth.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape (B, ...).

        Returns
        -------
        torch.Tensor
            Output tensor, either dropped or scaled.
        """
        if not self.training or self.drop_prob == 0.0:
            return x
        keep_prob = 1 - self.drop_prob
        mask = (torch.rand(x.shape[0], device=x.device) < keep_prob).to(x.dtype)
        mask = mask.view(-1, *([1] * (x.dim() - 1)))
        return x * mask / keep_prob

# ================= LoRA Linear =================
class LoRALinear(nn.Module):
    """
    Linear layer with Low-Rank Adaptation (LoRA).

    Parameters
    ----------
    in_features : int
        Input feature dimension.
    out_features : int
        Output feature dimension.
    r : int
        Rank of the LoRA adaptation matrices.
    alpha : int
        Scaling factor for LoRA updates.
    """
    def __init__(self, in_features, out_features, r=6, alpha=6):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.r = r
        self.alpha = alpha
        self.scaling = self.alpha / max(1, r)
        self.linear = nn.Linear(in_features, out_features)
        if r > 0:
            self.lora_A = nn.Parameter(torch.randn(r, in_features) * 1e-3)
            self.lora_B = nn.Parameter(torch.zeros(out_features, r))
        else:
            self.lora_A = None
            self.lora_B = None

    def forward(self, x):
        """
        Forward pass through LoRA linear layer.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape (B, in_features).

        Returns
        -------
        torch.Tensor
            Output tensor of shape (B, out_features).
        """
        result = self.linear(x)
        if self.r > 0:
            lora_update = (x @ self.lora_A.T) @ self.lora_B.T
            result += lora_update * self.scaling
        return result

# ================= LoRA MultiheadAttention =================
class LoRAMultiheadAttention(nn.Module):
    """
    Multihead Attention layer with LoRA adaptation.

    Parameters
    ----------
    embed_dim : int
        Dimension of embeddings.
    num_heads : int
        Number of attention heads.
    dropout : float
        Dropout probability in attention.
    batch_first : bool
        If True, input shape is (B, S, D).
    r : int
        LoRA rank.
    alpha : int
        LoRA scaling factor.
    """
    def __init__(self, embed_dim, num_heads, dropout=0.0, batch_first=True, r=6, alpha=6):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.r = r
        self.alpha = alpha
        self.scaling = self.alpha / max(1, r)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=batch_first)
        if r > 0:
            self.lora_A_q = nn.Parameter(torch.randn(r, embed_dim) * 1e-3)
            self.lora_B_q = nn.Parameter(torch.zeros(embed_dim, r))
            self.lora_A_k = nn.Parameter(torch.randn(r, embed_dim) * 1e-3)
            self.lora_B_k = nn.Parameter(torch.zeros(embed_dim, r))
            self.lora_A_v = nn.Parameter(torch.randn(r, embed_dim) * 1e-3)
            self.lora_B_v = nn.Parameter(torch.zeros(embed_dim, r))
            self.lora_A_out = nn.Parameter(torch.randn(r, embed_dim) * 1e-3)
            self.lora_B_out = nn.Parameter(torch.zeros(embed_dim, r))

    def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None):
        """
        Forward pass through LoRA multihead attention.

        Parameters
        ----------
        query, key, value : torch.Tensor
            Input tensors of shape (B, S, D).
        key_padding_mask : torch.Tensor, optional
            Mask for padded tokens.
        need_weights : bool
            If True, returns attention weights.
        attn_mask : torch.Tensor, optional
            Attention mask.

        Returns
        -------
        attn_output : torch.Tensor
            Attention output tensor.
        attn_weights : torch.Tensor
            Attention weights (if `need_weights` is True).
        """
        if self.r > 0:
            query = query + (query @ self.lora_A_q.T) @ self.lora_B_q.T * self.scaling
            key   = key   + (key   @ self.lora_A_k.T) @ self.lora_B_k.T * self.scaling
            value = value + (value @ self.lora_A_v.T) @ self.lora_B_v.T * self.scaling

        attn_output, attn_weights = self.attn(
            query, key, value,
            key_padding_mask=key_padding_mask,
            need_weights=need_weights,
            attn_mask=attn_mask
        )

        if self.r > 0:
            attn_output = attn_output + (attn_output @ self.lora_A_out.T) @ self.lora_B_out.T * self.scaling
        return attn_output, attn_weights

# ================= CoAttentionBlock =================
class CoAttentionBlock(nn.Module):
    """
    Co-Attention block between query and key/value features with LoRA and residual feed-forward.

    Parameters
    ----------
    dim_q : int
        Dimension of query features.
    dim_kv : int
        Dimension of key/value features.
    num_heads : int
        Number of attention heads.
    hidden_dim : int
        Hidden dimension for feed-forward network.
    dropout : float
        Dropout probability.
    lora_r : int
        LoRA rank.
    drop_path_prob : float
        Stochastic depth probability.
    """
    def __init__(self, dim_q=256, dim_kv=256, num_heads=8, hidden_dim=256, dropout=0.2, lora_r=6, drop_path_prob=0.2):
        super().__init__()
        self.query_proj = LoRALinear(dim_q, dim_q, r=lora_r)
        self.key_proj   = LoRALinear(dim_kv, dim_q, r=lora_r)
        self.value_proj = LoRALinear(dim_kv, dim_q, r=lora_r)
        self.attn = LoRAMultiheadAttention(dim_q, num_heads, dropout=dropout, batch_first=True, r=lora_r)
        self.gate = nn.Sequential(LoRALinear(dim_q * 2, dim_q, r=lora_r), nn.Sigmoid())
        self.ffn = nn.Sequential(
            LoRALinear(dim_q, hidden_dim, r=lora_r),
            nn.GELU(),
            nn.Dropout(dropout),
            LoRALinear(hidden_dim, dim_q, r=lora_r),
            nn.Dropout(dropout)
        )
        self.norm1 = nn.LayerNorm(dim_q)
        self.norm2 = nn.LayerNorm(dim_q)
        self.drop_path = StochasticDepth(drop_path_prob) if drop_path_prob > 0.0 else None

    def forward(self, query, key, value, key_padding_mask=None):
        """
        Forward pass of CoAttentionBlock.

        Parameters
        ----------
        query, key, value : torch.Tensor
            Input feature tensors.
        key_padding_mask : torch.Tensor, optional
            Mask for padded tokens.

        Returns
        -------
        torch.Tensor
            Output tensor after co-attention, gating, and feed-forward.
        """
        q = self.query_proj(query)
        k = self.key_proj(key)
        v = self.value_proj(value)
        attn_out, _ = self.attn(q, k, v, key_padding_mask=key_padding_mask)
        gate = self.gate(torch.cat([query, attn_out], dim=-1))
        out = self.norm1(query + gate * attn_out)
        ffn_out = self.ffn(out)
        if self.drop_path is not None:
            ffn_out = self.drop_path(ffn_out)
        return self.norm2(out + ffn_out)

# ================= CoAttentionLayer =================
class CoAttentionLayer(nn.Module):
    """
    Co-Attention layer for interaction between video and text features.

    It applies co-attention blocks in both directions:
    video-to-text and text-to-video, then pools the sequences
    and concatenates the outputs.

    Parameters
    ----------
    dim_q : int
        Dimension of query features (video features).
    dim_kv : int
        Dimension of key/value features (text features).
    num_heads : int
        Number of attention heads.
    hidden_dim : int
        Hidden dimension for feed-forward networks.
    dropout : float
        Dropout probability.
    lora_r : int
        LoRA rank.
    max_video_tokens : int
        Maximum number of video tokens to keep after pooling.
    """
    def __init__(self, dim_q=256, dim_kv=256, num_heads=8, hidden_dim=256, dropout=0.5, lora_r=6, max_video_tokens=32):
        super().__init__()
        self.video2text = CoAttentionBlock(dim_q, dim_kv, num_heads, hidden_dim, dropout, lora_r)
        self.text2video = CoAttentionBlock(dim_kv, dim_q, num_heads, hidden_dim, dropout, lora_r)
        self.max_video_tokens = max_video_tokens

    def forward(self, video_feat, text_feat, video_mask=None, text_mask=None):
        """
        Forward pass of co-attention layer.

        Parameters
        ----------
        video_feat : torch.Tensor
            Video features, shape (B, T, D) or (B, N, T, D) if patches.
        text_feat : torch.Tensor
            Text features, shape (B, S, D).
        video_mask : torch.Tensor, optional
            Mask for video tokens.
        text_mask : torch.Tensor, optional
            Mask for text tokens.

        Returns
        -------
        torch.Tensor
            Concatenated pooled co-attention features, shape (B, D_video + D_text).
        """
        B = video_feat.shape[0]

        if video_feat.dim() == 4:  
            T, N, D = video_feat.shape[1], video_feat.shape[2], video_feat.shape[3]
            video_feat = video_feat.view(B, T * N, D)
        elif video_feat.dim() == 2:
            video_feat = video_feat.unsqueeze(1)

        # Pooling video sequence
        if video_feat.shape[1] > self.max_video_tokens:
            video_feat = video_feat.transpose(1, 2)  # (B, D, Seq)
            video_feat = F.adaptive_avg_pool1d(video_feat, self.max_video_tokens)
            video_feat = video_feat.transpose(1, 2)

        if text_feat.dim() == 2:
            text_feat = text_feat.unsqueeze(1)

        # Forward co-attention with optional masks
        v2t = self.video2text(video_feat, text_feat, text_feat, key_padding_mask=text_mask)
        t2v = self.text2video(text_feat, video_feat, video_feat, key_padding_mask=video_mask)
        
        # Pool sequences
        v2t_pooled = v2t.mean(dim=1)
        t2v_pooled = t2v.mean(dim=1)
        return torch.cat([v2t_pooled, t2v_pooled], dim=-1)

# ================= LateFusionLoRA =================
class LateFusionLoRA(nn.Module):
    """
    Late fusion module with LoRA for combining two feature modalities.

    Applies two-layer MLP with LayerNorm, GELU activation, dropout,
    and optional stochastic depth, then outputs class logits.

    Parameters
    ----------
    dim_feat1 : int
        Dimension of first input feature.
    dim_feat2 : int
        Dimension of second input feature.
    dim_hidden : int
        Hidden layer dimension.
    num_classes : int
        Number of output classes.
    r : int
        LoRA rank.
    """
    def __init__(self, dim_feat1, dim_feat2, dim_hidden, num_classes, r=6):
        super().__init__()
        self.fc1 = LoRALinear(dim_feat1 + dim_feat2, dim_hidden * 2, r=r)
        self.norm1 = nn.LayerNorm(dim_hidden * 2)
        self.fc2 = LoRALinear(dim_hidden * 2, dim_hidden, r=r)
        self.norm2 = nn.LayerNorm(dim_hidden)
        self.fc3 = LoRALinear(dim_hidden, num_classes, r=r)
        self.dropout = nn.Dropout(0.35)
        self.drop_path = StochasticDepth(0.2) if dim_hidden > 256 else None

    def forward(self, x1, x2):
        """
        Forward pass of late fusion module.

        Parameters
        ----------
        x1, x2 : torch.Tensor
            Input feature tensors to fuse, shape (B, dim_feat1), (B, dim_feat2).

        Returns
        -------
        torch.Tensor
            Output class logits, shape (B, num_classes).
        """
        x = torch.cat([x1, x2], dim=-1)
        x = self.fc1(x)
        x = self.norm1(x)
        x = F.gelu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.norm2(x)
        x = F.gelu(x)
        if self.drop_path is not None:
            x = self.drop_path(x)
        return self.fc3(x)

# ================= MultimodalModel =================
class MultimodalModel(nn.Module):
    """
    Multimodal model combining co-attention between video and text features
    and late fusion with caption features, with LoRA finetuning support.

    Parameters
    ----------
    video_dim : int
        Dimension of video features.
    text_dim : int
        Dimension of text features.
    caption_dim : int
        Dimension of caption features.
    hidden_dim : int
        Hidden dimension for co-attention and fusion layers.
    num_classes : int
        Number of output classes.
    r : int
        LoRA rank.
    finetune_mode : str
        Finetuning mode: 'coattn_plus_latefusion', 'late_fusion_only', 'coattn_only', etc.
    dropout_prob : float
        Dropout probability for the fusion layer.
    """
    def __init__(self, video_dim=256, text_dim=256, caption_dim=256, hidden_dim=256,
                 num_classes=3, r=6, finetune_mode: str = 'coattn_plus_latefusion',
                 dropout_prob=0.1):
        super().__init__()
        self.ca1 = CoAttentionLayer(video_dim, text_dim, num_heads=8, hidden_dim=hidden_dim, lora_r=r)
        self.fc_ca1 = LoRALinear(video_dim * 2, video_dim, r=r)
        self.dropout = nn.Dropout(dropout_prob)   
        self.late_fusion = LateFusionLoRA(video_dim, caption_dim, hidden_dim, num_classes, r=r)
        self.set_trainable(finetune_mode)

    def set_trainable(self, finetune_mode: str = 'coattn_plus_latefusion'):
        """
        Set which parameters are trainable according to finetune_mode.

        Parameters
        ----------
        finetune_mode : str
            Mode of finetuning.
        """
        for p in self.parameters():
            p.requires_grad = False

        finetune_mode = finetune_mode.lower()

        def enable_lora(module):
            for name, p in module.named_parameters():
                if "lora_A" in name or "lora_B" in name:
                    p.requires_grad = True

        def unfreeze_part(module, names_to_unfreeze):
            for name, p in module.named_parameters():
                for target_name in names_to_unfreeze:
                    if target_name in name:
                        p.requires_grad = True

        if finetune_mode == 'late_fusion_only':
            enable_lora(self.late_fusion)
            enable_lora(self.fc_ca1)
        elif finetune_mode == 'coattn_only':
            enable_lora(self.ca1)
            enable_lora(self.fc_ca1)
        elif finetune_mode == 'coattn_plus_latefusion':
            enable_lora(self.ca1)
            enable_lora(self.late_fusion)
            enable_lora(self.fc_ca1)
        elif finetune_mode == 'coattn_plus_latefusion_part':
            enable_lora(self.ca1)
            enable_lora(self.late_fusion)
            enable_lora(self.fc_ca1)
            unfreeze_part(self.ca1.video2text.query_proj, ['linear'])
            unfreeze_part(self.late_fusion.fc1, ['linear'])
        elif finetune_mode == 'all':
            for p in self.parameters():
                p.requires_grad = True
        else:
            raise ValueError(f"Unknown finetune_mode: {finetune_mode}")

    def get_param_groups(self, lr_late=2e-4, lr_coattn=1.5e-4, lr_fc_ca1=5e-5, weight_decay=5e-4):
        """
        Return parameter groups with separate learning rates for optimizer.

        Parameters
        ----------
        lr_late, lr_coattn, lr_fc_ca1 : float
            Learning rates for respective modules.
        weight_decay : float
            Weight decay for optimizer.

        Returns
        -------
        list of dict
            Parameter groups for optimizer.
        """
        groups = []
        if any(p.requires_grad for p in self.late_fusion.parameters()):
            groups.append({'params': [p for p in self.late_fusion.parameters() if p.requires_grad], 'lr': lr_late, 'weight_decay': weight_decay})
        if any(p.requires_grad for p in self.ca1.parameters()):
            groups.append({'params': [p for p in self.ca1.parameters() if p.requires_grad], 'lr': lr_coattn, 'weight_decay': weight_decay})
        if any(p.requires_grad for p in self.fc_ca1.parameters()):
            groups.append({'params': [p for p in self.fc_ca1.parameters() if p.requires_grad], 'lr': lr_fc_ca1, 'weight_decay': weight_decay})
        return groups

    def forward(self, video_feat, text_feat, caption_feat, video_mask=None, text_mask=None):
        """
        Forward pass of the multimodal model.

        Parameters
        ----------
        video_feat : torch.Tensor
            Video features.
        text_feat : torch.Tensor
            Text features.
        caption_feat : torch.Tensor
            Caption features.
        video_mask : torch.Tensor, optional
            Mask for video tokens.
        text_mask : torch.Tensor, optional
            Mask for text tokens.

        Returns
        -------
        torch.Tensor
            Class logits for each sample.
        """
        ca1_out = self.ca1(video_feat, text_feat, video_mask=video_mask, text_mask=text_mask)
        ca1_out = self.fc_ca1(ca1_out)
        ca1_out = self.dropout(ca1_out)
        return self.late_fusion(ca1_out, caption_feat)

In [40]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import numpy as np
from torch.utils.data import DataLoader

# ================= Device =================
device = "mps" if torch.cuda.is_available() else "cpu"

# ================= Label mapping =================
unique_labels = ["0", "1", "2"]
label2id = {name: i for i, name in enumerate(unique_labels)}

# ================= Placeholder feature extractors =================
def extract_video_feature(video_str):
    """
    Simulate video feature extraction.

    Args:
        video_str (str): Placeholder video identifier.

    Returns:
        torch.Tensor: Random 256-dimensional video feature vector.
    """
    return torch.randn(256, dtype=torch.float32)

def extract_text_embedding(text_str):
    """
    Simulate text or caption feature extraction.

    Args:
        text_str (str): Placeholder text identifier.

    Returns:
        torch.Tensor: Random 256-dimensional text feature vector.
    """
    return torch.randn(256, dtype=torch.float32)

train_dataset = VideoTextCaptionDataset(data_train, labels_train)
val_dataset   = VideoTextCaptionDataset(data_val, labels_val)
test_dataset  = VideoTextCaptionDataset(data_test, labels_test)


# ================= Collate function =================
def collate_fn(batch):
    """
    Collate a batch of dataset items into tensors suitable for DataLoader.

    Args:
        batch (list): List of dataset items (video_feat, text_feat, caption_feat, label).

    Returns:
        tuple: Batched tensors (video_feats, text_feats, caption_feats, labels)
    """
    video_feats = torch.stack([item["video_feat"] for item in batch])
    text_feats = torch.stack([item["text_feat"] for item in batch])
    caption_feats = torch.stack([item["caption_feat"] for item in batch])
    labels = torch.tensor([item["label"] for item in batch], dtype=torch.long)
    return video_feats, text_feats, caption_feats, labels

# ================= Baselines =================
class SVFENDBaseline(nn.Module):
    """
    Simple MLP baseline combining video, text, and caption features.
    """
    def __init__(self, feature_dim=256, num_classes=3):
        """
        Forward pass.

        Args:
            video (torch.Tensor): Video features.
            text (torch.Tensor): Text features.
            caption (torch.Tensor): Caption features.

        Returns:
            torch.Tensor: Class logits.
        """
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(feature_dim*3, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )
    def forward(self, video, text, caption):
        x = torch.cat([video, text, caption], dim=1)
        return self.mlp(x)

class MiniMMBT(nn.Module):
    """
    Mini Multimodal Bitransformer (MiniMMBT) baseline.
    Combines video and text sequences via a Transformer encoder.
    """
    def __init__(self, feature_dim=256, num_classes=3, num_layers=2):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=feature_dim, nhead=8)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(feature_dim, num_classes)
    def forward(self, video, text):
        """
        Forward pass for MiniMMBT.

        Args:
            video (torch.Tensor): Video features of shape (batch, seq_len, feature_dim) or (batch, feature_dim)
            text (torch.Tensor): Text features of shape (batch, seq_len, feature_dim) or (batch, feature_dim)

        Returns:
            torch.Tensor: Class logits.
        """
        if video.dim() == 2: video = video.unsqueeze(1)
        if text.dim() == 2: text = text.unsqueeze(1)
        x = torch.cat([video, text], dim=1)
        x = self.transformer(x)
        cls_token = x[:,0,:]
        return self.fc(cls_token)

class LateFusionMLP(nn.Module):
    """
    Late fusion baseline that averages temporal features before MLP classification.
    """
    def __init__(self, feature_dim=256, num_classes=3):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(feature_dim*3, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )
    def forward(self, video, text, caption):
        """
        Forward pass for late fusion.

        Args:
            video (torch.Tensor): Video features (batch, seq_len, feature_dim) or (batch, feature_dim)
            text (torch.Tensor): Text features (batch, seq_len, feature_dim) or (batch, feature_dim)
            caption (torch.Tensor): Caption features (batch, feature_dim)

        Returns:
            torch.Tensor: Class logits.
        """
        if video.dim() == 3: video = video.mean(dim=1)
        if text.dim() == 3: text = text.mean(dim=1)
        x = torch.cat([video, text, caption], dim=1)
        return self.mlp(x)

# ================= Training / Testing =================
def train_model(model, train_loader, val_loader, num_epochs=15, lr=2e-4):
    """
    Train a PyTorch model.

    Args:
        model (nn.Module): Model to train.
        train_loader (DataLoader): Training DataLoader.
        val_loader (DataLoader): Validation DataLoader.
        num_epochs (int): Number of epochs.
        lr (float): Learning rate.

    Returns:
        nn.Module: Trained model.
    """
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for video, text, caption, label in train_loader:
            video, text, caption, label = video.to(device), text.to(device), caption.to(device), label.to(device)
            optimizer.zero_grad()
            if isinstance(model, MiniMMBT):
                output = model(video, text)
            else:
                output = model(video, text, caption)
            loss = criterion(output, label)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        train_loss = total_loss / len(train_loader)

        # Validation
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for video, text, caption, label in val_loader:
                video, text, caption, label = video.to(device), text.to(device), caption.to(device), label.to(device)
                if isinstance(model, MiniMMBT):
                    output = model(video, text)
                else:
                    output = model(video, text, caption)
                preds = output.argmax(dim=-1)
                correct += (preds == label).sum().item()
                total += label.size(0)
        val_acc = correct / total
        print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss:.4f}, Val Acc: {val_acc:.4f}")
    return model

def test_model(model, test_loader):
    """
    Test a trained PyTorch model.

    Args:
        model (nn.Module): Trained model.
        test_loader (DataLoader): Test DataLoader.

    Returns:
        tuple: (accuracy, f1, precision, recall)
    """
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for video, text, caption, label in test_loader:
            video, text, caption, label = video.to(device), text.to(device), caption.to(device), label.to(device)
            if isinstance(model, MiniMMBT):
                output = model(video, text)
            else:
                output = model(video, text, caption)
            preds = output.argmax(dim=-1)
            all_preds.append(preds.cpu())
            all_labels.append(label.cpu())
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro')
    prec = precision_score(all_labels, all_preds, average='macro')
    rec = recall_score(all_labels, all_preds, average='macro')
    return acc, f1, prec, rec

def train_logistic_regression(train_loader, val_loader):
    """
    Train a logistic regression classifier on extracted features.

    Args:
        train_loader (DataLoader): Training data loader.
        val_loader (DataLoader): Validation data loader.

    Returns:
        tuple: (accuracy, f1, precision, recall) on validation set.
    """
    X_train, y_train = [], []
    X_val, y_val = [], []
    for video, text, caption, label in train_loader:
        X_train.append(torch.cat([video, text, caption], dim=1).cpu().numpy())
        y_train.append(label.cpu().numpy())
    for video, text, caption, label in val_loader:
        X_val.append(torch.cat([video, text, caption], dim=1).cpu().numpy())
        y_val.append(label.cpu().numpy())

    X_train = np.concatenate(X_train, axis=0)
    y_train = np.concatenate(y_train, axis=0)
    X_val = np.concatenate(X_val, axis=0)
    y_val = np.concatenate(y_val, axis=0)

    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_val = scaler.transform(X_val)

    clf = LogisticRegression(max_iter=1000, multi_class='multinomial')
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_val)

    acc = accuracy_score(y_val, y_pred)
    f1 = f1_score(y_val, y_pred, average='macro')
    prec = precision_score(y_val, y_pred, average='macro')
    rec = recall_score(y_val, y_pred, average='macro')
    return acc, f1, prec, rec

# ================= Run all baselines =================
def run_all_baselines(train_loader, val_loader, test_loader, feature_dim=256, num_classes=3):
    """
    Run all baseline models: SV-FEND, MiniMMBT, LateFusionMLP, and Logistic Regression.

    Args:
        train_loader (DataLoader): Training data loader.
        val_loader (DataLoader): Validation data loader.
        test_loader (DataLoader): Test data loader.
        feature_dim (int): Feature dimensionality.
        num_classes (int): Number of classes.

    Returns:
        dict: Dictionary of results for each baseline with metrics (acc, f1, precision, recall).
    """
    results = {}

    print("=== Training SV-FEND-like ===")
    svfend = SVFENDBaseline(feature_dim, num_classes)
    svfend = train_model(svfend, train_loader, val_loader)
    results['SV-FEND'] = test_model(svfend, test_loader)

    print("\n=== Training MiniMMBT ===")
    mmbt = MiniMMBT(feature_dim, num_classes)
    mmbt = train_model(mmbt, train_loader, val_loader)
    results['MiniMMBT'] = test_model(mmbt, test_loader)

    print("\n=== Training LateFusionMLP ===")
    lfm = LateFusionMLP(feature_dim, num_classes)
    lfm = train_model(lfm, train_loader, val_loader)
    results['LateFusionMLP'] = test_model(lfm, test_loader)

    print("\n=== Logistic Regression ===")
    results['LogisticRegression'] = train_logistic_regression(train_loader, val_loader)

    print("\n=== Benchmark Results ===")
    for k, v in results.items():
        acc, f1, prec, rec = v
        print(f"{k}: Acc={acc:.4f}, F1={f1:.4f}, Prec={prec:.4f}, Rec={rec:.4f}")

    return results

# ================= USAGE EXAMPLE =================
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)
test_loader  = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

results = run_all_baselines(train_loader, val_loader, test_loader)


=== Training SV-FEND-like ===
Epoch 1/15 - Train Loss: 0.9079, Val Acc: 0.6099
Epoch 2/15 - Train Loss: 0.7645, Val Acc: 0.6455
Epoch 3/15 - Train Loss: 0.6889, Val Acc: 0.6495
Epoch 4/15 - Train Loss: 0.6335, Val Acc: 0.6495
Epoch 5/15 - Train Loss: 0.5877, Val Acc: 0.6752
Epoch 6/15 - Train Loss: 0.5419, Val Acc: 0.6594
Epoch 7/15 - Train Loss: 0.5053, Val Acc: 0.6851
Epoch 8/15 - Train Loss: 0.4646, Val Acc: 0.6792
Epoch 9/15 - Train Loss: 0.4353, Val Acc: 0.6653
Epoch 10/15 - Train Loss: 0.3956, Val Acc: 0.6772
Epoch 11/15 - Train Loss: 0.3611, Val Acc: 0.6693
Epoch 12/15 - Train Loss: 0.3331, Val Acc: 0.6713
Epoch 13/15 - Train Loss: 0.3052, Val Acc: 0.6733
Epoch 14/15 - Train Loss: 0.2811, Val Acc: 0.6653
Epoch 15/15 - Train Loss: 0.2514, Val Acc: 0.6594

=== Training MiniMMBT ===




Epoch 1/15 - Train Loss: 0.9438, Val Acc: 0.5762
Epoch 2/15 - Train Loss: 0.7732, Val Acc: 0.5921
Epoch 3/15 - Train Loss: 0.6606, Val Acc: 0.6376
Epoch 4/15 - Train Loss: 0.5125, Val Acc: 0.6119
Epoch 5/15 - Train Loss: 0.3367, Val Acc: 0.6396
Epoch 6/15 - Train Loss: 0.2222, Val Acc: 0.5901
Epoch 7/15 - Train Loss: 0.1435, Val Acc: 0.6119
Epoch 8/15 - Train Loss: 0.0959, Val Acc: 0.6198
Epoch 9/15 - Train Loss: 0.0652, Val Acc: 0.5960
Epoch 10/15 - Train Loss: 0.0749, Val Acc: 0.6079
Epoch 11/15 - Train Loss: 0.0646, Val Acc: 0.5802
Epoch 12/15 - Train Loss: 0.0440, Val Acc: 0.5980
Epoch 13/15 - Train Loss: 0.0567, Val Acc: 0.6040
Epoch 14/15 - Train Loss: 0.0283, Val Acc: 0.5941
Epoch 15/15 - Train Loss: 0.0441, Val Acc: 0.6040

=== Training LateFusionMLP ===
Epoch 1/15 - Train Loss: 0.9099, Val Acc: 0.6079
Epoch 2/15 - Train Loss: 0.7658, Val Acc: 0.6317
Epoch 3/15 - Train Loss: 0.6939, Val Acc: 0.6673
Epoch 4/15 - Train Loss: 0.6382, Val Acc: 0.6495
Epoch 5/15 - Train Loss: 0.5917




=== Benchmark Results ===
SV-FEND: Acc=0.6713, F1=0.6693, Prec=0.6727, Rec=0.6673
MiniMMBT: Acc=0.5980, F1=0.5932, Prec=0.6021, Rec=0.5919
LateFusionMLP: Acc=0.6772, F1=0.6768, Prec=0.6769, Rec=0.6768
LogisticRegression: Acc=0.5802, F1=0.5782, Prec=0.5779, Rec=0.5789
