# Siamese Neural Networks for Movie Trailer Recommendation System

### MASTER TRAINER
* able to run all 4 models (2x2)
* cosine vs euclidean
* pair vs triplet

### code prepration

In [1]:
# # RUN AND THEN RESTART KERNEL!
# import subprocess
# import sys

# # Downgrade protobuf to compatible version
# result = subprocess.run([sys.executable, '-m', 'pip', 'install', 'protobuf==3.20.3'], 
#                        capture_output=True, text=True)
# print("STDOUT:", result.stdout)
# print("STDERR:", result.stderr)

# !export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

In [2]:
# !pip install torchvision transformers
# !pip install torch torchvision transformers numpy pandas scikit-learn tqdm librosa opencv-python matplotlib tensorboard boto3 av 
# !pip install resampy soundfile
# !pip install ffmpeg-python
# !pip install easyocr
# !pip install torchinfo
# !pip install kornia

In [3]:
# !conda install -c conda-forge ffmpeg -y

In [4]:
# Standard Library Imports
import gc
import io
import logging
import math
import os
import pickle
import random
import tempfile
import time
import traceback
import warnings
from collections import defaultdict, OrderedDict
from datetime import datetime
from functools import lru_cache
from pathlib import Path

# Third-party Core Libraries
import numpy as np
import pandas as pd
import psutil
from tqdm import tqdm
# from tqdm.auto import tqdm  # For notebook compatibility

# Machine Learning & Scientific Computing
from sklearn.decomposition import PCA
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.manifold import TSNE
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, average_precision_score, confusion_matrix,
    precision_recall_curve
)

# Deep Learning - PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils as utils
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary
from torchvision import transforms
import torchvision.models as models

# Deep Learning - Transformers & Specialized
from transformers import BertModel, BertTokenizer
import kornia.filters as KFilters

# Computer Vision & Audio Processing
import cv2
import librosa

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Cloud & External Services
import boto3

# Monitoring & Logging
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

# Warning Filters Configuration
warnings.filterwarnings("ignore", category=FutureWarning, module="librosa")
warnings.filterwarnings("ignore", category=UserWarning, message="PySoundFile failed.*")
warnings.filterwarnings("ignore", category=UserWarning, message=".*Parameter.*")
cv2.setLogLevel(0)  # Suppress OpenCV warnings

# Logging Configuration
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# CUDA Configuration
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Reproducibility Setup
def set_seed(seed=42):
    """Set seeds for reproducible results across all libraries."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

print('All imports successfully loaded! Ready to go!')

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f257d9bf100>>
Traceback (most recent call last):
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 

KeyboardInterrupt



## Feature Extractor Modules

### Visual Features

In [None]:
class VisualProcessingModule(nn.Module):
    def __init__(self, backbone="dinov2", low_level_features=True, use_precomputed_embeddings=False, dinov2_embedding_dim=384):
        """
        Visual Processing Module that can work with either:
        1. Raw frames + backbone processing (original mode)
        2. Precomputed DINOv2 embeddings (new mode)
        
        Args:
            backbone: "dinov2" or "resnet" - only used when use_precomputed_embeddings=False
            low_level_features: Whether to extract and use low-level visual features
            use_precomputed_embeddings: If True, expects precomputed DINOv2 embeddings as input
            dinov2_embedding_dim: Dimension of DINOv2 embeddings (384 for ViT-S/14)
        """
        super().__init__()
        self.backbone_type = backbone
        self.low_level_features = low_level_features
        self.use_precomputed_embeddings = use_precomputed_embeddings

        # Setup backbone and feature dimensions
        if use_precomputed_embeddings:
            # No backbone needed, we'll receive precomputed embeddings
            self.backbone = None
            self.feature_dim = dinov2_embedding_dim
            print(f"Using precomputed embeddings mode with feature_dim={self.feature_dim}")
        else:
            # Original mode - setup backbone
            if backbone == "dinov2":
                self.backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14', trust_repo=True)
                self.feature_dim = 384
                
                # Conservative fine-tuning: Only unfreeze the final norm layer
                print("DINOv2 Fine-tuning: Conservative - ONLY final 'norm' layer.")
                for name, param in self.backbone.named_parameters():
                    param.requires_grad = False  # Freeze all first
                    
                if hasattr(self.backbone, 'norm') and isinstance(self.backbone.norm, nn.LayerNorm):
                    print("Unfreezing DINOv2's final 'norm' layer.")
                    for param in self.backbone.norm.parameters():
                        param.requires_grad = True
                    
                    num_trainable = sum(p.numel() for p in self.backbone.parameters() if p.requires_grad)
                    print(f"DINOv2 trainable parameters: {num_trainable}")
                else:
                    print("DINOv2 'norm' layer not found. Backbone remains fully frozen.")
                    
            elif backbone == "resnet":
                self.backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
                self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])
                self.feature_dim = 2048
                
                # Freeze all but last few layers
                for param in list(self.backbone.parameters())[:-10]:
                    param.requires_grad = False

        # MLP layers for processing features
        self.norm_after_backbone_pooling = nn.LayerNorm(self.feature_dim, eps=1e-5)
        self.dropout_before_fusion = nn.Dropout(0.4)
        
        self.frame_fusion_mlp = nn.Sequential(
            nn.Linear(self.feature_dim, self.feature_dim // 2),
            nn.BatchNorm1d(self.feature_dim // 2, eps=1e-5),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(self.feature_dim // 2, self.feature_dim // 2)
        )
        
        self.dropout_after_fusion = nn.Dropout(0.4)

        # Low-level feature processing
        if low_level_features:
            self.low_level_dim_output = 64
            self.low_level_projection = nn.Linear(5, self.low_level_dim_output)
            self.output_dim = (self.feature_dim // 2) + self.low_level_dim_output
        else:
            self.output_dim = self.feature_dim // 2

        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize weights with conservative gains to prevent gradient explosion."""
        for module_name, module in self.named_modules():
            if isinstance(module, nn.Linear):
                if 'frame_fusion_mlp.0' in module_name:
                    gain = 1.0
                    print(f"Applying standard gain ({gain}) to {module_name}")
                    nn.init.xavier_uniform_(module.weight, gain=gain)
                else:
                    gain = nn.init.calculate_gain('relu')
                    nn.init.xavier_uniform_(module.weight, gain=gain)
                
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0.0)
                    
            elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
                nn.init.constant_(module.weight, 1.0)
                nn.init.constant_(module.bias, 0.0)

    def forward(self, frames, precomputed_dino_embeddings=None):
        """
        Forward pass supporting both modes:
        
        Mode 1 (original): forward(frames)
        - frames: (Batch, NumFrames, C, H, W)
        
        Mode 2 (precomputed): forward(frames, precomputed_dino_embeddings)
        - frames: (Batch, NumFrames, C, H, W) - used only for low-level features
        - precomputed_dino_embeddings: (Batch, NumFrames, DINO_DIM)
        """
        
        if self.use_precomputed_embeddings:
            if precomputed_dino_embeddings is None:
                raise ValueError("precomputed_dino_embeddings must be provided when use_precomputed_embeddings=True")
            return self._forward_with_precomputed(frames, precomputed_dino_embeddings)
        else:
            return self._forward_with_backbone(frames)

    def _forward_with_precomputed(self, frames_for_low_level, precomputed_dino_frame_embeddings):
        """Forward pass using precomputed DINOv2 embeddings."""
        
        # 1. Pool the precomputed per-frame embeddings
        video_features_pooled = precomputed_dino_frame_embeddings.mean(dim=1)  # (Batch, DINO_DIM)
        
        # --- ADDED L2 NORMALIZATION FOR POOLED DINO FEATURES ---
        if torch.isnan(video_features_pooled).any() or torch.isinf(video_features_pooled).any():
            print("V- Pooled DINO had NaN/Inf before L2 norm. Clamping.")
            video_features_pooled = torch.nan_to_num(video_features_pooled, nan=0.0, posinf=1.0, neginf=-1.0) # Added clamping range
        
        # Stabilize before F.normalize
        pooled_dino_norm_val = torch.norm(video_features_pooled, p=2, dim=1, keepdim=True)
        # Add a small epsilon to prevent division by zero if norm is exactly zero after nan_to_num
        if torch.any(pooled_dino_norm_val < 1e-7): 
            noise = torch.randn_like(video_features_pooled) * 1e-7
            # Apply noise only where norm is too small
            video_features_pooled = torch.where(
                (pooled_dino_norm_val < 1e-7).expand_as(video_features_pooled), 
                video_features_pooled + noise, 
                video_features_pooled
            )
            
        video_features_pooled_for_norm = video_features_pooled
        if torch.any(torch.norm(video_features_pooled_for_norm, p=2, dim=1) < 1e-7): # Check again
             # If still zero norm for some (e.g. all-zero input + all-zero noise), make them tiny non-zero
             video_features_pooled_for_norm = video_features_pooled_for_norm + 1e-7 * torch.ones_like(video_features_pooled_for_norm)


        video_features_pooled_normalized_l2 = F.normalize(video_features_pooled_for_norm, p=2, dim=1, eps=1e-6)

        # 2. Normalization and MLP processing
        video_features_normed_by_layernorm = self.norm_after_backbone_pooling(video_features_pooled_normalized_l2) # Pass the L2-normalized version
        
        if torch.isnan(video_features_normed_by_layernorm).any() or torch.isinf(video_features_normed_by_layernorm).any():
            video_features_normed_by_layernorm = torch.nan_to_num(video_features_normed_by_layernorm, nan=0.0)

        video_features_to_fuse = self.dropout_before_fusion(video_features_normed_by_layernorm)
        
        # Check BatchNorm variance if needed
        if video_features_to_fuse.size(0) > 1: # Ensure batch size > 1 for variance calculation
            input_to_bn = self.frame_fusion_mlp[0](video_features_to_fuse)
            # Ensure input_to_bn is not empty and has variance before checking
            if input_to_bn.numel() > 0 and input_to_bn.size(0) > 1: 
                var_bn_input_check = input_to_bn.var(dim=0, unbiased=False)
                if (var_bn_input_check < 1e-6).any():
                    problematic_channels = (var_bn_input_check < 1e-6).sum().item()

        fused_video_features = self.frame_fusion_mlp(video_features_to_fuse)
        if torch.isnan(fused_video_features).any() or torch.isinf(fused_video_features).any():
            fused_video_features = torch.nan_to_num(fused_video_features, nan=0.0)

        final_high_level_features = self.dropout_after_fusion(fused_video_features)

        # 3. Low-level features (calculated on-the-fly using original frames)
        if self.low_level_features and frames_for_low_level is not None:
            low_level_raw_avg = self.extract_low_level_features(frames_for_low_level)  # (Batch, 5)
            
            if torch.isnan(low_level_raw_avg).any():
                low_level_raw_avg = torch.nan_to_num(low_level_raw_avg, nan=0.0)
            
            low_level_projected = self.low_level_projection(low_level_raw_avg)
            low_level_projected = torch.tanh(low_level_projected)
            if torch.isnan(low_level_projected).any():
                low_level_projected = torch.nan_to_num(low_level_projected, nan=0.0)

            output_features = torch.cat([final_high_level_features, low_level_projected], dim=1)
        else:
            output_features = final_high_level_features
        
        if torch.isnan(output_features).any():
            output_features = torch.nan_to_num(output_features, nan=0.0)
        
        output_magnitude_visual = torch.norm(output_features, p=2, dim=1, keepdim=True)
        if torch.any(output_magnitude_visual < 1e-7):
            noise_visual = torch.randn_like(output_features) * 1e-7
            output_features = torch.where((output_magnitude_visual < 1e-7).expand_as(output_features), output_features + noise_visual, output_features)
            # print(f"V- Applied noise to final visual features with near-zero norm before F.normalize.")

        output_features_final_normalized = F.normalize(output_features, p=2, dim=1, eps=1e-6)
        
        # print(f"V- Post-FinalNorm: mean={output_features_final_normalized.mean().item():.4f}, std={output_features_final_normalized.std().item():.4f}, norm={torch.norm(output_features_final_normalized, p=2, dim=1).mean().item():.4f}")

        if torch.isnan(output_features_final_normalized).any():
            output_features_final_normalized = torch.zeros_like(output_features_final_normalized)
            
        return output_features_final_normalized

    def _forward_with_backbone(self, frames):
        """Original forward pass using backbone feature extraction."""
        batch_size, num_frames = frames.shape[0], frames.shape[1]

        if torch.isnan(frames).any():
            print("V- WARNING: Input frames contain NaN values! Replacing with zeros.")
            frames = torch.nan_to_num(frames, nan=0.0)

        # 1. Backbone Feature Extraction
        # Handle partial freezing for DINOv2
        if self.backbone_type == "dinov2":
            # Store original requires_grad states
            original_requires_grad = {}
            for name, param in self.backbone.named_parameters():
                original_requires_grad[name] = param.requires_grad
                # Temporarily freeze parts that should be frozen during forward pass
                if 'norm' not in name:
                    param.requires_grad_(False)
        
        frames_flat = frames.view(-1, 3, frames.shape[3], frames.shape[4])
        chunk_size = 8  # Adjust based on GPU memory
        frame_features_list = []
        
        for i in range(0, frames_flat.size(0), chunk_size):
            chunk = frames_flat[i:i+chunk_size]
            chunk_features = self.backbone(chunk)
            frame_features_list.append(chunk_features)
        
        frame_features_flat = torch.cat(frame_features_list, dim=0)

        # Restore original requires_grad states for DINOv2
        if self.backbone_type == "dinov2":
            for name, param in self.backbone.named_parameters():
                param.requires_grad_(original_requires_grad[name])

        if torch.isnan(frame_features_flat).any() or torch.isinf(frame_features_flat).any():
            print("V- ❌ CRITICAL: NaN/Inf detected from backbone output! Cleaning.")
            frame_features_flat = torch.nan_to_num(frame_features_flat, nan=0.0, posinf=10.0, neginf=-10.0)
        
        frame_features = frame_features_flat.view(batch_size, num_frames, -1)
        frame_features = torch.clamp(frame_features, -10.0, 10.0)  # Safety clamp

        # 2. Temporal Pooling and Normalization
        video_features_pooled = frame_features.mean(dim=1)

        if torch.isnan(video_features_pooled).any() or torch.isinf(video_features_pooled).any():
            print("V- ❌ NaN/Inf in video_features_pooled! Zeroing.")
            video_features_pooled = torch.nan_to_num(video_features_pooled, nan=0.0)

        video_features_normed = self.norm_after_backbone_pooling(video_features_pooled)
        if torch.isnan(video_features_normed).any() or torch.isinf(video_features_normed).any():
            print("V- ❌ NaN/Inf after normalization! Zeroing.")
            video_features_normed = torch.nan_to_num(video_features_normed, nan=0.0)

        # 3. Frame Fusion MLP
        video_features_to_fuse = self.dropout_before_fusion(video_features_normed)

        fused_video_features = self.frame_fusion_mlp(video_features_to_fuse)

        # 4. Final Dropout
        final_high_level_features = self.dropout_after_fusion(fused_video_features)

        # 5. Low-Level Features (if enabled)
        if self.low_level_features:
            low_level_raw = self.extract_low_level_features(frames)
            if torch.isnan(low_level_raw).any() or torch.isinf(low_level_raw).any():
                print("V- ❌ NaN/Inf in low_level_raw! Zeroing.")
                low_level_raw = torch.nan_to_num(low_level_raw, nan=0.0)
            
            low_level_projected = self.low_level_projection(low_level_raw)
            low_level_projected = torch.tanh(low_level_projected)  # Bound low-level features
            if torch.isnan(low_level_projected).any() or torch.isinf(low_level_projected).any():
                print("V- ❌ NaN/Inf in low_level_projected! Zeroing.")
                low_level_projected = torch.nan_to_num(low_level_projected, nan=0.0)

            output_features = torch.cat([final_high_level_features, low_level_projected], dim=1)
        else:
            output_features = final_high_level_features
        
        if torch.isnan(output_features).any() or torch.isinf(output_features).any():
            print("V- ❌ NaN/Inf in final output_features! Zeroing.")
            output_features = torch.nan_to_num(output_features, nan=0.0)
            
        return output_features

    def extract_low_level_features(self, frames):
        """Extract low-level visual features like brightness, contrast, edge density."""
        B, N, C, H, W = frames.shape
        frames_flat = frames.view(B * N, C, H, W)

        # Convert to grayscale robustly
        if frames_flat.dtype == torch.uint8:
            frames_flat_float = frames_flat.float() / 255.0
        else:
            frames_flat_float = frames_flat

        gray_frames_flat = (0.299 * frames_flat_float[:, 0:1] + 
                           0.587 * frames_flat_float[:, 1:2] + 
                           0.114 * frames_flat_float[:, 2:3])
        
        brightness = gray_frames_flat.mean(dim=[1, 2, 3])
        contrast = gray_frames_flat.std(dim=[1, 2, 3], unbiased=False) + 1e-8

        # Edge density using Kornia if available
        edge_density = torch.zeros_like(brightness)
        if KFilters is not None:
            try:
                sobel_output = KFilters.sobel(gray_frames_flat)
                if sobel_output.shape[1] == 2:  # Returns gx and gy
                    sobel_magnitude = torch.sqrt(sobel_output[:,0:1]**2 + sobel_output[:,1:2]**2 + 1e-10)
                else:  # Returns magnitude
                    sobel_magnitude = sobel_output
                edge_density = sobel_magnitude.mean(dim=[1, 2, 3])
            except Exception as e:
                print(f"V- Kornia Sobel filter error: {e}. Using zero for edge density.")
                edge_density = torch.zeros_like(brightness)
        
        # Placeholders for additional features
        color_entropy = torch.zeros_like(brightness)
        rule_of_thirds_metric = torch.zeros_like(brightness)

        low_level_features_flat = torch.stack([
            brightness, contrast, edge_density, color_entropy, rule_of_thirds_metric
        ], dim=1)
        
        low_level_features_video = low_level_features_flat.view(B, N, -1)
        features_avg = low_level_features_video.mean(dim=1)
        
        if torch.isnan(features_avg).any():
            features_avg = torch.nan_to_num(features_avg, nan=0.0)
            
        return features_avg

### Audio Processing -optimized

In [None]:
class OptimizedAudioProcessingModule(nn.Module):
    def __init__(self, use_vggish=True, use_spectrogram_cnn=True, vggish_embedding_dim=128):
        super().__init__()
        self.use_vggish = use_vggish
        self.use_spectrogram_cnn = use_spectrogram_cnn
        
        self.vggish_dim = vggish_embedding_dim if self.use_vggish else 0
            
        # Spectrogram CNN setup
        if self.use_spectrogram_cnn:
            self.spec_cnn = nn.Sequential(
                nn.Conv2d(1, 16, kernel_size=3, padding=1),      # 0
                nn.BatchNorm2d(16),                             # 1
                nn.GELU(),                                      # 2
                nn.MaxPool2d(2),                                # 3
                nn.Dropout2d(0.1),                              # 4
                nn.Conv2d(16, 32, kernel_size=3, padding=1),    # 5
                nn.BatchNorm2d(32),                             # 6
                nn.GELU(),                                      # 7
                nn.MaxPool2d(2),                                # 8
                nn.Dropout2d(0.1),                              # 9
                nn.Conv2d(32, 64, kernel_size=3, padding=1),    # 10
                nn.BatchNorm2d(64),                             # 11
                nn.GELU(),                                      # 12
                nn.AdaptiveAvgPool2d((1, 1)),                   # 13
                nn.Dropout2d(0.2)                               # 14
            )
            
            self.spec_dim = 64
        else:
            self.spec_dim = 0

        # Fusion layer setup
        self.input_to_fusion_dim = self.vggish_dim + self.spec_dim 
        self.fused_audio_intermediate_dim = 128 
        self.final_audio_output_dim = 192 

        if self.input_to_fusion_dim > 0 and (self.vggish_dim == 0 or self.spec_dim == 0):
            # Only one feature type is active
            self.fusion = nn.Identity()
            self.output_dim = self.input_to_fusion_dim
            print(f"AUDIO_MODULE: Using Identity for self.fusion as only one audio feature type "
                  f"({'VGGish' if self.vggish_dim > 0 else 'SpecCNN'}) is active. "
                  f"Output dim for temporal pooling: {self.output_dim}")
        elif self.vggish_dim > 0 and self.spec_dim > 0:
            # Both feature types are active - need fusion
            self.fusion = nn.Sequential(
                nn.Linear(self.input_to_fusion_dim, self.fused_audio_intermediate_dim),
                nn.LayerNorm(self.fused_audio_intermediate_dim, eps=1e-5), 
                nn.GELU(),
                nn.Dropout(0.5),
                nn.Linear(self.fused_audio_intermediate_dim, self.final_audio_output_dim) 
            )
            self.output_dim = self.final_audio_output_dim 
        else:
            # No audio features active
            self.fusion = nn.Identity()
            self.output_dim = 0
            print("AUDIO_MODULE: No audio features active, self.fusion is Identity, output_dim is 0.")

        # Temporal pooling setup
        self.temporal_pooling = nn.Sequential(
            nn.Linear(self.output_dim if self.output_dim > 0 else 1, self.output_dim if self.output_dim > 0 else 1), 
            nn.Tanh()
        )
        
        self._initialize_weights()
        
    def _initialize_weights(self):
        for module_name, module in self.named_modules(): 
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0.0)
            elif isinstance(module, nn.Linear):
                if 'fusion.' in module_name or 'temporal_pooling.' in module_name:
                    gain_val = 0.5 
                    nn.init.xavier_uniform_(module.weight, gain=gain_val) 
                else:
                    nn.init.xavier_uniform_(module.weight, gain=0.5) 
                
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0.0)
            elif isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.LayerNorm)): 
                nn.init.constant_(module.weight, 1.0)
                nn.init.constant_(module.bias, 0.0)
        
    def compute_spectrogram_batch(self, audio, sr=16000):
        """Compute spectrograms for a batch of audio samples."""
        device = audio.device
        batch_size = audio.shape[0]
        audio = torch.clamp(audio, -1.0, 1.0)
        window = torch.hann_window(512, device=device)
        specs = []
        
        for i in range(batch_size):
            audio_sample = audio[i]
            if audio_sample.std() > 1e-8: 
                audio_sample = (audio_sample - audio_sample.mean()) / (audio_sample.std() + 1e-8)
            stft = torch.stft(
                audio_sample, 
                n_fft=512, 
                hop_length=256, 
                window=window,
                return_complex=True
            )
            spec = torch.abs(stft)
            spec = torch.clamp(spec, min=1e-8, max=100.0) 
            spec = torch.log(spec + 1e-8)
            if spec.std() > 1e-8:
                spec = (spec - spec.mean()) / (spec.std() + 1e-8)
            spec = torch.clamp(spec, -5.0, 5.0) 
            specs.append(spec)
            
        specs = torch.stack(specs).unsqueeze(1)
        return specs

    def forward(self, precomputed_vggish_embedding=None, waveform_for_spec_cnn=None):
        """
        Forward pass with separate inputs for VGGish embeddings and spectrogram CNN.
        
        Args:
            precomputed_vggish_embedding: (Batch, VGGISH_DIM) - precomputed VGGish features
            waveform_for_spec_cnn: (Batch, NumSamples) or (Batch, 1, NumSamples) - raw audio for spec CNN
        """
        features_to_combine = []
        
        # Process VGGish embeddings if provided
        if self.use_vggish and precomputed_vggish_embedding is not None:
            if torch.isnan(precomputed_vggish_embedding).any() or torch.isinf(precomputed_vggish_embedding).any():
                print("A- VGGish emb had NaN/Inf before L2 norm. Clamping.")
                precomputed_vggish_embedding = torch.nan_to_num(precomputed_vggish_embedding, nan=0.0)
            
            # --- ADD L2 NORMALIZATION FOR VGGISH EMBEDDING ---
            vggish_norm_val = torch.norm(precomputed_vggish_embedding, p=2, dim=1, keepdim=True)
            if torch.any(vggish_norm_val < 1e-7):
                noise = torch.randn_like(precomputed_vggish_embedding) * 1e-7
                precomputed_vggish_embedding = torch.where((vggish_norm_val < 1e-7).expand_as(precomputed_vggish_embedding),
                                                           precomputed_vggish_embedding + noise, precomputed_vggish_embedding)
            
            normalized_vggish_embedding = F.normalize(precomputed_vggish_embedding, p=2, dim=1, eps=1e-6)
            # print(f"A- VGGish (after L2 norm): mean={normalized_vggish_embedding.mean().item():.4f}, std={normalized_vggish_embedding.std().item():.4f}, norm={torch.norm(normalized_vggish_embedding, p=2, dim=1).mean().item():.4f}")
            features_to_combine.append(normalized_vggish_embedding)
            # --- END ADDED L2 NORMALIZATION ---
        
        # Process spectrogram CNN if waveform provided
        if self.use_spectrogram_cnn and waveform_for_spec_cnn is not None:
            # Ensure waveform is (Batch, NumSamples) for compute_spectrogram_batch
            if waveform_for_spec_cnn.dim() == 3 and waveform_for_spec_cnn.shape[1] == 1:
                waveform_for_spec_cnn = waveform_for_spec_cnn.squeeze(1)

            specs = self.compute_spectrogram_batch(waveform_for_spec_cnn)  # (Batch, 1, N_MELS, Width)
            
            # Handle dimension issues
            if specs.dim() == 5: 
                specs = specs.squeeze(2)
            specs = torch.clamp(specs, -5.0, 5.0)
            
            # Process through CNN
            spec_cnn_output = self.spec_cnn(specs).squeeze(-1).squeeze(-1)
            spec_cnn_output = torch.clamp(spec_cnn_output, -10.0, 10.0)
            if torch.isnan(spec_cnn_output).any(): 
                spec_cnn_output = torch.nan_to_num(spec_cnn_output, nan=0.0)
            features_to_combine.append(spec_cnn_output)
            
        # Handle case where no features are available
        if not features_to_combine:
            expected_output_dim = self.output_dim if self.output_dim > 0 else 1
            # Determine batch size from one of the inputs if possible, or default if all are None
            bs = (precomputed_vggish_embedding.shape[0] if precomputed_vggish_embedding is not None else 
                  (waveform_for_spec_cnn.shape[0] if waveform_for_spec_cnn is not None else 1))
            device = (precomputed_vggish_embedding.device if precomputed_vggish_embedding is not None else
                     (waveform_for_spec_cnn.device if waveform_for_spec_cnn is not None else torch.device('cpu')))
            return torch.zeros(bs, expected_output_dim, device=device)

        # Combine features
        if len(features_to_combine) > 1:
            combined_audio_features = torch.cat(features_to_combine, dim=1)
            processed_features = self.fusion(combined_audio_features)
        elif features_to_combine:
            processed_features = features_to_combine[0]
            # Apply fusion if it's not Identity
            if not isinstance(self.fusion, nn.Identity):
                processed_features = self.fusion(processed_features)

        # Apply temporal pooling and normalization
        output = self.temporal_pooling(processed_features)
        output = F.normalize(output, p=2, dim=1, eps=1e-8)
        return output

### Text Processing 

In [None]:
class TextProcessingModule(nn.Module):
    def __init__(self, use_bert=True, use_tfidf=True, num_languages=150,
                 lang_embedding_dim=16):
        super().__init__()
        self.use_bert = use_bert
        self.use_tfidf = use_tfidf
        
        if use_bert:
            self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            self.bert = BertModel.from_pretrained('bert-base-uncased')
            self.bert_dim = 768
            for layer in self.bert.encoder.layer:
                for param in layer.parameters():
                    param.requires_grad = False
            for param in self.bert.pooler.parameters():
                param.requires_grad = False
        else:
            self.bert_dim = 0
        
        if use_tfidf:
            self.tfidf_dim = 100
            self.tfidf_projection = nn.Linear(5000, self.tfidf_dim)
            # Changed eps to be more conservative for LayerNorm
            self.tfidf_norm = nn.LayerNorm(self.tfidf_dim, eps=1e-6)
        else:
            self.tfidf_dim = 0
            
        # --- NEW: Language and Year Setup ---
        self.lang_embedding_dim = lang_embedding_dim
        self.language_embedding = nn.Embedding(num_languages, self.lang_embedding_dim)
        self.year_dim = 1 # Year is a single feature
            
        # --- MODIFIED: Update the fusion input dimension ---
        self.input_dim = self.bert_dim + self.tfidf_dim + self.lang_embedding_dim + self.year_dim
        
        # Fusion MLP
        self.output_dim = 384
        self.fusion = nn.Sequential(
            nn.Linear(self.input_dim, 256),
            nn.BatchNorm1d(256, eps=1e-5),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, self.output_dim)
        )
        
        self._initialize_weights()

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight, gain=1.0)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0.0)
            elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
                nn.init.constant_(module.weight, 1.0)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0.0)
        
    def process_bert(self, text_batch):
        encoded = self.tokenizer(text_batch, padding=True, truncation=True, 
                                max_length=128, return_tensors='pt')
        encoded = {k: v.to(next(self.bert.parameters()).device) for k, v in encoded.items()}
        
        with torch.no_grad():
            outputs = self.bert(**encoded)
            bert_features = outputs.last_hidden_state[:, 0, :]
        
        # Enhanced NaN/Inf checking for BERT features
        if torch.isnan(bert_features).any() or torch.isinf(bert_features).any():
            bert_features = torch.nan_to_num(bert_features, nan=0.0, posinf=1.0, neginf=-1.0)
        
        return bert_features
    
    def process_tfidf(self, tfidf_features):
        # Input sanitization
        if torch.isnan(tfidf_features).any() or torch.isinf(tfidf_features).any():
            tfidf_features = torch.nan_to_num(tfidf_features, nan=0.0, posinf=1.0, neginf=-1.0)

        projected = self.tfidf_projection(tfidf_features)
        
        # Post-projection sanitization
        if torch.isnan(projected).any() or torch.isinf(projected).any():
            projected = torch.nan_to_num(projected, nan=0.0, posinf=1.0, neginf=-1.0)
        
        if self.training and projected.size(0) > 1:
            # Check for zero variance and add tiny noise if needed
            var = projected.var(dim=0, keepdim=True)
            zero_var_mask = var < 1e-8
            if zero_var_mask.any():
                noise = torch.randn_like(projected) * 1e-8
                projected = projected + noise

        normalized_tfidf = self.tfidf_norm(projected)
        
        if torch.isnan(normalized_tfidf).any() or torch.isinf(normalized_tfidf).any():
            normalized_tfidf = torch.nan_to_num(normalized_tfidf, nan=0.0, posinf=1.0, neginf=-1.0)
        
        return normalized_tfidf
        
    def forward(self, title, plot, tfidf_features=None, year=None, language_idx=None):      
        features = []
        text_batch = [f"{t} [SEP] {p}" for t, p in zip(title, plot)]
        
        if self.use_bert:
            bert_features = self.process_bert(text_batch)  # Already handles NaNs
            
            # L2 normalization for BERT features
            bert_norm_val = torch.norm(bert_features, p=2, dim=1, keepdim=True)
            if torch.any(bert_norm_val < 1e-7):
                noise = torch.randn_like(bert_features) * 1e-7
                bert_features = torch.where((bert_norm_val < 1e-7).expand_as(bert_features),
                                           bert_features + noise, bert_features)
            bert_features = F.normalize(bert_features, p=2, dim=1, eps=1e-6)
            features.append(bert_features)
            
        if self.use_tfidf and tfidf_features is not None:
            tfidf_projected = self.process_tfidf(tfidf_features)
            features.append(tfidf_projected)
        elif self.use_tfidf and tfidf_features is None:
            print("T- TF-IDF features were expected but not provided. Creating zeros tensor.")
            batch_size = len(title)
            device = next(self.parameters()).device
            tfidf_zeros = torch.zeros(batch_size, self.tfidf_dim, device=device)
            features.append(tfidf_zeros)

        if language_idx is not None:
            if language_idx.dim() > 1:
                language_idx = language_idx.squeeze(1)
            lang_emb = self.language_embedding(language_idx)
            features.append(lang_emb)

        if year is not None:
            features.append(year)

        if not features:
             print("T- WARNING: No text features were processed. Returning zeros.")
             batch_size = len(title) if title else 1
             device = next(self.parameters()).device if len(list(self.parameters())) > 0 else 'cpu'
             return torch.zeros(batch_size, self.output_dim, device=device)

        combined = torch.cat(features, dim=1)
        
        if torch.isnan(combined).any() or torch.isinf(combined).any():
            combined = torch.nan_to_num(combined, nan=0.0, posinf=1.0, neginf=-1.0)
        
        output = self.fusion(combined)
        
        # Sanitize output from fusion MLP
        if torch.isnan(output).any() or torch.isinf(output).any():
            output = torch.nan_to_num(output, nan=0.0, posinf=1.0, neginf=-1.0)
        
        # Check for zero vectors before normalization and handle them
        output_norm = torch.norm(output, p=2, dim=1, keepdim=True)
        
        # Add noise to samples with near-zero norm
        if torch.any(output_norm < 1e-7):
            noise = torch.randn_like(output) * 1e-7
            near_zero_mask = (output_norm < 1e-7).expand_as(output)
            output = torch.where(near_zero_mask, output + noise, output)
        
        # Apply F.normalize with increased eps
        output = F.normalize(output, p=2, dim=1, eps=1e-6)
        
        # Final check after F.normalize
        if torch.isnan(output).any() or torch.isinf(output).any():
            print("T- 👺 NaN/Inf in text_module output AFTER F.normalize. Re-clamping to zeros.")
            output_norm_check = torch.norm(output, p=2, dim=1, keepdim=True)
            print(f"   Input norms to F.normalize: min={output_norm_check.min().item():.8f}, "
                  f"max={output_norm_check.max().item():.8f}, mean={output_norm_check.mean().item():.8f}")
            output = torch.zeros_like(output)
        
        return output

## Fusion Networks

In [None]:
#SIMPLIFIED!
class SimpleFusionNetwork(nn.Module):
    def __init__(self, visual_dim, audio_dim, text_dim, embedding_dim=128, ablation_mode='full'):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.ablation_mode = ablation_mode

        self.fusion_input_dim = 0
        if visual_dim > 0: self.fusion_input_dim += visual_dim
        if (ablation_mode == 'visual_audio' or ablation_mode == 'full') and audio_dim > 0:
            self.fusion_input_dim += audio_dim
        if ablation_mode == 'full' and text_dim > 0:
            self.fusion_input_dim += text_dim
        
        if self.fusion_input_dim == 0:
            raise ValueError("Fusion input dimension cannot be zero.")

        print(f"Fusion Input Dim: {self.fusion_input_dim}, Final Embedding Dim: {embedding_dim}")

        intermediate_dim = max(embedding_dim * 2, self.fusion_input_dim // 2)
        intermediate_dim = min(intermediate_dim, 1024) # Cap the size

        # A more stable MLP structure
        self.fusion_mlp = nn.Sequential(
            nn.LayerNorm(self.fusion_input_dim),
            nn.Linear(self.fusion_input_dim, intermediate_dim),
            nn.ReLU(),  # for stability
            nn.Dropout(0.5), # Heavier dropout
            nn.Linear(intermediate_dim, embedding_dim)
        )
        
        self._initialize_weights()
        print(f"🔧 Robust SimpleFusionNetwork (LayerNorm -> Linear -> ReLU -> Dropout -> Linear) initialized.")

    def _initialize_weights(self):
        for module in self.fusion_mlp.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight, gain=0.7)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.LayerNorm):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)

    def forward(self, visual_features, audio_features=None, text_features=None):
        features_to_fuse = []
        if visual_features is not None: features_to_fuse.append(visual_features)
        
        if (self.ablation_mode == 'visual_audio' or self.ablation_mode == 'full') and audio_features is not None:
            features_to_fuse.append(audio_features)
            
        if self.ablation_mode == 'full' and text_features is not None:
            features_to_fuse.append(text_features)
        
        if not features_to_fuse:
            return torch.zeros((1, self.embedding_dim), device=visual_features.device if visual_features is not None else 'cpu')

        fused_input = torch.cat(features_to_fuse, dim=1)
        
        # Check for NaN/Inf BEFORE the MLP
        if torch.isnan(fused_input).any() or torch.isinf(fused_input).any():
            print("F- ❌ NaN/Inf detected in fused_input before MLP. Clamping.")
            fused_input = torch.nan_to_num(fused_input, nan=0.0, posinf=1.0, neginf=-1.0)
            
        embedding = self.fusion_mlp(fused_input)
        
        # Final normalization ->l for contrastive loss
        return F.normalize(embedding, p=2, dim=1, eps=1e-8)

In [None]:
#REVISED HYBRID MODULE
class HybridFusionNetwork(nn.Module):
    def __init__(self, visual_dim, audio_dim, text_dim, embedding_dim=128, num_heads=4, dropout=0.2):
        super().__init__()
        self.embedding_dim = embedding_dim
        
        # --- Cross-Attention (Early Fusion) ---
        # This part models interactions between modalities
        self.visual_query = nn.Linear(visual_dim, visual_dim)
        
        # Audio influences Visual
        self.audio_kv = nn.Linear(audio_dim, visual_dim * 2) # Key and Value from Audio
        self.va_attention = nn.MultiheadAttention(embed_dim=visual_dim, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.va_norm = nn.LayerNorm(visual_dim)
        
        # Text influences Visual
        self.text_kv = nn.Linear(text_dim, visual_dim * 2) # Key and Value from Text
        self.vt_attention = nn.MultiheadAttention(embed_dim=visual_dim, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.vt_norm = nn.LayerNorm(visual_dim)
        
        # --- Late Fusion MLP ---
        # This part combines the individually processed and the cross-attended features
        
        #will have the original visual, the visual-after-audio-attention, and visual-after-text-attention
        late_fusion_input_dim = visual_dim * 3 
        
        self.late_fusion_mlp = nn.Sequential(
            nn.LayerNorm(late_fusion_input_dim),
            nn.Linear(late_fusion_input_dim, embedding_dim * 2),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(embedding_dim * 2, embedding_dim)
        )

        self._initialize_weights()
        print(f"🔧 Revised HybridFusionNetwork initialized.")

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight, gain=nn.init.calculate_gain('relu'))
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.LayerNorm):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)

    def forward(self, visual_features, audio_features, text_features):
        # All inputs should be pre-normalized from the sub-modules
        
        # --- Early Fusion Part (Cross-Attention) ---
        # Visual features will act as the 'query' that asks questions of other modalities
        q = self.visual_query(visual_features).unsqueeze(1) # (B, 1, D_vis)

        # 1. Audio influences Visual
        audio_k, audio_v = self.audio_kv(audio_features).chunk(2, dim=-1)
        audio_k = audio_k.unsqueeze(1) # (B, 1, D_vis)
        audio_v = audio_v.unsqueeze(1) # (B, 1, D_vis)
        
        # Attention: Visual asks "what's in the audio?"
        va_out, _ = self.va_attention(query=q, key=audio_k, value=audio_v)
        va_out = va_out.squeeze(1) # (B, D_vis)
        visual_after_audio = self.va_norm(visual_features + va_out) 

        # 2. Text influences Visual
        text_k, text_v = self.text_kv(text_features).chunk(2, dim=-1)
        text_k = text_k.unsqueeze(1) # (B, 1, D_vis)
        text_v = text_v.unsqueeze(1) # (B, 1, D_vis)

        # Attention: Visual asks "what's in the text?"
        vt_out, _ = self.vt_attention(query=q, key=text_k, value=text_v)
        vt_out = vt_out.squeeze(1) # (B, D_vis)
        visual_after_text = self.vt_norm(visual_features + vt_out) 

        # --- Late Fusion Part ---
        # Combine the original visual with the attention-modified versions
        late_fusion_input = torch.cat([visual_features, visual_after_audio, visual_after_text], dim=1)
        
        final_embedding = self.late_fusion_mlp(late_fusion_input)

        # Final normalization before the loss function
        return F.normalize(final_embedding, p=2, dim=1, eps=1e-8)

In [None]:
class UltraSimpleFusionNetwork(nn.Module):
    def __init__(self, visual_dim, audio_dim, text_dim, embedding_dim=128, ablation_mode='full'):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.ablation_mode = ablation_mode

        self.fusion_input_dim = 0
        if visual_dim > 0: self.fusion_input_dim += visual_dim
        if (ablation_mode == 'visual_audio' or ablation_mode == 'full') and audio_dim > 0:
            self.fusion_input_dim += audio_dim
        if ablation_mode == 'full' and text_dim > 0:
            self.fusion_input_dim += text_dim
        
        if self.fusion_input_dim == 0:
            raise ValueError("Fusion input dimension cannot be zero.")

        print(f"Fusion Input Dim: {self.fusion_input_dim}, Final Embedding Dim: {embedding_dim}")

        # --- ULTRA-STABLE MLP ---
        # A direct, normalized projection with no hidden layers.
        # This is the most stable architecture possible.
        self.fusion_mlp = nn.Sequential(
            nn.LayerNorm(self.fusion_input_dim),
            nn.Linear(self.fusion_input_dim, self.embedding_dim)
        )
        
        self._initialize_weights()
        print(f"🔧 ULTRA-STABLE UltraSimpleFusionNetwork (LayerNorm -> Linear) initialized.")

    def _initialize_weights(self):
        for module in self.fusion_mlp.modules():
            if isinstance(module, nn.Linear):
                # Small gain for stability
                nn.init.xavier_uniform_(module.weight, gain=0.1)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.LayerNorm):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)

    def forward(self, visual_features, audio_features=None, text_features=None):
        features_to_fuse = []
        if visual_features is not None: features_to_fuse.append(visual_features)
        
        if (self.ablation_mode == 'visual_audio' or self.ablation_mode == 'full') and audio_features is not None:
            features_to_fuse.append(audio_features)
            
        if self.ablation_mode == 'full' and text_features is not None:
            features_to_fuse.append(text_features)
        
        if not features_to_fuse:
            return torch.zeros((1, self.embedding_dim), device='cpu')

        fused_input = torch.cat(features_to_fuse, dim=1)
        
        if torch.isnan(fused_input).any() or torch.isinf(fused_input).any():
            fused_input = torch.nan_to_num(fused_input, nan=0.0)
            
        embedding = self.fusion_mlp(fused_input)
        
        return F.normalize(embedding, p=2, dim=1, eps=1e-8)

## LOSS Functions

In [None]:
class ContrastiveLossCosine(nn.Module):
    def __init__(self, margin=0.5): # A smaller margin is often used for Cosine distance
        super(ContrastiveLossCosine, self).__init__()
        self.margin = margin
        self.eps = 1e-9 # For numerical stability

    def forward(self, embedding1, embedding2, label):
        cosine_similarity = F.cosine_similarity(embedding1, embedding2, dim=1, eps=self.eps)

        # Distance = 1 - Similarity
        cosine_distance = 1 - cosine_similarity

        loss_positive = (label) * torch.pow(cosine_distance, 2)
        loss_negative = (1 - label) * torch.pow(torch.clamp(self.margin - cosine_distance, min=0.0), 2)
        
        loss_contrastive = torch.mean(loss_positive + loss_negative)

        return loss_contrastive

In [None]:
class ContrastiveLossEuclidean(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLossEuclidean, self).__init__()
        self.margin = margin
        self.eps = 1e-9 # for stability in sqrt

    def forward(self, embedding1, embedding2, label):
        euclidean_distance = F.pairwise_distance(embedding1, embedding2, p=2, eps=self.eps)

        loss_positive = (label) * torch.pow(euclidean_distance, 2)
        loss_negative = (1 - label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
        loss_contrastive = torch.mean(loss_positive + loss_negative)
        
        return loss_contrastive

In [None]:
class TripletLossCosine(nn.Module):
    """
    Triplet loss function for cosine similarity.
    The goal is to make the anchor more similar to the positive than to the negative,
    by at least a certain margin.
    
    Objective: sim(a, p) > sim(a, n) + margin
    Loss: max(0, sim(a, n) - sim(a, p) + margin)
    """
    def __init__(self, margin=0.3):
        super(TripletLossCosine, self).__init__()
        self.margin = margin
        self.eps = 1e-9 # For numerical stability

    def forward(self, anchor_emb, positive_emb, negative_emb):
        # All embeddings should be L2-normalized from the model
        sim_pos = F.cosine_similarity(anchor_emb, positive_emb, dim=1, eps=self.eps)
        sim_neg = F.cosine_similarity(anchor_emb, negative_emb, dim=1, eps=self.eps)
        
        loss = torch.clamp(sim_neg - sim_pos + self.margin, min=0.0)
        
        return torch.mean(loss)

In [None]:
class TripletLossEuclidean(nn.Module):
    """
    Triplet loss function for Euclidean distance.
    The goal is to make the anchor's distance to the positive smaller than its
    distance to the negative, by at least a certain margin.
    
    Objective: dist(a, p) + margin < dist(a, n)
    Loss: max(0, dist(a, p) - dist(a, n) + margin)
    """
    def __init__(self, margin=1.0):
        super(TripletLossEuclidean, self).__init__()
        self.margin = margin

    def forward(self, anchor_emb, positive_emb, negative_emb):
        dist_pos = F.pairwise_distance(anchor_emb, positive_emb, p=2)
        dist_neg = F.pairwise_distance(anchor_emb, negative_emb, p=2)
        
        loss = torch.clamp(dist_pos - dist_neg + self.margin, min=0.0)
        
        return torch.mean(loss)

## SNN architectures

In [None]:
class OptimizedPairBasedSiameseNetwork(nn.Module):
    def __init__(self, visual_module, audio_module, text_module, fusion_network):
        super().__init__()
        self.visual_module = visual_module
        self.audio_module = audio_module
        self.text_module = text_module
        self.fusion_network = fusion_network
    
    def forward_single(self, 
                       frames_for_low_level,         # For visual low-level
                       dino_frame_embeddings,        # For visual high-level (precomputed)
                       vggish_embedding=None,        # For audio (precomputed VGGish)
                       waveform_for_spec_cnn=None,   # For audio (raw wave for spec CNN)
                       title=None, plot=None, tfidf_features=None, # For text
                       year=None, language_idx=None):

        # 1. Visual Processing
        visual_features = self.visual_module(
            frames=frames_for_low_level,
            precomputed_dino_embeddings=dino_frame_embeddings
        )
        
        # 2. Audio Processing
        audio_features = None
        if self.audio_module is not None:
            audio_features = self.audio_module(
                precomputed_vggish_embedding=vggish_embedding, 
                waveform_for_spec_cnn=waveform_for_spec_cnn
            )
        
        # 3. Text Processing
        text_features = None
        if self.text_module is not None:
            # check that title and plot are not None before passing them
            if title is not None and plot is not None:
                text_features = self.text_module(
                    title=title, 
                    plot=plot, 
                    tfidf_features=tfidf_features, 
                    year=year, 
                    language_idx=language_idx
                )
            else:
                batch_size = visual_features.shape[0]
                device = visual_features.device
                text_features = torch.zeros(batch_size, self.text_module.output_dim, device=device)

            
        fused_embedding = self.fusion_network(visual_features, audio_features, text_features)
        return fused_embedding
        
    def forward(self, 
                frames1_for_low_level, dino_frame_embeddings1, waveform1, vggish_embedding1, 
                frames2_for_low_level, dino_frame_embeddings2, waveform2, vggish_embedding2,
                video1_title=None, video1_plot=None, video1_tfidf=None, video1_year=None, video1_language_idx=None,
                video2_title=None, video2_plot=None, video2_tfidf=None, video2_year=None, video2_language_idx=None):
       
        current_batch_size = frames1_for_low_level.shape[0]
        current_device = frames1_for_low_level.device

        nan_check_visual = (
            torch.isnan(frames1_for_low_level).any() or torch.isnan(dino_frame_embeddings1).any() or
            torch.isnan(frames2_for_low_level).any() or torch.isnan(dino_frame_embeddings2).any()
        )
        if nan_check_visual:
            emb_dim = self.fusion_network.embedding_dim if hasattr(self.fusion_network, 'embedding_dim') else 128
            return (torch.zeros(current_batch_size, emb_dim, device=current_device),
                    torch.zeros(current_batch_size, emb_dim, device=current_device))

        if self.audio_module:
            waveform1 = torch.nan_to_num(waveform1, nan=0.0, posinf=0.0, neginf=0.0) if waveform1 is not None else None
            vggish_embedding1 = torch.nan_to_num(vggish_embedding1, nan=0.0, posinf=0.0, neginf=0.0) if vggish_embedding1 is not None else None
            waveform2 = torch.nan_to_num(waveform2, nan=0.0, posinf=0.0, neginf=0.0) if waveform2 is not None else None
            vggish_embedding2 = torch.nan_to_num(vggish_embedding2, nan=0.0, posinf=0.0, neginf=0.0) if vggish_embedding2 is not None else None
        if self.text_module:
            video1_tfidf = torch.nan_to_num(video1_tfidf, nan=0.0) if video1_tfidf is not None else None
            video2_tfidf = torch.nan_to_num(video2_tfidf, nan=0.0) if video2_tfidf is not None else None


        embedding1 = self.forward_single( 
            frames1_for_low_level, dino_frame_embeddings1,
            vggish_embedding1, waveform1, 
            video1_title, video1_plot, tfidf_features=video1_tfidf,
            year=video1_year, language_idx=video1_language_idx
        )
        
        embedding2 = self.forward_single(
            frames2_for_low_level, dino_frame_embeddings2,
            vggish_embedding2, waveform2, 
            video2_title, video2_plot, tfidf_features=video2_tfidf,
            year=video2_year, language_idx=video2_language_idx 
        )

        return embedding1, embedding2

In [None]:
class OnlineTripletSiameseNetwork(nn.Module):
    def __init__(self, visual_module, audio_module, text_module, fusion_network):
        """
        Initializes the single processing tower of the Siamese network.
        Its only job is to process ONE batch of movies into embeddings.
        """
        super().__init__()
        self.visual_module = visual_module
        self.audio_module = audio_module
        self.text_module = text_module
        self.fusion_network = fusion_network

    def forward(self, **batch_data):
        """
        Processes a BATCH of movie data into a BATCH of embeddings.

        Args:
            **batch_data: A dictionary of features. For a batch of size 16,
                          'dino_frame_embeddings' would have a shape of [16, 16, 384],
                          'vggish_embedding' would be [16, 128],
                          'title' would be a list of 16 strings, etc.

        Returns:
            A tensor of embeddings with shape [batch_size, embedding_dim].
        """        
        # 1. Process Visual Features
        visual_features = self.visual_module(
            frames=batch_data['frames_for_low_level'],
            precomputed_dino_embeddings=batch_data['dino_frame_embeddings']
        )
        
        # 2. Process Audio Features
        audio_features = None
        if self.audio_module is not None:
            audio_features = self.audio_module(
                precomputed_vggish_embedding=batch_data.get('vggish_embedding'),
                waveform_for_spec_cnn=batch_data.get('waveform_for_spec_cnn')
            )
        
        # 3. Process Text Features
        text_features = None
        if self.text_module is not None:
            text_features = self.text_module(
                title=batch_data.get('title'), 
                plot=batch_data.get('plot'),
                tfidf_features=batch_data.get('tfidf_features'),
                year=batch_data.get('year'),
                language_idx=batch_data.get('language_idx')
            )
            
        # 4. Fuse all features into the final embedding
        raw_embedding = self.fusion_network(visual_features, audio_features, text_features)
        
        # Apply the final L2 normalization ->guarantees the input to loss function is always normalized.
        final_normalized_embedding = F.normalize(raw_embedding, p=2, dim=1, eps=1e-6)
        
        return final_normalized_embedding

## Dataset and Helpers

In [None]:
class VideoDatasetS3:
    """Helper class to download videos from S3"""
    def __init__(self, bucket_name):
        self.s3_client = boto3.client('s3')
        self.bucket_name = bucket_name
        self.cache = CachedS3Dataset(bucket_name)
        # Cache of movie IDs to full S3 keys
        self.movie_key_cache = {}
        
    def get_video_path(self, movie_id, title=""):
        """Download video from S3 and return local path"""
        # Check if we've already found this movie's key
        if movie_id in self.movie_key_cache:
            video_key = self.movie_key_cache[movie_id]
        else:
            # If title is empty, we need to find the video by listing objects with the movie_id prefix
            if not title:
                try:
                    # List objects with the movie ID prefix
                    response = self.s3_client.list_objects_v2(
                        Bucket=self.bucket_name,
                        Prefix=f"movie_trailers/{movie_id}_"
                    )
                    
                    # Check if any objects were found
                    if 'Contents' in response and response['Contents']:
                        # Use the first matching object
                        video_key = response['Contents'][0]['Key']
                        # Save to cache
                        self.movie_key_cache[movie_id] = video_key
                    else:
                        print(f"No video found for movie ID {movie_id}")
                        return None
                except Exception as e:
                    print(f"Error listing objects for movie ID {movie_id}: {e}")
                    return None
            else:
                # If title is provided, use it to form the key
                video_key = f"movie_trailers/{movie_id}_{title}.mp4"
                self.movie_key_cache[movie_id] = video_key
        
        # Extract the filename from the key
        filename = os.path.basename(video_key)
        local_path = os.path.join(tempfile.gettempdir(), filename)
        
        # Check if file exists locally
        if os.path.exists(local_path):
            return local_path
        
        # Try to get from cache first, then download if not in cache
        try:
            # THIS IS THE KEY CHANGE: Use the cache to get the file content
            file_content = self.cache.get(video_key)
            
            if file_content is not None:
                # Write cache content to local file
                with open(local_path, 'wb') as f:
                    f.write(file_content)
                return local_path
            
            # If not in cache, download directly
            self.s3_client.download_file(self.bucket_name, video_key, local_path)
            print(f"Downloaded {video_key} to {local_path}")
            return local_path
        except Exception as e:
            print(f"Error downloading video {video_key}: {e}")
            return None
            
    def cleanup_temp_files(self):
        """Remove all temporary downloaded video files to free space"""
        temp_dir = tempfile.gettempdir()
        for filename in os.listdir(temp_dir):
            if filename.startswith("_") and (filename.endswith(".mp4") or filename.endswith(".avi")):
                try:
                    os.remove(os.path.join(temp_dir, filename))
                except Exception as e:
                    print(f"Could not remove {filename}: {e}")

    #TODO test with different frame nbs (16-32-64(max))
    def get_preextracted_frames(self, movie_id, num_frames=16):
        """Get pre-extracted frames from S3"""
        try:
            # NEW PATH: Updated to use the tensor files
            frames_key = f"movie_trailers_frames_tensors/{movie_id}_frames.pt"
            
            with tempfile.NamedTemporaryFile(delete=False) as temp_file:
                self.s3_client.download_file(self.bucket_name, frames_key, temp_file.name)
                frames = torch.load(temp_file.name, map_location='cpu')
                os.unlink(temp_file.name)
                
            # have 64 frames but only need 16 -> sample evenly
            if frames.shape[0] > num_frames:
                indices = torch.linspace(0, frames.shape[0] - 1, num_frames).long()
                frames = frames[indices]
                
            # # Optional: Print debug info for 1% of loads
            # if torch.rand(1).item() < 0.01:  # 1% of loads
            #     print(f"📹 Loaded {frames.shape[0]} frames for movie {movie_id}")
                
            return frames
            
        except Exception as e:
            print(f"Error loading pre-extracted frames for {movie_id}: {e}")
            # Fallback to zeros
            return torch.zeros((num_frames, 3, 224, 224))

    def get_preextracted_dino_frame_embeddings(self, movie_id):
        try:
            embedding_key = f"{DINO_PER_FRAME_EMBEDDING_S3_PREFIX}{movie_id}_dino_per_frame.pt"
            
            # Use the cache
            file_content_bytes = self.cache.get(embedding_key) # self.cache is CachedS3Dataset instance
            if file_content_bytes is None:
                # This means it wasn't in cache and download failed, or key doesn't exist
                print(f"DINO per-frame emb for {movie_id} not found via cache or S3 for key {embedding_key}")
                raise FileNotFoundError(f"DINO per-frame emb {embedding_key} not found.")

            dino_embeddings = torch.load(io.BytesIO(file_content_bytes), map_location='cpu')
            # Ensure consistent number of frames if necessary, though DINO per-frame should match NUM_FRAMES_FROM_PT
            if dino_embeddings.shape[0] != NUM_FRAMES_FROM_PT:
                 print(f"Warning: DINO per-frame for {movie_id} has {dino_embeddings.shape[0]} frames, expected {NUM_FRAMES_FROM_PT}. Adjusting/Padding.")
                 if dino_embeddings.shape[0] > NUM_FRAMES_FROM_PT:
                     indices = torch.linspace(0, dino_embeddings.shape[0] - 1, NUM_FRAMES_FROM_PT).long()
                     dino_embeddings = dino_embeddings[indices]
                 else:
                     padding_shape = (NUM_FRAMES_FROM_PT - dino_embeddings.shape[0], dino_embeddings.shape[1])
                     padding = torch.zeros(padding_shape, dtype=dino_embeddings.dtype, device=dino_embeddings.device)
                     dino_embeddings = torch.cat((dino_embeddings, padding), dim=0)
            return dino_embeddings
        except Exception as e:
            print(f"Error loading pre-extracted DINO per-frame embeddings for {movie_id} (key: {embedding_key if 'embedding_key' in locals() else 'UNKNOWN'}): {e}")
            return torch.zeros(NUM_FRAMES_FROM_PT, DINO_DIM)

    def get_preextracted_audio_vggish_wave(self, movie_id):
        try:
            audio_key = f"{AUDIO_VGGISH_WAVE_S3_PREFIX}{movie_id}_audio_vggish_wave.pt"

            file_content_bytes = self.cache.get(audio_key)
            if file_content_bytes is None:
                print(f"Audio data for {movie_id} not found via cache or S3 for key {audio_key}")
                raise FileNotFoundError(f"Audio data {audio_key} not found.")

            audio_data = torch.load(io.BytesIO(file_content_bytes), map_location='cpu')
            
            waveform = audio_data['waveform']
            if waveform.ndim == 1:
                waveform = waveform.unsqueeze(0) 
            return audio_data['vggish_embedding'], waveform
        except Exception as e:
            print(f"Error loading pre-extracted VGGish & waveform for {movie_id} (key: {audio_key if 'audio_key' in locals() else 'UNKNOWN'}): {e}")
            return torch.zeros(VGGISH_DIM), torch.zeros(1, SAMPLE_RATE * MAX_AUDIO_LENGTH_SEC)

        
class TextDataset:
    """Helper class to manage text data"""
    def __init__(self, csv_path, lang_vocab=None):
        self.s3_client = boto3.client('s3')
        
        # Parse S3 path
        bucket_name = csv_path.split('//')[1].split('/')[0]
        key = '/'.join(csv_path.split('//')[1].split('/')[1:])
        
        # Download the CSV file
        with tempfile.NamedTemporaryFile(delete=False, suffix='.csv') as temp_file:
            self.s3_client.download_file(bucket_name, key, temp_file.name)
            self.text_df = pd.read_csv(temp_file.name)
        
        # Set index to movieId for faster lookups
        self.text_df.set_index('movieId', inplace=True)
        
        # Initialize tokenizer
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        
        # Initialize TF-IDF vectorizer
        self.tfidf = TfidfVectorizer(max_features=5000)
        # Fit TF-IDF on plots
        self.tfidf.fit(self.text_df['plot'].fillna('').values)
    
    # --- NEW: Create Language Vocabulary ---
        # Find all unique languages by splitting the 'language' column
        # --- MODIFIED: Use pre-built vocab or create a new one ---
        if lang_vocab:
            self.lang_vocab = lang_vocab
        else:
            # This will only be run once by the main process
            print("Building new language vocabulary from scratch...")
            all_langs = set(['']) # <<< Add an empty string to guarantee index 0 exists
            self.text_df['language'].dropna().str.split('|').apply(lambda langs: all_langs.update(set(langs)))
            self.lang_vocab = {lang: i for i, lang in enumerate(sorted(list(all_langs)))}
        
        self.num_languages = len(self.lang_vocab)
        # --- END MODIFICATION ---
        print(f"🌍 Created language vocabulary with {self.num_languages} languages.")

        # --- NEW: Prepare for Year Normalization ---
        self.min_year = self.text_df['year'].min()
        self.max_year = self.text_df['year'].max()
        print(f"🗓️ Year range for normalization: {self.min_year} - {self.max_year}")

    def get_text_features(self, movie_id):
        """Get text, year, and language features for a given movie ID."""
        try:
            movie_data = self.text_df.loc[movie_id]
            title = movie_data.get('title', "")
            plot = movie_data.get('plot', "")
            
            # --- NEW: Process Year ---
            year = movie_data.get('year', self.min_year) # Default to min_year if missing
            # Normalize year to be roughly in [0, 1] range
            normalized_year = (year - self.min_year) / (self.max_year - self.min_year + 1e-6)
            year_tensor = torch.tensor([normalized_year], dtype=torch.float32)

            # --- NEW: Process Language ---
            # For simplicity, we use the first language if multiple are listed
            # lang_str = movie_data.get('language', "").split('|')[0]
            # lang_idx = self.lang_vocab.get(lang_str, 0) # Default to the first lang if not found
            # lang_tensor = torch.tensor([lang_idx], dtype=torch.long)
            # --- THIS IS THE FIX for the AttributeError ---
            # Safely handle Language
            # --- Make language processing more robust ---
            lang_val = movie_data.get('language')
            if isinstance(lang_val, str):
                lang_str = lang_val.split('|')[0]
            else:
                lang_str = "" # Default to empty string if NaN or not a string
    
            lang_idx = self.lang_vocab.get(lang_str, self.lang_vocab.get("", 0))
            lang_tensor = torch.tensor([lang_idx], dtype=torch.long)

            # BERT and TF-IDF processing (as before)
            full_text = f"{title} [SEP] {plot}"
            bert_tokens = self.tokenizer(full_text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
            tfidf_vector = self.tfidf.transform([plot]).toarray()[0]
            tfidf_tensor = torch.tensor(tfidf_vector, dtype=torch.float32)
            
            return {
                'title': title,
                'plot': plot,
                'tfidf_features': tfidf_tensor,
                'year': year_tensor, # <<< NEW
                'language_idx': lang_tensor # <<< NEW
            }
        except KeyError:
            # Return default values if movie is not found
            return {
                'title': "", 'plot': "",
                'tfidf_features': torch.zeros(5000, dtype=torch.float32),
                'year': torch.tensor([0.5], dtype=torch.float32), # Default normalized year
                'language_idx': torch.tensor([0], dtype=torch.long) # Default language index
            }

def extract_video_frames(video_path, num_frames=16):
    """Extract frames from a video file"""
    frames = []
    if video_path is None or not os.path.exists(video_path):
        # Return zeros if video doesn't exist
        return torch.zeros((num_frames, 3, 224, 224))
    
    try:
        cap = cv2.VideoCapture(video_path)
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        if frame_count <= 0:
            cap.release()
            return torch.zeros((num_frames, 3, 224, 224))
        
        # Calculate frame indices to extract (evenly distributed)
        indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
        
        transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        for idx in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if ret:
                # Convert BGR to RGB
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                # Apply transformations
                frame = transform(frame)
                frames.append(frame)
            else:
                # If frame reading fails, add zeros
                frames.append(torch.zeros((3, 224, 224)))
        
        cap.release()
        
        # If we couldn't get enough frames, pad with zeros
        while len(frames) < num_frames:
            frames.append(torch.zeros((3, 224, 224)))
            
        return torch.stack(frames)
    
    except Exception as e:
        print(f"Error extracting frames from {video_path}: {e}")
        return torch.zeros((num_frames, 3, 224, 224))

def extract_audio_features(video_path, max_length=10, sr=16000):
    """Extract audio features from a video file"""
    if video_path is None or not os.path.exists(video_path):
        # Return zeros if video doesn't exist
        return {
            'waveform': torch.zeros((1, sr * max_length)),
            'spectrogram': torch.zeros((128, 100))
        }
    
    try:
        # Extract audio using librosa
        y, _ = librosa.load(video_path, sr=sr, mono=True, duration=max_length)
        
        # Pad if audio is shorter than max_length
        if len(y) < sr * max_length:
            padding = sr * max_length - len(y)
            y = np.pad(y, (0, padding), mode='constant')
        # Trim if audio is longer than max_length
        else:
            y = y[:sr * max_length]
        
        # Create mel spectrogram
        mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128)
        log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
        
        # Normalize
        log_mel_spec = (log_mel_spec - log_mel_spec.mean()) / (log_mel_spec.std() + 1e-8)
        
        # Resize spectrogram to fixed size (128, 100)
        h, w = log_mel_spec.shape
        if w > 100:
            log_mel_spec = log_mel_spec[:, :100]
        else:
            padding = np.zeros((h, 100 - w))
            log_mel_spec = np.concatenate([log_mel_spec, padding], axis=1)
        
        return {
            'waveform': torch.tensor(y, dtype=torch.float32).unsqueeze(0),
            'spectrogram': torch.tensor(log_mel_spec, dtype=torch.float32)
        }
    
    except Exception as e:
        print(f"Error extracting audio from {video_path}: {e}")
        # Add more detailed error info for debugging
        traceback.print_exc()
        
        return {
            'waveform': torch.zeros((1, sr * max_length)),
            'spectrogram': torch.zeros((128, 100))
        }

In [None]:
class PairDataset(Dataset):
    """Dataset for pair-based Siamese network"""
    def __init__(self, data_path, bucket_name, text_csv_path, ablation_mode, split='train', lang_vocab=None):
        self.s3_client = boto3.client('s3')
        self.bucket_name = bucket_name
        self.split = split
        self.ablation_mode = ablation_mode # Store it
        
        # Parse S3 path
        s3_path_parts = data_path.split('/')
        base_bucket = s3_path_parts[2]
        base_prefix = '/'.join(s3_path_parts[3:])

        base_prefix = '/'.join(s3_path_parts[3:])
        base_prefix = base_prefix.rstrip('/')
        
        pos_key = f"{base_prefix}/{split}_positive_pairs.pkl"
        with tempfile.NamedTemporaryFile(delete=False) as temp_file:
            try:
                self.s3_client.download_file(base_bucket, pos_key, temp_file.name)
                with open(temp_file.name, 'rb') as f: 
                    self.positive_pairs = pickle.load(f)
            except Exception as e:
                print(f"Error downloading/loading positive pairs: {e}")
                print(f"Attempted to access: s3://{base_bucket}/{pos_key}")
                raise
        
        # Download negative pairs
        neg_key = f"{base_prefix}/{split}_negative_pairs.pkl"  
        with tempfile.NamedTemporaryFile(delete=False) as temp_file:
            try:
                self.s3_client.download_file(base_bucket, neg_key, temp_file.name)
                with open(temp_file.name, 'rb') as f:
                    self.negative_pairs = pickle.load(f)
            except Exception as e:
                print(f"Error downloading/loading negative pairs: {e}")
                print(f"Attempted to access: s3://{base_bucket}/{neg_key}")
                raise
        
        # Combine positive and negative pairs
        self.all_pairs = []
        for pair in self.positive_pairs:
            movie1_id, movie2_id = pair
            self.all_pairs.append((movie1_id, movie2_id, 1))
        for pair in self.negative_pairs:
            movie1_id, movie2_id = pair
            self.all_pairs.append((movie1_id, movie2_id, 0))
            
        # Initialize video and text helpers
        self.video_dataset = VideoDatasetS3(bucket_name)
        if self.ablation_mode == 'full':
            self.text_dataset = TextDataset(text_csv_path, lang_vocab=lang_vocab)
        else:
            self.text_dataset = None
    
    def __len__(self):
        return len(self.all_pairs)

    @lru_cache(maxsize=128) #try32 if not working
    def _cached_get_video_frames(self, video_path):
        return extract_video_frames(video_path)
    
    @lru_cache(maxsize=128)
    def _cached_get_audio_features(self, video_path):
        return extract_audio_features(video_path)
    
    def __getitem__(self, idx):
        movie1_id, movie2_id, label = self.all_pairs[idx]
        
        # --- Visual ---
        # Load the original (normalized) frames for low-level feature calculation
        frames1_for_low_level = self.video_dataset.get_preextracted_frames(movie1_id, num_frames=NUM_FRAMES_FROM_PT)
        frames2_for_low_level = self.video_dataset.get_preextracted_frames(movie2_id, num_frames=NUM_FRAMES_FROM_PT)
        
        # Load pre-extracted DINOv2 per-frame embeddings
        dino_frame_embeddings1 = self.video_dataset.get_preextracted_dino_frame_embeddings(movie1_id)
        dino_frame_embeddings2 = self.video_dataset.get_preextracted_dino_frame_embeddings(movie2_id)

        if torch.isnan(dino_frame_embeddings1).any():
            print(f"!!! NaN found in DINO embedding for movie_id: {movie1_id}")
    
        # --- Audio ---
        vggish_emb1, waveform1 = torch.zeros(VGGISH_DIM), torch.zeros(1, SAMPLE_RATE * MAX_AUDIO_LENGTH_SEC)
        vggish_emb2, waveform2 = torch.zeros(VGGISH_DIM), torch.zeros(1, SAMPLE_RATE * MAX_AUDIO_LENGTH_SEC)
    
        if self.ablation_mode != 'visual_only':
            vggish_emb1, waveform1 = self.video_dataset.get_preextracted_audio_vggish_wave(movie1_id)
            vggish_emb2, waveform2 = self.video_dataset.get_preextracted_audio_vggish_wave(movie2_id)
        
        # --- Text ---
        item = {
            'movie1_id': movie1_id,
            'movie2_id': movie2_id,
            'frames1_for_low_level': frames1_for_low_level, 
            'dino_frame_embeddings1': dino_frame_embeddings1,
            'frames2_for_low_level': frames2_for_low_level,
            'dino_frame_embeddings2': dino_frame_embeddings2,
            'vggish_embedding1': vggish_emb1,                
            'waveform1': waveform1,                          
            'vggish_embedding2': vggish_emb2,
            'waveform2': waveform2,
            'label': torch.tensor(label, dtype=torch.float32)
        }
        
        if self.ablation_mode == 'full' and self.text_dataset is not None:
            text1_data = self.text_dataset.get_text_features(movie1_id)
            item['video1_title'] = text1_data['title']
            item['video1_plot'] = text1_data['plot']
            item['video1_tfidf'] = text1_data['tfidf_features']
            item['video1_year'] = text1_data['year']            
            item['video1_language_idx'] = text1_data['language_idx'] 
            
            text2_data = self.text_dataset.get_text_features(movie2_id)
            item['video2_title'] = text2_data['title']
            item['video2_plot'] = text2_data['plot']
            item['video2_tfidf'] = text2_data['tfidf_features']
            item['video2_year'] = text2_data['year']        
            item['video2_language_idx'] = text2_data['language_idx']

        return item

In [None]:
# =================================================================
# CLASS 1: FOR ONLINE TRAINING (provides Anchor-Positive pairs)
# =================================================================
class TripletDatasetForOnlineTraining(Dataset):
    """
    This dataset is used ONLY for TRAINING.
    It loads POSITIVE pairs and returns a dictionary for the anchor and a dictionary for the positive.
    The negative is found "online" during the training loop.
    """
    def __init__(self, data_path, bucket_name, text_csv_path, ablation_mode, split='train', lang_vocab=None):
        self.s3_client = boto3.client('s3')
        self.bucket_name = bucket_name
        self.split = split
        self.ablation_mode = ablation_mode

        s3_path_parts = data_path.split('/')
        if len(s3_path_parts) < 3:
             raise ValueError(f"Invalid S3 path format: {data_path}")
        
        base_bucket = s3_path_parts[2]
        base_prefix = '/'.join(s3_path_parts[3:]).rstrip('/')
        
        # This dataset loads POSITIVE pairs to serve as (Anchor, Positive)
        pos_key = f"{base_prefix}/{split}_positive_pairs.pkl"
        
        print(f"Loading (Anchor, Positive) pairs from: s3://{base_bucket}/{pos_key}")
        with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as temp_file:
            try:
                self.s3_client.download_file(base_bucket, pos_key, temp_file.name)
                with open(temp_file.name, 'rb') as f:
                    self.anchor_positive_pairs = pickle.load(f)
                print(f"Successfully loaded {len(self.anchor_positive_pairs)} anchor-positive pairs for '{split}' split.")
            except Exception as e:
                print(f"FATAL ERROR: Could not download/load positive pairs for online training from s3://{base_bucket}/{pos_key}")
                print(f"Error details: {e}")
                raise
            finally:
                os.unlink(temp_file.name)
        
        # Initialize helper classes
        self.video_dataset = VideoDatasetS3(bucket_name)
        if self.ablation_mode == 'full':
            self.text_dataset = TextDataset(text_csv_path, lang_vocab=lang_vocab)
        else:
            self.text_dataset = None

    def __len__(self):
        return len(self.anchor_positive_pairs)
    
    def _get_single_item_features(self, movie_id):
        """Helper function to load all features for a single movie_id."""
        
        # --- Visual Features ---
        frames_for_low_level = self.video_dataset.get_preextracted_frames(movie_id, num_frames=NUM_FRAMES_FROM_PT)
        dino_frame_embeddings = self.video_dataset.get_preextracted_dino_frame_embeddings(movie_id)

        # --- Audio Features ---
        vggish_emb, waveform = torch.zeros(VGGISH_DIM), torch.zeros(1, SAMPLE_RATE * MAX_AUDIO_LENGTH_SEC)
        if self.ablation_mode != 'visual_only':
            vggish_emb, waveform = self.video_dataset.get_preextracted_audio_vggish_wave(movie_id)
        
        # --- Prepare the data dictionary ---
        item_data = {
            'movie_id': torch.tensor(movie_id, dtype=torch.long), # CRITICAL for online mining
            'frames_for_low_level': frames_for_low_level,
            'dino_frame_embeddings': dino_frame_embeddings,
            'vggish_embedding': vggish_emb,
            'waveform_for_spec_cnn': waveform
        }

        # --- Text Features (if applicable) ---
        if self.ablation_mode == 'full' and self.text_dataset is not None:
            text_data = self.text_dataset.get_text_features(movie_id)
            item_data.update(text_data)
        else:
            # Provide None placeholders for model's forward signature if text isn't used
            item_data.update({'title': None, 'plot': None, 'tfidf_features': None, 'year': None, 'language_idx': None})
            
        return item_data

    def __getitem__(self, idx):
        # This returns the data for the anchor and the positive
        anchor_id, positive_id = self.anchor_positive_pairs[idx]
        anchor_data = self._get_single_item_features(anchor_id)
        positive_data = self._get_single_item_features(positive_id)
        
        return {"anchor_data": anchor_data, "positive_data": positive_data}


# ======================================================================
# CLASS 2: FOR VALIDATION (provides pre-generated A, P, N triplets)
# ======================================================================
class TripletDatasetForValidation(Dataset):
    """
    This dataset is used ONLY for VALIDATION.
    It loads the pre-generated, "easy" triplets file to provide a consistent
    benchmark for measuring validation loss.
    """
    def __init__(self, data_path, bucket_name, text_csv_path, ablation_mode, split='validation', lang_vocab=None):
        self.s3_client = boto3.client('s3')
        self.bucket_name = bucket_name
        self.split = split
        self.ablation_mode = ablation_mode

        s3_path_parts = data_path.split('/')
        if len(s3_path_parts) < 3:
             raise ValueError(f"Invalid S3 path format: {data_path}")

        base_bucket = s3_path_parts[2]
        base_prefix = '/'.join(s3_path_parts[3:]).rstrip('/')
        
        # This dataset loads the pre-generated TRIPLETS
        triplet_key = f"{base_prefix}/{split}_triplets.pkl"
        
        print(f"Loading (A, P, N) triplets from: s3://{base_bucket}/{triplet_key}")
        with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as temp_file:
            try:
                self.s3_client.download_file(base_bucket, triplet_key, temp_file.name)
                with open(temp_file.name, 'rb') as f:
                    self.triplets = pickle.load(f)
                print(f"Successfully loaded {len(self.triplets)} triplets for '{split}' split.")
            except Exception as e:
                print(f"Could not load '{triplet_key}'. Trying 'all_triplets.pkl' instead...")
                triplet_key = f"{base_prefix}/all_triplets.pkl"
                try:
                    self.s3_client.download_file(base_bucket, triplet_key, temp_file.name)
                    with open(temp_file.name, 'rb') as f:
                        all_triplets = pickle.load(f)
                    
                    random.seed(42) # for reproducible splits
                    random.shuffle(all_triplets)
                    train_end = int(0.7 * len(all_triplets))
                    val_end = int(0.85 * len(all_triplets))
                    
                    if split == 'train': self.triplets = all_triplets[:train_end]
                    elif split == 'validation': self.triplets = all_triplets[train_end:val_end]
                    else: self.triplets = all_triplets[val_end:]
                    
                    print(f"Loaded from 'all_triplets.pkl' and took {len(self.triplets)} for '{split}' split.")
                    
                except Exception as e2:
                    print(f"FATAL ERROR: Could not download/load triplets from either key. Error: {e2}")
                    raise
            finally:
                os.unlink(temp_file.name)

        # Initialize helper classes (same as the training dataset)
        self.video_dataset = VideoDatasetS3(bucket_name)
        if self.ablation_mode == 'full':
            self.text_dataset = TextDataset(text_csv_path, lang_vocab=lang_vocab)
        else:
            self.text_dataset = None
    
    def __len__(self):
        return len(self.triplets)

    def _get_single_item_features(self, movie_id):
        """Helper function to load all features for a single movie_id."""
        item_data = {
            'movie_id': torch.tensor(movie_id, dtype=torch.long),
            'frames_for_low_level': self.video_dataset.get_preextracted_frames(movie_id, num_frames=NUM_FRAMES_FROM_PT),
            'dino_frame_embeddings': self.video_dataset.get_preextracted_dino_frame_embeddings(movie_id),
            'vggish_embedding': torch.zeros(VGGISH_DIM),
            'waveform_for_spec_cnn': torch.zeros(1, SAMPLE_RATE * MAX_AUDIO_LENGTH_SEC)
        }
        if self.ablation_mode != 'visual_only':
            item_data['vggish_embedding'], item_data['waveform_for_spec_cnn'] = self.video_dataset.get_preextracted_audio_vggish_wave(movie_id)
        
        if self.ablation_mode == 'full' and self.text_dataset is not None:
            text_data = self.text_dataset.get_text_features(movie_id)
            item_data.update(text_data)
        else:
            item_data.update({'title': None, 'plot': None, 'tfidf_features': None, 'year': None, 'language_idx': None})
            
        return item_data

    def __getitem__(self, idx):
        # This returns the data for the pre-generated anchor, positive, and negative
        anchor_id, positive_id, negative_id = self.triplets[idx]
        anchor_data = self._get_single_item_features(anchor_id)
        positive_data = self._get_single_item_features(positive_id)
        negative_data = self._get_single_item_features(negative_id)
        
        return {
            "anchor_data": anchor_data,
            "positive_data": positive_data,
            "negative_data": negative_data
        }

In [None]:
class AllMoviesDataset(Dataset):
    """
    A dataset that returns individual movie items and their movie_id.
    This is used for proper in-batch negative mining.
    """
    def __init__(self, data_path, bucket_name, text_csv_path, ablation_mode, split='train', lang_vocab=None):
        self.ablation_mode = ablation_mode
        
        print(f"Initializing AllMoviesDataset for '{split}' split...")
        s3_client = boto3.client('s3')
        s3_path_parts = data_path.split('/')
        base_bucket = s3_path_parts[2]
        base_prefix = '/'.join(s3_path_parts[3:]).rstrip('/')

        all_movie_ids = set()
        for pair_file in [f"{split}_positive_pairs.pkl", f"{split}_negative_pairs.pkl"]:
            key = f"{base_prefix}/{pair_file}"
            try:
                with tempfile.NamedTemporaryFile(delete=False) as temp_f:
                    s3_client.download_file(base_bucket, key, temp_f.name)
                    with open(temp_f.name, 'rb') as f:
                        pairs = pickle.load(f)
                    for id1, id2 in pairs:
                        all_movie_ids.add(id1)
                        all_movie_ids.add(id2)
                os.unlink(temp_f.name)
            except Exception as e:
                print(f"Warning: Could not load {key}. It might not exist for this split. Error: {e}")

        self.movie_ids = sorted(list(all_movie_ids))
        print(f"Found {len(self.movie_ids)} unique movies for the '{split}' split.")

        self.video_dataset = VideoDatasetS3(bucket_name)
        if self.ablation_mode == 'full':
            self.text_dataset = TextDataset(text_csv_path, lang_vocab=lang_vocab)
        else:
            self.text_dataset = None

        self._get_single_item_features = TripletDatasetForValidation._get_single_item_features.__get__(self)

    def __len__(self):
        return len(self.movie_ids)

    def __getitem__(self, idx):
        movie_id = self.movie_ids[idx]
        item_data = self._get_single_item_features(movie_id)
        return item_data

In [None]:
def save_checkpoint(state, is_best=False, checkpoint_dir='checkpoints', s3_bucket='md-data-content-recommendation', s3_prefix='n2n-model/cosine-with-languages-checkpoints'):
    """
    Save a model checkpoint both locally and to S3.
    
    Args:
        state: The model state to save
        is_best: Whether this is the best model so far
        checkpoint_dir: Local directory to save checkpoints (temporary)
        s3_bucket: S3 bucket name
        s3_prefix: Prefix path in the S3 bucket
    """
    # Create local directory if it doesn't exist
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Local filename
    local_filename = os.path.join(checkpoint_dir, f"epoch_{state['epoch']}.pth")
    
    try:
        # First try saving directly to S3 (preferred method)
        s3_client = boto3.client('s3')
        
        # Save to a buffer first
        buffer = io.BytesIO()
        torch.save(state, buffer)
        buffer.seek(0)
        
        # Upload to S3
        s3_path = f"{s3_prefix}/epoch_{state['epoch']}.pth"
        s3_client.upload_fileobj(buffer, s3_bucket, s3_path)
        print(f"Checkpoint saved to s3://{s3_bucket}/{s3_path}")
        
        # If best model, save an additional copy
        if is_best:
            buffer.seek(0)  # Reset buffer position
            best_s3_path = f"{s3_prefix}/best_model.pth"
            s3_client.upload_fileobj(buffer, s3_bucket, best_s3_path)
            print(f"New best model saved to s3://{s3_bucket}/{best_s3_path}")
            
    except Exception as e:
        print(f"Error saving to S3: {e}. Falling back to local save.")
        # Fall back to local save
        torch.save(state, local_filename)
        
        if is_best:
            best_path = os.path.join(checkpoint_dir, "best_model.pth")
            torch.save(state, best_path)
            print(f"New best model saved to {best_path}")

def load_checkpoint_from_s3(s3_bucket, s3_key, map_location=None):
    """
    Load a checkpoint directly from S3.
    
    Args:
        s3_bucket: S3 bucket name
        s3_key: S3 key (path to the checkpoint)
        map_location: Optional device mapping for torch.load
        
    Returns:
        The loaded checkpoint
    """
    s3_client = boto3.client('s3')
    buffer = io.BytesIO()
    
    print(f"Loading checkpoint from s3://{s3_bucket}/{s3_key}")
    s3_client.download_fileobj(s3_bucket, s3_key, buffer)
    buffer.seek(0)
    
    return torch.load(buffer, map_location=map_location, weights_only=False)

def clean_local_checkpoints(checkpoint_dir='checkpoints', keep_latest=2):
    """
    Remove old local checkpoints, keeping only the specified number of latest files.
    
    Args:
        checkpoint_dir: Directory containing checkpoints
        keep_latest: Number of latest checkpoints to keep
    """
    if not os.path.exists(checkpoint_dir):
        return
        
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith('epoch_') and f.endswith('.pth')]
    if len(checkpoints) <= keep_latest:
        return
        
    # Sort by epoch number
    checkpoints.sort(key=lambda x: int(x.split('_')[1].split('.')[0]), reverse=True)
    
    # Delete older checkpoints
    for checkpoint in checkpoints[keep_latest:]:
        os.remove(os.path.join(checkpoint_dir, checkpoint))
        print(f"Removed old checkpoint: {checkpoint}")


In [None]:
class CachedS3Dataset:
    def __init__(self, bucket_name, cache_size=1000):
        self.cache = OrderedDict()  # Use OrderedDict for LRU functionality
        self.cache_size = cache_size
        self.s3 = boto3.client('s3')
        self.bucket_name = bucket_name
        
    def get(self, key):
        if key in self.cache:
            # Move this key to the end (most recently used position)
            value = self.cache.pop(key)
            self.cache[key] = value
            return value
        
        # Download to a BytesIO object if not in cache
        buffer = io.BytesIO()
        try:
            self.s3.download_fileobj(self.bucket_name, key, buffer)
            buffer.seek(0)  # Reset pointer to start of buffer
            data = buffer.read()  # Store the file content
            
            # If we've reached cache limit, remove least recently used item
            if len(self.cache) >= self.cache_size:
                self.cache.popitem(last=False)  # Remove first item (least recently used)
                
            self.cache[key] = data
            return data
        except Exception as e:
            print(f"Error downloading {key} to cache: {e}")
            return None

## Model Factory

In [None]:
def create_model_and_criterion(config, device, num_languages):
    """Builds the model, criterion, and optimizer from a config dict."""
    
    # 1. Build the sub-modules (these are the same for all models)
    visual_module = VisualProcessingModule(use_precomputed_embeddings=True, dinov2_embedding_dim=DINO_DIM)
    audio_module = OptimizedAudioProcessingModule(use_vggish=True, vggish_embedding_dim=VGGISH_DIM)
    text_module = TextProcessingModule(num_languages=num_languages)

    # 2. Build the correct Fusion Network
    fusion_config = config['fusion_network']
    if fusion_config == 'Hybrid':
        fusion_net = HybridFusionNetwork(visual_module.output_dim, audio_module.output_dim, text_module.output_dim)
    elif fusion_config == 'Simple':
        fusion_net = SimpleFusionNetwork(visual_module.output_dim, audio_module.output_dim, text_module.output_dim)
    elif fusion_config == 'UltraSimple':
        fusion_net = UltraSimpleFusionNetwork(visual_module.output_dim, audio_module.output_dim, text_module.output_dim)
    else:
        raise ValueError(f"Unknown fusion network: {fusion_config}")

    # 3. Build the correct SNN architecture and Loss Function
    model_type = config['model_type']
    loss_type = config['loss_type']
    
    if model_type == 'pair':
        model = OptimizedPairBasedSiameseNetwork(visual_module, audio_module, text_module, fusion_net)
        if loss_type == 'cosine':
            criterion = ContrastiveLossCosine(margin=0.8).to(device)
        elif loss_type == 'euclidean':
            criterion = ContrastiveLossEuclidean(margin=1.2).to(device)
        else: raise ValueError("Invalid loss type for pair model")
            
    elif model_type == 'triplet':
        model = OnlineTripletSiameseNetwork(visual_module, audio_module, text_module, fusion_net)
        if loss_type == 'cosine':
            criterion = TripletLossCosine(margin=0.3).to(device)
        elif loss_type == 'euclidean':
            criterion = TripletLossEuclidean(margin=1.0).to(device)
        else: raise ValueError("Invalid loss type for triplet model")
    
    else: raise ValueError(f"Unknown model_type: {model_type}")

    # 4. Create Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=config['hyperparameters']['lr'], weight_decay=1e-2)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.7, verbose=True)

    return model.to(device), criterion, optimizer, scheduler

In [None]:
def create_small_subset_for_testing(dataset, size=1000):
    """
    Create a small subset of the dataset for testing.
    This function now handles both PairDataset and TripletDataset.
    """
    random.seed(42)  # For reproducibility

    # --- Check if the dataset is for pairs or triplets ---
    if hasattr(dataset, 'all_pairs'):
        # This is a PairDataset
        print(f"Subsetting a PairDataset...")
        
        # Check current distribution
        original_size = len(dataset.all_pairs)
        current_positive = sum(1 for pair in dataset.all_pairs if pair[2] == 1)
        current_negative = original_size - current_positive
        print(f"Original dataset - Size: {original_size}, Positive: {current_positive}, Negative: {current_negative}")

        # Separate positive and negative pairs
        positive_pairs = [pair for pair in dataset.all_pairs if pair[2] == 1]
        negative_pairs = [pair for pair in dataset.all_pairs if pair[2] == 0]

        # Take half positive, half negative (or whatever is available)
        half_size = size // 2
        sampled_positive = random.sample(positive_pairs, min(half_size, len(positive_pairs)))
        sampled_negative = random.sample(negative_pairs, min(half_size, len(negative_pairs)))
        
        # Combine and shuffle
        dataset.all_pairs = sampled_positive + sampled_negative
        random.shuffle(dataset.all_pairs)
        print(f"New subset size: {len(dataset.all_pairs)} (Positive: {len(sampled_positive)}, Negative: {len(sampled_negative)})")

    elif hasattr(dataset, 'triplets'):
        # This is a TripletDataset
        print(f"Subsetting a TripletDataset...")
        original_size = len(dataset.triplets)
        print(f"Original dataset - Size: {original_size} triplets")

        # For triplets, we don't need to balance. Just take a random sample.
        if original_size > size:
            dataset.triplets = random.sample(dataset.triplets, size)
        
        print(f"New subset size: {len(dataset.triplets)} triplets")
        
    else:
        # Fallback if we get an unknown dataset type
        print("Warning: Unknown dataset type in create_small_subset_for_testing. Skipping subset creation.")

    return dataset

In [None]:
def create_dataloaders(config, lang_vocab, num_languages):
    """
    Creates the correct dataloaders based on the model_type and test_mode in the config.
    """
    model_type = config['model_type']
    batch_size = config['hyperparameters']['batch_size']
    test_mode = config.get('test_mode', False)
    test_size = config.get('test_size', 10000)

    # These are needed for initializing the datasets
    DATA_PATH = 's3://md-data-content-recommendation/SNN-training-data/optimized4/'
    BUCKET_NAME = 'md-data-content-recommendation'
    TEXT_CSV = 's3://md-data-content-recommendation/cleaned_textual_trailers_dataset_with_languages.csv'
    ABLATION_MODE = 'full'
    
    train_loader, val_loader = None, None
    positive_pairs_set = None

    if model_type == 'pair':
        print("--- Creating Dataloaders for Pair-based training ---")
        train_dataset = PairDataset(DATA_PATH, BUCKET_NAME, TEXT_CSV, ABLATION_MODE, split='train', lang_vocab=lang_vocab)
        val_dataset = PairDataset(DATA_PATH, BUCKET_NAME, TEXT_CSV, ABLATION_MODE, split='validation', lang_vocab=lang_vocab)
        
        if test_mode:
            print(f"--- Applying Test Mode: Subsetting datasets to size ~{test_size} ---")
            train_dataset = create_small_subset_for_testing(train_dataset, test_size)
            val_dataset = create_small_subset_for_testing(val_dataset, test_size // 5)
            
    elif model_type == 'triplet':
        print("--- Creating Dataloaders for Online Triplet Mining ---")
        train_dataset = AllMoviesDataset(DATA_PATH, BUCKET_NAME, TEXT_CSV, ABLATION_MODE, split='train', lang_vocab=lang_vocab)
        val_dataset = TripletDatasetForValidation(DATA_PATH, BUCKET_NAME, TEXT_CSV, ABLATION_MODE, split='validation', lang_vocab=lang_vocab)
        
        # We still need the full set of positive pairs for the mining logic in the training loop
        pos_pair_helper = TripletDatasetForOnlineTraining(DATA_PATH, BUCKET_NAME, TEXT_CSV, 'full', split='train', lang_vocab=lang_vocab)
        positive_pairs_set = {tuple(sorted(p)) for p in pos_pair_helper.anchor_positive_pairs}
        del pos_pair_helper
        
        if test_mode:
            print(f"--- Applying Test Mode: Subsetting datasets to size ~{test_size} ---")
            # For AllMoviesDataset, we sample the list of movie_ids
            original_size = len(train_dataset.movie_ids)
            train_dataset.movie_ids = random.sample(train_dataset.movie_ids, min(test_size, original_size))
            print(f"Subsetted Triplet training movies from {original_size} to {len(train_dataset.movie_ids)}")

            val_dataset = create_small_subset_for_testing(val_dataset, test_size // 5)

    # Create DataLoaders from the (potentially subsetted) datasets
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    return {'train': train_loader, 'val': val_loader}, positive_pairs_set

In [None]:
def check_gradient_flow(model):
    """Check if gradients are flowing properly"""

    total_norm_sq = 0.0
    param_count = 0
    small_grad_count = 0
    large_grad_count = 0
    nan_grad_count = 0

    for name, param in model.named_parameters():
        if param.requires_grad and param.grad is not None:
            param_norm = param.grad.data.norm(2).item()
            param_count += 1

            if math.isnan(param_norm) or math.isinf(param_norm):
                nan_grad_count += 1
                print(f"   ❌ NAN/INF gradient: {name}")
                continue  # Skip it in total norm computation

            total_norm_sq += param_norm ** 2

            if param_norm < 1e-7:
                small_grad_count += 1
                print(f"   ⚠️  SMALL gradient: {name} = {param_norm:.2e}")
            elif param_norm > 100:
                large_grad_count += 1
                print(f"   ⚠️  LARGE gradient: {name} = {param_norm:.2e}")

    if param_count - nan_grad_count > 0:
        total_norm = (total_norm_sq) ** 0.5
    else:
        total_norm = float('nan')

    print(f"🔍 GRADIENT DIAGNOSIS:")
    print(f"   Total gradient norm: {total_norm:.4f}")
    print(f"   Parameters with gradients: {param_count}")
    print(f"   Very small gradients: {small_grad_count}")
    print(f"   Very large gradients: {large_grad_count}")
    print(f"   ❌ NaN/Inf gradients: {nan_grad_count}")

    if nan_grad_count > 0: # <<< ADD THIS CONDITION FIRST
        print(f"   ❌ CRITICAL: {nan_grad_count} NaN/Inf gradients detected! Gradient flow is unhealthy.")
    elif math.isnan(total_norm) or math.isinf(total_norm):
        print("   ❌ NAN or INF in total gradient norm! Likely exploding gradients!")
    elif total_norm < 1e-4:
        print("   ⚠️  VANISHING GRADIENTS DETECTED!")
    elif total_norm > 1000: # Increased threshold for exploding, as previous explosions were >1000
        print("   ⚠️  EXPLODING GRADIENTS DETECTED!")
    else:
        print("   ✅ Gradient flow looks healthy (for non-NaN/Inf parameters)")

In [None]:
#PLOT FINAL LEARNNG CURVES
def plot_final_learning_curves(log_dir, experiment_name):
    """
    Finds the TensorBoard event file in a specific log directory and plots
    the training and validation loss curves, saving it with the experiment name.
    """
    event_file = None
    for f in os.listdir(log_dir):
        if f.startswith('events.out.tfevents'):
            event_file = os.path.join(log_dir, f)
            break
    
    if not event_file:
        print(f"Error: No event file found in {log_dir} for {experiment_name}")
        return

    ea = EventAccumulator(event_file).Reload()
    
    # Extract data
    try:
        train_loss = [(s.step, s.value) for s in ea.Scalars('Loss/Train')]
        val_loss = [(s.step, s.value) for s in ea.Scalars('Loss/Validation')]
    except KeyError as e:
        print(f"Error: Could not find scalar tag in log for {experiment_name}: {e}")
        return

    plt.figure(figsize=(10, 6))
    plt.plot([p[0] for p in train_loss], [p[1] for p in train_loss], label='Training Loss', marker='o')
    plt.plot([p[0] for p in val_loss], [p[1] for p in val_loss], label='Validation Loss', marker='o')
    
    # Highlight best model
    best_epoch_idx = np.argmin([p[1] for p in val_loss])
    best_val_loss = val_loss[best_epoch_idx][1]
    best_epoch = val_loss[best_epoch_idx][0]
    
    plt.axvline(x=best_epoch, color='g', linestyle='--', label=f'Best Model (Epoch {best_epoch+1})')
    plt.scatter(best_epoch, best_val_loss, s=100, color='g', zorder=5)

    plt.title(f'Learning Curves: {experiment_name}', fontsize=16)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    
    # Save the plot
    plot_path = Path(log_dir) / f'final_learning_curve_{experiment_name}.png'
    plt.savefig(plot_path)
    plt.close()
    
    print(f"✅ Final learning curve plot saved to {plot_path}")
    return str(plot_path)

## Training 

In [None]:
def setup_tensorboard(local_log_dir='master_tmp_logs', s3_bucket='md-data-content-recommendation', s3_prefix='n2n-model/logs'):
    """
    Set up TensorBoard with local logs that can be periodically synced to S3.
    
    Args:
        local_log_dir: Local directory to store logs temporarily
        s3_bucket: S3 bucket name
        s3_prefix: Prefix path in the S3 bucket
        
    Returns:
        SummaryWriter instance and the log directory
    """
    # Create a unique log directory name with timestamp
    timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
    log_dir = f"{local_log_dir}/{timestamp}"
    os.makedirs(log_dir, exist_ok=True)
    
    writer = SummaryWriter(log_dir=log_dir)
    
    # Store S3 information for later syncing
    writer.s3_bucket = s3_bucket
    writer.s3_prefix = f"{s3_prefix}/{timestamp}"
    
    print(f"TensorBoard logs will be saved to: {log_dir}")    
    return writer

In [None]:
# ==================================================================
# UNIFIED PAIR TRAINING AND HELPER FUNCTIONS
# ==================================================================

def _calculate_and_log_pair_metrics(y_true, distances, loss_type, margin, writer, prefix, epoch):
    """
    Calculates and logs a comprehensive set of metrics for pair-based models.
    """
    if not y_true or not distances:
        print(f"Warning: Empty labels or distances for {prefix}. Skipping metrics.")
        return {}

    y_true_np = np.array(y_true)
    distances_np = np.array(distances)
    
    # Invert distance to get a similarity score for ranking metrics
    if loss_type == 'cosine':
        similarity_scores_np = 1.0 - distances_np
    else: # euclidean
        similarity_scores_np = np.exp(-distances_np)

    # For binary classification metrics, use margin as a threshold
    y_pred_np = (distances_np < margin * 0.8).astype(int) # Heuristic threshold

    metrics = {}
    if len(np.unique(y_true_np)) > 1:
        metrics['accuracy'] = accuracy_score(y_true_np, y_pred_np)
        metrics['precision'] = precision_score(y_true_np, y_pred_np, zero_division=0)
        metrics['recall'] = recall_score(y_true_np, y_pred_np, zero_division=0)
        metrics['f1'] = f1_score(y_true_np, y_pred_np, zero_division=0)
        try:
            metrics['auc_roc'] = roc_auc_score(y_true_np, similarity_scores_np)
            metrics['auc_pr'] = average_precision_score(y_true_np, similarity_scores_np)
        except ValueError:
            metrics['auc_roc'], metrics['auc_pr'] = 0.0, 0.0
            
        # Log confusion matrix for validation set
        if prefix == 'Val':
            cm = confusion_matrix(y_true_np, y_pred_np)
            fig, ax = plt.subplots(figsize=(6, 5))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax, 
                        xticklabels=['Pred Different', 'Pred Similar'], yticklabels=['True Different', 'True Similar'])
            ax.set_title(f'Confusion Matrix - Epoch {epoch+1}')
            writer.add_figure('ConfusionMatrix', fig, epoch)
            plt.close(fig)

    pos_dists = distances_np[y_true_np == 1]
    neg_dists = distances_np[y_true_np == 0]
    metrics['avg_pos_dist'] = np.mean(pos_dists) if len(pos_dists) > 0 else 0
    metrics['avg_neg_dist'] = np.mean(neg_dists) if len(neg_dists) > 0 else 0
    metrics['dist_separation'] = metrics['avg_neg_dist'] - metrics['avg_pos_dist']
    
    # Log to TensorBoard
    for name, value in metrics.items():
        if isinstance(value, (float, np.number)):
            writer.add_scalar(f'Metrics_{prefix}/{name}', value, epoch)
            
    return metrics

def train_pair_model(model, loaders, criterion, optimizer, scheduler, device, config):
    """
    Unified training loop for pair-based models with full diagnostics.
    """
    epochs = config['hyperparameters']['epochs']
    s3_prefix = config['checkpoint_s3_prefix']
    loss_type = config['loss_type']
    start_epoch = config.get('start_epoch', 0)
    patience = config.get('patience', 5)
    
    writer = setup_tensorboard(s3_prefix=f"n2n-model/logs/{config['run_name']}")
    best_val_loss = float('inf')
    no_improve = 0

    print(f"--- Starting Pair-Based Training ({loss_type.capitalize()} Distance) ---")

    for epoch in range(start_epoch, epochs):
        # --- TRAINING PHASE ---
        model.train()
        train_loss_sum = 0.0
        train_epoch_labels, train_epoch_distances = [], []
        
        progress_bar_train = tqdm(loaders['train'], desc=f"Epoch {epoch+1}/{epochs} Training", leave=False)
        for batch in progress_bar_train:
            optimizer.zero_grad()
            
            model_args = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() if k not in ['movie1_id', 'movie2_id', 'label']}
            embedding1, embedding2 = model(**model_args)
            loss = criterion(embedding1, embedding2, batch['label'].to(device))
            
            if torch.isnan(loss):
                print(f"❌ LOSS IS NaN at Epoch {epoch+1}. Skipping batch.")
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
            optimizer.step()
            
            train_loss_sum += loss.item()
            progress_bar_train.set_postfix(loss=loss.item())
            
            with torch.no_grad():
                dist_func = (lambda e1, e2: 1 - F.cosine_similarity(e1, e2)) if loss_type == 'cosine' else F.pairwise_distance
                distances_batch = dist_func(embedding1, embedding2).cpu().numpy()
                train_epoch_labels.extend(batch['label'].cpu().numpy())
                train_epoch_distances.extend(distances_batch)

        avg_train_loss = train_loss_sum / len(loaders['train'])
        writer.add_scalar('Loss/Train', avg_train_loss, epoch)

        # --- VALIDATION PHASE ---
        model.eval()
        val_loss_sum = 0.0
        val_epoch_labels, val_epoch_distances = [], []
        
        with torch.no_grad():
            for batch in tqdm(loaders['val'], desc=f"Epoch {epoch+1}/{epochs} Validation", leave=False):
                model_args = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() if k not in ['movie1_id', 'movie2_id', 'label']}
                embedding1, embedding2 = model(**model_args)
                val_loss_sum += criterion(embedding1, embedding2, batch['label'].to(device)).item()
                
                dist_func = (lambda e1, e2: 1 - F.cosine_similarity(e1, e2)) if loss_type == 'cosine' else F.pairwise_distance
                distances_val = dist_func(embedding1, embedding2).cpu().numpy()
                val_epoch_labels.extend(batch['label'].cpu().numpy())
                val_epoch_distances.extend(distances_val)

        avg_val_loss = val_loss_sum / len(loaders['val'])
        writer.add_scalar('Loss/Validation', avg_val_loss, epoch)
        
        # --- METRICS & LOGGING ---
        train_metrics = _calculate_and_log_pair_metrics(train_epoch_labels, train_epoch_distances, loss_type, criterion.margin, writer, 'Train', epoch)
        val_metrics = _calculate_and_log_pair_metrics(val_epoch_labels, val_epoch_distances, loss_type, criterion.margin, writer, 'Val', epoch)
        
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val F1: {val_metrics.get('f1', 'N/A'):.4f} | Val AUC: {val_metrics.get('auc_roc', 'N/A'):.4f}")
        print(f"  Distances -> Pos: {val_metrics.get('avg_pos_dist', 'N/A'):.4f} | Neg: {val_metrics.get('avg_neg_dist', 'N/A'):.4f} | Sep: {val_metrics.get('dist_separation', 'N/A'):.4f}")

        # --- CHECKPOINTING & EARLY STOPPING ---
        scheduler.step(avg_val_loss)
        is_best = avg_val_loss < best_val_loss
        if is_best:
            best_val_loss = avg_val_loss
            no_improve = 0
            print(f"🎉 New best model found at epoch {epoch+1}!")
        else:
            no_improve += 1
            
        save_checkpoint({
            'epoch': epoch + 1, 'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': avg_val_loss, 'config': config, 'metrics': val_metrics
        }, is_best=is_best, s3_prefix=s3_prefix)
        
        if no_improve >= patience:
            print(f"Early stopping triggered at epoch {epoch+1}.")
            break
            
    writer.close()
    print("--- Pair-Based Training Finished ---")
    return model, writer

In [None]:
def _log_triplet_diagnostics(writer, pos_dists, neg_dists, margin, loss_type, epoch, prefix):
    """Logs detailed validation metrics for triplet models to TensorBoard."""
    if not pos_dists or not neg_dists:
        return
        
    avg_pos_dist = np.mean(pos_dists)
    avg_neg_dist = np.mean(neg_dists)
    separation = avg_neg_dist - avg_pos_dist
    
    writer.add_scalar(f'Metrics_{prefix}/Avg_Positive_Dist_{loss_type}', avg_pos_dist, epoch)
    writer.add_scalar(f'Metrics_{prefix}/Avg_Negative_Dist_{loss_type}', avg_neg_dist, epoch)
    writer.add_scalar(f'Metrics_{prefix}/Dist_Separation_{loss_type}', separation, epoch)

    # Calculate percentage of "satisfied" triplets (where loss would be 0)
    satisfied_triplets = np.mean(np.array(pos_dists) - np.array(neg_dists) + margin < 0)
    writer.add_scalar(f'Metrics_{prefix}/Satisfied_Triplets_Percent', satisfied_triplets * 100, epoch)
    
    print(f"{prefix} Metrics -> PosDist: {avg_pos_dist:.4f} | NegDist: {avg_neg_dist:.4f} | Sep: {separation:.4f} | Satisfied: {satisfied_triplets:.2%}")
    
    if prefix == 'Val':
        fig, ax = plt.subplots(figsize=(10, 6))
        sns.histplot(pos_dists, color="green", label=f'Anchor-Positive (Mean: {avg_pos_dist:.2f})', kde=True, stat="density", element="step", ax=ax)
        sns.histplot(neg_dists, color="red", label=f'Anchor-Negative (Mean: {avg_neg_dist:.2f})', kde=True, stat="density", element="step", ax=ax)
        plt.axvline(x=margin, color='blue', linestyle='--', linewidth=2, label=f'Margin ({margin})')
        plt.title(f'Validation Distance Distribution - Epoch {epoch+1}')
        plt.legend()
        writer.add_figure('DistanceDistribution', fig, epoch)
        plt.close(fig)

def train_triplet_model(model, train_loader, val_loader, positive_pairs_set, criterion, optimizer, scheduler, device, config):
    """
    Unified training loop for triplet-based models with semi-hard mining.
    """
    epochs = config['hyperparameters']['epochs']
    s3_prefix = config['checkpoint_s3_prefix']
    loss_type = config['loss_type']
    start_epoch = config.get('start_epoch', 0)
    patience = config.get('patience', 5)

    writer = setup_tensorboard(s3_prefix=f"n2n-model/logs/{config['run_name']}")
    best_val_loss = float('inf')
    no_improve = 0

    print(f"--- Starting Triplet Training ({loss_type.capitalize()} Distance) ---")

    for epoch in range(start_epoch, epochs):
        # --- TRAINING PHASE ---
        model.train()
        train_loss_sum = 0.0
        
        progress_bar_train = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} Training", leave=False)
        for i, batch in enumerate(progress_bar_train):
            optimizer.zero_grad()

            batch_data = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            embeddings = model(**batch_data)
            movie_ids = batch_data['movie_id']
            
            # --- In-Batch Semi-Hard Negative Mining ---
            distance_matrix = torch.cdist(embeddings, embeddings) if loss_type == 'euclidean' else (1 - torch.matmul(embeddings, embeddings.t()))
            
            anchors, positives, negatives = [], [], []
            for j in range(len(movie_ids)):
                anchor_id, anchor_emb = movie_ids[j].item(), embeddings[j]
                
                # Find positives in the batch
                pos_indices = [k for k, pos_id in enumerate(movie_ids) if j != k and tuple(sorted((anchor_id, pos_id.item()))) in positive_pairs_set]
                if not pos_indices: continue
                
                # Select one positive
                pos_idx = pos_indices[0]
                positive_emb = embeddings[pos_idx]
                
                dist_ap = distance_matrix[j, pos_idx]
                
                # Find semi-hard negatives
                neg_mask = torch.ones(len(movie_ids), dtype=torch.bool, device=device)
                neg_mask[j] = False
                neg_mask[pos_indices] = False
                
                dist_an = distance_matrix[j, neg_mask]
                semi_hard_mask = (dist_an > dist_ap) & (dist_an < dist_ap + criterion.margin)
                
                if torch.any(semi_hard_mask):
                    semi_hard_indices = torch.where(neg_mask)[0][semi_hard_mask]
                    chosen_neg_idx = random.choice(semi_hard_indices)
                    negative_emb = embeddings[chosen_neg_idx]
                    
                    anchors.append(anchor_emb)
                    positives.append(positive_emb)
                    negatives.append(negative_emb)

            if not anchors: continue
            
            loss = criterion(torch.stack(anchors), torch.stack(positives), torch.stack(negatives))
            if torch.isnan(loss): continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
            optimizer.step()
            train_loss_sum += loss.item()
            progress_bar_train.set_postfix(loss=loss.item(), triplets=len(anchors))

        avg_train_loss = train_loss_sum / len(train_loader)
        writer.add_scalar('Loss/Train', avg_train_loss, epoch)

        # --- VALIDATION PHASE ---
        model.eval()
        val_loss_sum = 0.0
        val_pos_dists, val_neg_dists = [], []
        with torch.no_grad():
            progress_bar_val = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} Validation", leave=False)
            for batch in progress_bar_val:
                # anchor_emb = model(**{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch['anchor_data'].items()})
                # positive_emb = model(**{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch['positive_data'].items()})
                # negative_emb = model(**{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch['negative_data'].items()})

                # Process each part of the triplet separately to save memory
                anchor_data = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch['anchor_data'].items()}
                anchor_emb = model(**anchor_data)
                
                positive_data = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch['positive_data'].items()}
                positive_emb = model(**positive_data)
                
                negative_data = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch['negative_data'].items()}
                negative_emb = model(**negative_data)
                
                val_loss_sum += criterion(anchor_emb, positive_emb, negative_emb).item()
                
                dist_func = (lambda e1, e2: 1 - F.cosine_similarity(e1, e2)) if loss_type == 'cosine' else F.pairwise_distance
                val_pos_dists.extend(dist_func(anchor_emb, positive_emb).cpu().numpy())
                val_neg_dists.extend(dist_func(anchor_emb, negative_emb).cpu().numpy())
                
        avg_val_loss = val_loss_sum / len(val_loader)
        writer.add_scalar('Loss/Validation', avg_val_loss, epoch)

        # --- METRICS & LOGGING ---
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        _log_triplet_diagnostics(writer, val_pos_dists, val_neg_dists, criterion.margin, loss_type, epoch, 'Val')
        if i % 100 == 1: check_gradient_flow(model) # Check gradients periodically

        # --- CHECKPOINTING & EARLY STOPPING ---
        scheduler.step(avg_val_loss)
        is_best = avg_val_loss < best_val_loss
        if is_best:
            best_val_loss = avg_val_loss
            no_improve = 0
            print(f"🎉 New best model found at epoch {epoch+1}!")
        else:
            no_improve += 1
            
        save_checkpoint({
            'epoch': epoch + 1, 'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': avg_val_loss, 'config': config
        }, is_best=is_best, s3_prefix=s3_prefix)
        
        if no_improve >= patience:
            print(f"Early stopping triggered at epoch {epoch+1}.")
            break
            
    writer.close()
    print("--- Triplet-Based Training Finished ---")
    return model, writer

# Main

In [None]:
# Load the TensorBoard extension and launch it
%load_ext tensorboard
%tensorboard --logdir master_tmp_logs --port 6001 --bind_all

In [None]:
if __name__ == "__main__":
    
    # --- Global Constants ---
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    BUCKET_NAME = 'md-data-content-recommendation'
    TEXT_CSV = 's3://md-data-content-recommendation/cleaned_textual_trailers_dataset_with_languages.csv'
    ABLATION_MODE = 'full'
    NUM_FRAMES_FROM_PT, DINO_DIM, VGGISH_DIM = 16, 384, 128
    SAMPLE_RATE, MAX_AUDIO_LENGTH_SEC = 16000, 10
    DINO_PER_FRAME_EMBEDDING_S3_PREFIX = 'movie_trailers_dino_per_frame_embeddings/'
    AUDIO_VGGISH_WAVE_S3_PREFIX = 'movie_trailers_audio_embeddings_vggish/'

    # === THE CONTROL PANEL FOR ALL EXPERIMENTS ===
    experiments_to_run = {
        "E2E_Pair_Cosine": {
            "model_type": "pair",
            "loss_type": "cosine",
            "fusion_network": "Hybrid",
            "checkpoint_s3_prefix": "n2n-model/25-cosine-with-languages-checkpoints/",
            "hyperparameters": {"lr": 5e-5, "batch_size": 16, "epochs": 50},
            "patience": 8, # increase patience
            "test_mode": True,      
            "test_size": 25000 
        },
        "E2E_Pair_Euclidean": {
            "model_type": "pair",
            "loss_type": "euclidean",
            "fusion_network": "Simple",
            "checkpoint_s3_prefix": "n2n-model/25-euclidean-with-languages-checkpoints/",
            "hyperparameters": {"lr":5e-5, "batch_size": 16, "epochs": 50},
            "patience": 8, # increase patience
            "test_mode": True,
            "test_size": 25000
        },
        "E2E_Triplet_Cosine": {
            "model_type": "triplet",
            "loss_type": "cosine",
            "fusion_network": "Simple",
            "checkpoint_s3_prefix": "n2n-model/25-SEMI-HARD-triplet-cosine-checkpoints/",
            "hyperparameters": {"lr": 1e-4, "batch_size": 16, "epochs": 100},
            "patience": 10, # increase patience
            "test_mode": True,
            "test_size": 25000
        },
        "E2E_Triplet_Euclidean": {
            "model_type": "triplet",
            "loss_type": "euclidean",
            "fusion_network": "UltraSimple",
            "checkpoint_s3_prefix": "n2n-model/25-SEMI-HARD-triplet-euclidean-checkpoints/", # normal = 2e-4, 2- = 1e-4, patience = 10
            "hyperparameters": {"lr": 1e-4, "batch_size": 16, "epochs": 100},
            "patience": 10, # increase patience
            "test_mode": True,
            "test_size": 25000
        },
    }

    # --- Pre-build Vocabulary ONCE ---
    master_text_dataset = TextDataset(TEXT_CSV, lang_vocab=None)
    num_languages = master_text_dataset.num_languages
    shared_lang_vocab = master_text_dataset.lang_vocab

    # To store the final results
    all_experiment_results = {}

    # --- MASTER LOOP ---
    for name, config in experiments_to_run.items():
        print(f"\n{'='*30}\nSTARTING EXPERIMENT: {name}\n{'='*30}")
        config['run_name'] = name # Add a name for logging
        set_seed(42) # Reset seed for each experiment for reproducibility

        # 1. Create Dataloaders for this experiment
        loaders, positive_pairs_set = create_dataloaders(config, shared_lang_vocab, num_languages)
        
        # 2. Create Model and Criterion
        model, criterion, optimizer, scheduler = create_model_and_criterion(config, DEVICE, num_languages)

        final_writer = None # Initialize writer
        
        # 3. Call the correct, specific training loop
        if config['model_type'] == 'pair':
            print(">>> Dispatching to PAIR training loop...")
            train_pair_model(
                model, loaders, criterion, optimizer, scheduler, DEVICE, config
            )
        elif config['model_type'] == 'triplet':
            print(">>> Dispatching to TRIPLET training loop...")
            train_triplet_model(
                model, loaders['train'], loaders['val'], positive_pairs_set,
                criterion, optimizer, scheduler, DEVICE, config
            )

        # 4. Generate final plots
        if final_writer:
            print("\nFlushing and closing TensorBoard writer...")
            final_writer.flush()
            final_writer.close()
            time.sleep(2) # Give filesystem a moment
            plot_final_learning_curves(final_writer.log_dir, name)
            
        print(f"\n{'='*30}\nFINISHED EXPERIMENT: {name}\n{'='*30}")
        
    print("\nAll experiments have been completed.")

In [None]:
zip -r master_logs.zip master_tmp_logs/