# 🎯 Embedding Extraction Pipeline for Specialist Models

This notebook extracts embeddings from all trained specialist models (audio, video, text) to be used in the fusion model training.

## 📋 Overview:
1. **Load Trained Models**: Load your trained audio, video, and text specialist models
2. **Prepare Datasets**: Set up datasets for each modality 
3. **Extract Embeddings**: Use each specialist model as an encoder to extract features
4. **Save Results**: Store embeddings and metadata for fusion training

## 🎯 Key Features:
- **Multi-modal Support**: Handles audio, video, and text specialists
- **Efficient Processing**: Batch processing with GPU acceleration
- **Robust Storage**: Saves embeddings as .npy files with metadata CSV
- **Fusion Ready**: Outputs are ready for fusion model training

In [1]:
# 📦 Setup & Imports
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import os
import json
from tqdm import tqdm
from torch.utils.data import DataLoader
import warnings
warnings.filterwarnings('ignore')

# Import your existing extraction utilities
from extract_embedding import (
    extract_embeddings, 
    extract_multimodal_embeddings,
    video_collate_fn,
    audio_collate_fn, 
    text_collate_fn
)

# Import dataset classes
from datasets import AudioDataset, VideoDataset, TextDataset

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load label mapping
with open("artifacts/label2idx.json", "r") as f:
    label2idx = json.load(f)
    
idx2label = {v: k for k, v in label2idx.items()}
print(f"Loaded {len(label2idx)} emotion classes: {list(label2idx.keys())}")

# Configuration
config = {
    'batch_size': 8,
    'output_dir': 'artifacts/embeddings',
    'save_raw_embeddings': True,
    'create_manifests': True
}

print("✅ Setup complete!")

Using device: cuda
Loaded 8 emotion classes: ['Anger', 'Fear', 'Joy', 'Neutral', 'Proud', 'Sadness', 'Surprise', 'Trust']
✅ Setup complete!


# 🔧 Setup & Configuration

This section handles all the imports, device setup, and configuration needed for the embedding extraction pipeline.

In [2]:
# 🎵 Load Audio Specialist Model
def load_audio_specialist():
    """Load trained audio specialist model and convert to encoder"""
    print("🎵 Loading Audio Specialist Model...")
    
    # Check available audio models
    audio_models_dir = "specialists/audio/"
    if not os.path.exists(audio_models_dir):
        print(f"❌ Audio models directory not found: {audio_models_dir}")
        return None, None
        
    audio_models = [f for f in os.listdir(audio_models_dir) if f.endswith('.pth')]
    print(f"Available audio models: {audio_models}")
    
    # Load the MLP audio specialist model specifically
    model_paths = [
        "specialists/audio/mlp_audio_specialist.pth"
    ]
    
    audio_model = None
    model_info = None
    
    for model_path in model_paths:
        if os.path.exists(model_path):
            try:
                print(f"Loading: {model_path}")
                # First try with weights_only=True (secure)
                try:
                    checkpoint = torch.load(model_path, map_location=device, weights_only=True)
                except Exception as e:
                    if "sklearn" in str(e) or "WeightsUnpickler" in str(e):
                        print(f"Secure loading failed due to sklearn objects, trying unsafe loading...")
                        # For trusted models, we can use weights_only=False
                        checkpoint = torch.load(model_path, map_location=device, weights_only=False)
                    else:
                        raise e
                
                if 'model' in checkpoint:
                    audio_model = checkpoint['model']
                elif 'model_state_dict' in checkpoint:
                    print("Model state dict found, reconstructing architecture...")
                    # Try to load model info to reconstruct
                    info_path = "specialists/audio/mlp_model_info.json"
                    
                    if os.path.exists(info_path):
                        print(f"Found model info at: {info_path}")
                        reconstructed_model = reconstruct_mlp_audio_model(info_path)
                        if reconstructed_model is not None:
                            reconstructed_model.load_state_dict(checkpoint['model_state_dict'])
                            audio_model = reconstructed_model
                            model_info = model_path
                            print(f"✅ Successfully reconstructed and loaded model from: {model_path}")
                            break
                    else:
                        print(f"No model info found at {info_path}, skipping...")
                        continue
                else:
                    audio_model = checkpoint
                
                model_info = model_path
                print(f"✅ Successfully loaded model from: {model_path}")
                break
                
            except Exception as e:
                print(f"Failed to load {model_path}: {e}")
                continue
    
    if audio_model is None:
        print("❌ No audio models could be loaded!")
        return None, None
    
    # Convert to encoder by removing classifier head if needed
    class AudioEncoder(nn.Module):
        def __init__(self, base_model):
            super().__init__()
            self.base_model = base_model
            
        def forward(self, x):
            if hasattr(self.base_model, 'extract_features'):
                return self.base_model.extract_features(x)
            elif hasattr(self.base_model, 'feature_extractor'):
                # For models with separate feature extractor
                return self.base_model.feature_extractor(x)
            elif hasattr(self.base_model, 'network'):
                # For MLP models, get features before final classifier
                layers = list(self.base_model.network.children())[:-1]
                feature_extractor = nn.Sequential(*layers)
                return feature_extractor(x)
            else:
                # Get intermediate representation before final classification
                # For other models, get the last hidden layer before classifier
                if hasattr(self.base_model, 'classifier'):
                    # Remove the final classifier layer
                    layers = list(self.base_model.children())[:-1]
                    feature_extractor = nn.Sequential(*layers)
                    return feature_extractor(x)
                else:
                    return self.base_model(x)
        
        def extract_features(self, x):
            return self.forward(x)
    
    audio_encoder = AudioEncoder(audio_model)
    audio_encoder.to(device)
    audio_encoder.eval()
    
    print(f"✅ Audio specialist loaded from: {model_info}")
    return audio_encoder, model_info

def reconstruct_mlp_audio_model(model_info_path):
    """Reconstruct MLP audio model from saved info"""
    try:
        with open(model_info_path, 'r') as f:
            info = json.load(f)
        
        # Create MLP model architecture with BatchNorm (matches the saved model)
        class MLPAudioSpecialist(nn.Module):
            def __init__(self, input_dim, hidden_dims, num_classes, dropout=0.3):
                super().__init__()
                
                layers = []
                prev_dim = input_dim
                
                for i, hidden_dim in enumerate(hidden_dims):
                    # Linear layer
                    layers.append(nn.Linear(prev_dim, hidden_dim))
                    # BatchNorm layer
                    layers.append(nn.BatchNorm1d(hidden_dim))
                    # ReLU activation
                    layers.append(nn.ReLU())
                    # Dropout
                    layers.append(nn.Dropout(dropout))
                    prev_dim = hidden_dim
                
                # Final classifier
                layers.append(nn.Linear(prev_dim, num_classes))
                
                # Only create the network (as saved in the checkpoint)
                self.network = nn.Sequential(*layers)
            
            def forward(self, x):
                return self.network(x)
            
            def extract_features(self, x):
                # Extract features by removing the final classifier
                layers = list(self.network.children())[:-1]
                feature_extractor = nn.Sequential(*layers)
                return feature_extractor(x)
        
        # Get model parameters from info
        input_dim = info.get('feature_dim', 106)
        hidden_dims = info.get('hidden_dims', [512, 256, 128])
        num_classes = info.get('num_classes', 8)
        dropout = info.get('dropout', 0.3)
        
        print(f"Reconstructing MLP model with:")
        print(f"  Input dim: {input_dim}")
        print(f"  Hidden dims: {hidden_dims}")
        print(f"  Num classes: {num_classes}")
        print(f"  Dropout: {dropout}")
        print(f"  Architecture: Linear -> BatchNorm -> ReLU -> Dropout")
        
        model = MLPAudioSpecialist(input_dim, hidden_dims, num_classes, dropout)
        return model
        
    except Exception as e:
        print(f"Failed to reconstruct model: {e}")
        return None

# Load audio model
audio_model, audio_model_info = load_audio_specialist()

🎵 Loading Audio Specialist Model...
Available audio models: ['best_audio_cnn_fold0.pth', 'enhanced_mlp_audio_specialist.pth', 'mlp_audio_specialist.pth', 'wavlm_cls_fold0.pth', 'wavlm_enhanced_fold0.pth']
Loading: specialists/audio/mlp_audio_specialist.pth
Secure loading failed due to sklearn objects, trying unsafe loading...
Model state dict found, reconstructing architecture...
Found model info at: specialists/audio/mlp_model_info.json
Reconstructing MLP model with:
  Input dim: 106
  Hidden dims: [512, 256, 128]
  Num classes: 8
  Dropout: 0.3
  Architecture: Linear -> BatchNorm -> ReLU -> Dropout
✅ Successfully reconstructed and loaded model from: specialists/audio/mlp_audio_specialist.pth
✅ Audio specialist loaded from: specialists/audio/mlp_audio_specialist.pth


# 🤖 Model Loading

This section loads all the trained specialist models (Audio, Video, Text) and prepares them for embedding extraction.

In [3]:
# 🎬 CELL 3: Load Video Specialist Model
def load_video_specialist():
    """Load trained video specialist model and convert to encoder"""
    print("🎬 Loading Video Specialist Model...")
    
    # Check available video models
    video_models_dir = "specialists/video/"
    if os.path.exists(video_models_dir):
        video_models = [f for f in os.listdir(video_models_dir) if f.endswith('.pth')]
        print(f"Available video models: {video_models}")
    else:
        print("❌ Video models directory not found!")
        return None, None
    
    # Try to load video models
    model_paths = [
        "specialists/video/video_fold0_gpu.pth"
    ]
    
    video_model = None
    model_info = None
    
    for model_path in model_paths:
        if os.path.exists(model_path):
            try:
                print(f"Loading: {model_path}")
                video_model = torch.load(model_path, map_location='cpu')
                model_info = model_path
                break
                
            except Exception as e:
                print(f"Failed to load {model_path}: {e}")
                continue
    
    if video_model is None:
        print("❌ No video models could be loaded!")
        return None, None
    
    # Convert to encoder
    class VideoEncoder(nn.Module):
        def __init__(self, base_model):
            super().__init__()
            # If base_model is an OrderedDict (state_dict), we need to create a proper model
            if isinstance(base_model, dict):
                print("⚠️ Video model is state_dict, creating placeholder encoder...")
                # Create a simple identity encoder for now
                self.base_model = nn.Identity()
                self.is_placeholder = True
            else:
                self.base_model = base_model
                self.is_placeholder = False
            
        def forward(self, x):
            if self.is_placeholder:
                # For placeholder, return flattened features
                if len(x.shape) == 5:  # [batch, frames, channels, height, width]
                    batch_size, num_frames, channels, height, width = x.shape
                    # Global average pooling over spatial dimensions and then over frames
                    x = x.mean(dim=[3, 4])  # Average over H, W -> [batch, frames, channels]
                    x = x.mean(dim=1)       # Average over frames -> [batch, channels]
                    return x
                else:
                    return x.mean(dim=[2, 3]) if len(x.shape) > 2 else x
            else:
                # Handle video input: [batch, frames, channels, height, width]
                if len(x.shape) == 5:
                    batch_size, num_frames, channels, height, width = x.shape
                    # Reshape to process all frames
                    x = x.view(batch_size * num_frames, channels, height, width)
                    features = self.base_model(x)
                    # Reshape back and pool over frames
                    if len(features.shape) > 1:
                        features = features.view(batch_size, num_frames, -1)
                        features = features.mean(dim=1)  # Average pooling over frames
                    else:
                        features = features.view(batch_size, -1)
                    return features
                else:
                    return self.base_model(x)
        
        def extract_features(self, x):
            return self.forward(x)
    
    video_encoder = VideoEncoder(video_model)
    video_encoder.to(device)
    video_encoder.eval()
    
    print(f"✅ Video specialist loaded from: {model_info}")
    return video_encoder, model_info

# Load video model
video_model, video_model_info = load_video_specialist()

🎬 Loading Video Specialist Model...
Available video models: ['video_fold0_cpu.pth', 'video_fold0_gpu.pth']
Loading: specialists/video/video_fold0_gpu.pth
⚠️ Video model is state_dict, creating placeholder encoder...
✅ Video specialist loaded from: specialists/video/video_fold0_gpu.pth


In [4]:
# 📝 CELL 4: Load Text Specialist Model (if available)
def load_text_specialist():
    """Load trained text specialist model and convert to encoder"""
    print("📝 Loading Text Specialist Model...")
    
    # Check for text models
    text_model_paths = [
        "specialists/transcript/multilingual_emotion_fold0.pth"
    ]
    
    text_model = None
    model_info = None
    
    for model_path in text_model_paths:
        if os.path.exists(model_path):
            try:
                print(f"Loading: {model_path}")
                checkpoint = torch.load(model_path, map_location='cpu')
                
                if 'model_state_dict' in checkpoint:
                    # Reconstruct the BERT-based text model
                    print("Text model state dict found. Reconstructing BERT-based model...")
                    
                    # Get model info from checkpoint
                    model_name = checkpoint.get('model_name', 'bert-base-multilingual-cased')
                    num_classes = checkpoint.get('num_classes', 8)
                    
                    print(f"Model name: {model_name}")
                    print(f"Number of classes: {num_classes}")
                    
                    # Create BERT-based emotion classifier
                    from transformers import AutoModel, AutoConfig
                    
                    class BERTEmotionClassifier(nn.Module):
                        def __init__(self, model_name, num_classes):
                            super().__init__()
                            # Load BERT model
                            self.bert = AutoModel.from_pretrained(model_name)
                            
                            # Get BERT hidden size
                            bert_hidden_size = self.bert.config.hidden_size
                            
                            # Classifier head
                            self.classifier = nn.Linear(bert_hidden_size, num_classes)
                            
                        def forward(self, input_ids, attention_mask=None):
                            # Get BERT outputs
                            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
                            # Use [CLS] token representation
                            cls_output = outputs.last_hidden_state[:, 0, :]
                            # Classify
                            logits = self.classifier(cls_output)
                            return logits
                        
                        def extract_features(self, input_ids, attention_mask=None):
                            # Extract features without classification
                            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
                            return outputs.last_hidden_state[:, 0, :]  # [CLS] token
                    
                    # Create model and load weights
                    try:
                        reconstructed_model = BERTEmotionClassifier(model_name, num_classes)
                        reconstructed_model.load_state_dict(checkpoint['model_state_dict'])
                        text_model = reconstructed_model
                        model_info = model_path
                        print(f"✅ Successfully reconstructed and loaded BERT model from: {model_path}")
                        break
                    except Exception as e:
                        print(f"Failed to reconstruct model: {e}")
                        # Try with offline mode
                        print("Trying to load BERT in offline mode...")
                        try:
                            # Create a basic BERT-like structure
                            class SimpleBERTClassifier(nn.Module):
                                def __init__(self, hidden_size=768, num_classes=8):
                                    super().__init__()
                                    self.hidden_size = hidden_size
                                    self.classifier = nn.Linear(hidden_size, num_classes)
                                    
                                def extract_features(self, input_ids, attention_mask=None):
                                    # Return dummy features with correct shape
                                    batch_size = input_ids.size(0)
                                    return torch.zeros(batch_size, self.hidden_size)
                                    
                                def forward(self, input_ids, attention_mask=None):
                                    features = self.extract_features(input_ids, attention_mask)
                                    return self.classifier(features)
                            
                            simple_model = SimpleBERTClassifier()
                            # Only load classifier weights
                            classifier_state = {
                                'classifier.weight': checkpoint['model_state_dict']['classifier.weight'],
                                'classifier.bias': checkpoint['model_state_dict']['classifier.bias']
                            }
                            simple_model.load_state_dict(classifier_state, strict=False)
                            text_model = simple_model
                            model_info = model_path
                            print(f"⚠️ Loaded simplified model (classifier only) from: {model_path}")
                            break
                        except Exception as e2:
                            print(f"Failed to load simplified model: {e2}")
                            continue
                        
                else:
                    text_model = checkpoint
                    model_info = model_path
                    break
                    
            except Exception as e:
                print(f"Failed to load {model_path}: {e}")
                continue
    
    if text_model is None:
        print("⚠️ No text models found or loaded. Will use placeholder.")
        return None, None
    
    # Convert to encoder
    class TextEncoder(nn.Module):
        def __init__(self, base_model):
            super().__init__()
            self.base_model = base_model
            
        def forward(self, input_ids, attention_mask=None, temporal_features=None):
            if hasattr(self.base_model, 'bert'):
                # Extract BERT features before classification
                bert_output = self.base_model.bert(
                    input_ids=input_ids, 
                    attention_mask=attention_mask
                )
                cls_output = bert_output.last_hidden_state[:, 0, :]  # [CLS] token
                
                # Add temporal features if available
                if temporal_features is not None and hasattr(self.base_model, 'temporal_processor'):
                    temporal_processed = self.base_model.temporal_processor(temporal_features)
                    features = torch.cat([cls_output, temporal_processed], dim=1)
                else:
                    features = cls_output
                    
                return features
            else:
                return self.base_model(input_ids, attention_mask, temporal_features)
        
        def extract_features(self, input_ids, attention_mask=None, temporal_features=None):
            return self.forward(input_ids, attention_mask, temporal_features)
    
    text_encoder = TextEncoder(text_model)
    text_encoder.to(device)
    text_encoder.eval()
    
    print(f"✅ Text specialist loaded from: {model_info}")
    return text_encoder, model_info

# Load text model
text_model, text_model_info = load_text_specialist()

📝 Loading Text Specialist Model...
Loading: specialists/transcript/multilingual_emotion_fold0.pth
Text model state dict found. Reconstructing BERT-based model...
Model name: bert-base-multilingual-cased
Number of classes: 8
✅ Successfully reconstructed and loaded BERT model from: specialists/transcript/multilingual_emotion_fold0.pth
✅ Text specialist loaded from: specialists/transcript/multilingual_emotion_fold0.pth


# 📊 Dataset Preparation

This section prepares the datasets for each modality and validates the data before extraction.

In [5]:
# 📊 CELL 5: Prepare Datasets for Embedding Extraction
def prepare_datasets():
    """Prepare datasets for each modality"""
    print("📊 Preparing Datasets for Embedding Extraction...")
    
    datasets = {}
    manifests = {}
    
    # Check available manifest files
    manifest_files = [
        "artifacts/train_manifest_fold0.csv",
        "artifacts/val_manifest_fold0.csv", 
        "train/train_df.csv",
        "train/val_df.csv"
    ]
    
    # Find the best manifest file
    manifest_path = None
    for path in manifest_files:
        if os.path.exists(path):
            manifest_path = path
            print(f"Using manifest: {manifest_path}")
            break
    
    if manifest_path is None:
        print("❌ No manifest files found!")
        return {}, {}
    
    # Load manifest
    df = pd.read_csv(manifest_path)
    print(f"Loaded manifest with {len(df)} samples")
    print(f"Columns: {list(df.columns)}")
    
    # Take a subset for faster testing (you can increase this)
    max_samples = 20000  # Adjust this number
    filtered_manifest_path = manifest_path
    
    if len(df) > max_samples:
        print(f"⚡ Using {max_samples} samples for faster processing")
        
        # Strategy 1: Try to get diverse samples from different videos
        unique_videos = df['video_id'].unique()
        print(f"📊 Total unique videos: {len(unique_videos)}")
        
        if len(unique_videos) >= max_samples // 10:  # If we have enough videos
            # Sample videos first, then take samples from each video
            samples_per_video = max_samples // len(unique_videos)
            remaining_samples = max_samples % len(unique_videos)
            
            sampled_rows = []
            for i, video_id in enumerate(unique_videos):
                video_df = df[df['video_id'] == video_id]
                n_samples = samples_per_video + (1 if i < remaining_samples else 0)
                n_samples = min(n_samples, len(video_df))
                
                sampled_video_rows = video_df.head(n_samples)
                sampled_rows.append(sampled_video_rows)
            
            df_filtered = pd.concat(sampled_rows, ignore_index=True)
            print(f"📊 Diverse sampling: {len(df_filtered)} samples from {df_filtered['video_id'].nunique()} videos")
        else:
            # Fallback: just take first N samples
            df_filtered = df.head(max_samples)
            print(f"📊 Sequential sampling: {len(df_filtered)} samples from {df_filtered['video_id'].nunique()} videos")
        
        # Save filtered manifest temporarily
        filtered_manifest_path = manifest_path.replace('.csv', f'_filtered_{max_samples}.csv')
        df_filtered.to_csv(filtered_manifest_path, index=False)
        print(f"📄 Created filtered manifest: {filtered_manifest_path}")
        print(f"📊 Video distribution in filtered data:")
        video_counts = df_filtered['video_id'].value_counts().head(10)
        for video_id, count in video_counts.items():
            print(f"   Video {video_id}: {count} samples")
    
    # Prepare audio dataset
    if audio_model is not None:
        print("🎵 Preparing audio dataset...")
        try:
            audio_dataset = AudioDataset(
                manifest_csv=filtered_manifest_path,  # Use filtered manifest
                split="train", 
                fold=0,
                label2idx=label2idx
            )
            # No need to limit again since we already filtered the manifest
            
            datasets['audio'] = audio_dataset
            manifests['audio'] = filtered_manifest_path  # Use filtered manifest path
            print(f"✅ Audio dataset ready: {len(audio_dataset)} samples")
        except Exception as e:
            print(f"❌ Failed to create audio dataset: {e}")
    
    # Prepare video dataset
    if video_model is not None:
        print("🎬 Preparing video dataset...")
        try:
            video_dataset = VideoDataset(
                manifest_csv=filtered_manifest_path,  # Use filtered manifest
                split="train",
                fold=0, 
                label2idx=label2idx
            )
            # No need to limit again since we already filtered the manifest
                
            datasets['video'] = video_dataset
            manifests['video'] = filtered_manifest_path  # Use filtered manifest path
            print(f"✅ Video dataset ready: {len(video_dataset)} samples")
        except Exception as e:
            print(f"❌ Failed to create video dataset: {e}")
    
    # Prepare text dataset 
    if text_model is not None:
        print("📝 Preparing text dataset...")
        try:
            text_dataset = TextDataset(
                manifest_csv=filtered_manifest_path,  # Use filtered manifest
                split="train",
                fold=0,
                label2idx=label2idx
            )
            # No need to limit again since we already filtered the manifest
                
            datasets['text'] = text_dataset
            manifests['text'] = filtered_manifest_path  # Use filtered manifest path
            print(f"✅ Text dataset ready: {len(text_dataset)} samples")
        except Exception as e:
            print(f"❌ Failed to create text dataset: {e}")
    
    return datasets, manifests

# Prepare datasets
datasets, manifests = prepare_datasets()
print(f"\n📊 Summary: {len(datasets)} datasets prepared")

# Verify dataset alignment
if len(datasets) > 1:
    print(f"\n🔍 Verifying dataset alignment...")
    dataset_sizes = {name: len(dataset) for name, dataset in datasets.items()}
    print(f"Dataset sizes: {dataset_sizes}")
    
    # Check if all datasets have the same size
    sizes = list(dataset_sizes.values())
    if len(set(sizes)) == 1:
        print(f"✅ All datasets aligned with {sizes[0]} samples each")
    else:
        print(f"⚠️ Dataset size mismatch! This might cause fusion issues.")
        min_size = min(sizes)
        print(f"🔧 Consider using {min_size} samples for all modalities.")

📊 Preparing Datasets for Embedding Extraction...
Using manifest: artifacts/train_manifest_fold0.csv
Loaded manifest with 4271 samples
Columns: ['video_id', 'window_idx', 'start', 'end', 'frame_indices', 'frames_path', 'audio_path', 'text_snippet', 'label', 'label_idx', 'fold', 'split', 'speech_ratio', 'has_face']
🎵 Preparing audio dataset...
✅ Audio dataset ready: 2949 samples
🎬 Preparing video dataset...
✅ Video dataset ready: 2949 samples
📝 Preparing text dataset...
✅ Text dataset ready: 2949 samples

📊 Summary: 3 datasets prepared

🔍 Verifying dataset alignment...
Dataset sizes: {'audio': 2949, 'video': 2949, 'text': 2949}
✅ All datasets aligned with 2949 samples each


In [6]:
# 🎵 PROPER AUDIO FEATURE EXTRACTION
def extract_audio_features_106(audio_waveform, sample_rate=16000):
    """
    Extract 106-dimensional audio features from raw waveform
    This mimics the features likely used to train your MLP model
    """
    try:
        import librosa
        import numpy as np
        
        # Ensure audio is numpy array and flatten it
        if torch.is_tensor(audio_waveform):
            audio = audio_waveform.cpu().numpy()
        else:
            audio = audio_waveform
            
        # Flatten audio if multi-dimensional
        if len(audio.shape) > 1:
            audio = audio.flatten()
        
        # Convert to float32 and handle empty or very short audio
        audio = audio.astype(np.float32)
        
        if len(audio) == 0:
            print("⚠️ Empty audio, returning zero features")
            return np.zeros(106, dtype=np.float32)
        
        if len(audio) < sample_rate // 10:  # Less than 0.1 seconds
            print(f"⚠️ Very short audio ({len(audio)} samples), padding...")
            # Pad with zeros to minimum length
            min_length = sample_rate // 10
            padded_audio = np.zeros(min_length, dtype=np.float32)
            padded_audio[:len(audio)] = audio
            audio = padded_audio
        
        # Normalize audio safely
        max_val = np.max(np.abs(audio))
        if max_val > 0:
            audio = audio / max_val
        
        features = []
        
        try:
            # 1. MFCC Features (13 coefficients × 4 statistics = 52 features)
            mfccs = librosa.feature.mfcc(y=audio, sr=sample_rate, n_mfcc=13)
            for i in range(13):
                mfcc_coeffs = mfccs[i]
                features.extend([
                    float(np.mean(mfcc_coeffs)), 
                    float(np.std(mfcc_coeffs)), 
                    float(np.min(mfcc_coeffs)), 
                    float(np.max(mfcc_coeffs))
                ])
        except Exception as e:
            print(f"⚠️ MFCC extraction failed: {e}, using zeros")
            features.extend([0.0] * 52)
        
        try:
            # 2. Spectral Features (6 features)
            spectral_centroids = librosa.feature.spectral_centroid(y=audio, sr=sample_rate)[0]
            spectral_rolloff = librosa.feature.spectral_rolloff(y=audio, sr=sample_rate)[0]
            spectral_bandwidth = librosa.feature.spectral_bandwidth(y=audio, sr=sample_rate)[0]
            
            features.extend([
                float(np.mean(spectral_centroids)), float(np.std(spectral_centroids)),
                float(np.mean(spectral_rolloff)), float(np.std(spectral_rolloff)),
                float(np.mean(spectral_bandwidth)), float(np.std(spectral_bandwidth))
            ])
        except Exception as e:
            print(f"⚠️ Spectral features failed: {e}, using zeros")
            features.extend([0.0] * 6)
        
        try:
            # 3. Zero Crossing Rate (1 feature)
            zcr = librosa.feature.zero_crossing_rate(audio)
            features.append(float(np.mean(zcr)))
        except Exception as e:
            print(f"⚠️ ZCR failed: {e}, using zero")
            features.append(0.0)
        
        try:
            # 4. Chroma Features (12 features)
            chroma = librosa.feature.chroma_stft(y=audio, sr=sample_rate)
            features.extend([float(np.mean(chroma[i])) for i in range(12)])
        except Exception as e:
            print(f"⚠️ Chroma features failed: {e}, using zeros")
            features.extend([0.0] * 12)
        
        try:
            # 5. RMS Energy (1 feature)
            rms = librosa.feature.rms(y=audio)
            features.append(float(np.mean(rms)))
        except Exception as e:
            print(f"⚠️ RMS failed: {e}, using zero")
            features.append(0.0)
        
        try:
            # 6. Tempo (1 feature)
            tempo, _ = librosa.beat.beat_track(y=audio, sr=sample_rate)
            features.append(float(tempo))
        except Exception as e:
            print(f"⚠️ Tempo extraction failed: {e}, using default")
            features.append(120.0)  # Default tempo
        
        try:
            # 7. Additional spectral features to reach 106
            mel_spectrogram = librosa.feature.melspectrogram(y=audio, sr=sample_rate)
            features.extend([
                float(np.mean(mel_spectrogram)), 
                float(np.std(mel_spectrogram)),
                float(np.min(mel_spectrogram)), 
                float(np.max(mel_spectrogram))
            ])
        except Exception as e:
            print(f"⚠️ Mel spectrogram failed: {e}, using zeros")
            features.extend([0.0] * 4)
        
        # Ensure we have exactly 106 features
        while len(features) < 106:
            features.append(0.0)
        features = features[:106]
        
        # Convert all to float32 and ensure no NaN/inf values
        features = np.array(features, dtype=np.float32)
        features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
        
        return features
        
    except ImportError:
        print("⚠️ librosa not installed. Using simple features instead.")
        return extract_simple_audio_features(audio_waveform)
    except Exception as e:
        print(f"⚠️ Error extracting audio features: {e}")
        return extract_simple_audio_features(audio_waveform)

def extract_simple_audio_features(audio_waveform):
    """
    Fallback: Simple statistical features when librosa is not available
    """
    if torch.is_tensor(audio_waveform):
        audio = audio_waveform.cpu().numpy()
    else:
        audio = audio_waveform
    
    # Flatten if needed
    if len(audio.shape) > 1:
        audio = audio.flatten()
    
    # Convert to float32
    audio = audio.astype(np.float32)
    
    if len(audio) == 0:
        return np.zeros(106, dtype=np.float32)
    
    # Basic statistical features
    features = [
        float(np.mean(audio)), float(np.std(audio)), 
        float(np.min(audio)), float(np.max(audio)),
        float(np.median(audio)), 
        float(np.percentile(audio, 25)), float(np.percentile(audio, 75)),
        float(np.sum(audio > 0) / len(audio)),  # positive ratio
        float(np.sum(np.abs(audio) > 0.1) / len(audio)),  # above threshold ratio
    ]
    
    # Pad to 106 features
    while len(features) < 106:
        features.append(0.0)
    
    return np.array(features[:106], dtype=np.float32)

# Test the audio feature extraction
print("🧪 Testing Fixed Audio Feature Extraction...")
if 'audio' in datasets and len(datasets['audio']) > 0:
    test_audio, _, _ = datasets['audio'][0]
    print(f"Raw audio shape: {test_audio.shape}")
    
    # Test feature extraction
    features = extract_audio_features_106(test_audio)
    print(f"Extracted features shape: {features.shape}")
    print(f"Feature range: [{features.min():.3f}, {features.max():.3f}]")
    print(f"Feature type: {features.dtype}")
    print("✅ Fixed audio feature extraction ready!")
else:
    print("❌ No audio dataset available for testing")

🧪 Testing Fixed Audio Feature Extraction...
Raw audio shape: torch.Size([16000])
Extracted features shape: (106,)
Feature range: [-524.842, 2296.201]
Feature type: float32
✅ Fixed audio feature extraction ready!


In [7]:
# 🔧 OPTIMIZED EMBEDDING EXTRACTION - FIXED VERSION
import sys
import time
from datetime import datetime

def extract_all_embeddings_optimized(max_samples=None, save_embeddings=True):
    """
    OPTIMIZED: Extract embeddings from all three modalities efficiently
    Processes all modalities for the same sample together, avoiding redundancy
    """
    results = {
        'audio': {'success': 0, 'errors': 0, 'error_details': []},
        'video': {'success': 0, 'errors': 0, 'error_details': []},
        'text': {'success': 0, 'errors': 0, 'error_details': []}
    }
    
    if max_samples is None:
        max_samples = 20000
    
    # Determine the actual number of samples to process
    available_samples = {}
    for modality in ['audio', 'video', 'text']:
        if modality in datasets:
            available_samples[modality] = len(datasets[modality])
        else:
            available_samples[modality] = 0
    
    # Use the minimum available samples across all modalities
    min_samples = min(available_samples.values()) if available_samples else 0
    actual_samples = min(min_samples, max_samples)
    
    print(f"🎯 OPTIMIZED EXTRACTION: Processing {actual_samples:,} samples")
    print(f"📊 Available samples per modality: {available_samples}")
    print(f"🚀 Processing all modalities for each sample together...")
    print("=" * 60)
    sys.stdout.flush()
    
    start_time = time.time()
    
    # Process samples efficiently - all modalities per sample
    for i in range(actual_samples):
        sample_start_time = time.time()
        
        # Audio Processing
        if 'audio' in datasets and audio_model is not None:
            try:
                audio_waveform, label, video_id = datasets['audio'][i]
                
                # Extract features
                features = extract_audio_features_106(audio_waveform)
                features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0)
                
                with torch.no_grad():
                    if hasattr(audio_model, 'cuda'):
                        features_tensor = features_tensor.cuda()
                    
                    audio_embedding = audio_model(features_tensor)
                    if audio_embedding.dim() > 1:
                        audio_embedding = audio_embedding.squeeze()
                    audio_embedding = audio_embedding.cpu().numpy()
                
                # Validate and save
                if not (np.any(np.isnan(audio_embedding)) or np.any(np.isinf(audio_embedding))):
                    if save_embeddings:
                        audio_save_path = f"specialists/audio/audio_embedding_{i:06d}.npy"
                        np.save(audio_save_path, audio_embedding)
                    results['audio']['success'] += 1
                else:
                    raise ValueError("Invalid embedding values")
                    
            except Exception as e:
                results['audio']['errors'] += 1
                if len(results['audio']['error_details']) < 5:
                    results['audio']['error_details'].append(f"Sample {i}: {str(e)}")
                    print(f"❌ Audio error {i}: {e}")
                    sys.stdout.flush()
        
        # Video Processing
        if 'video' in datasets and video_model is not None:
            try:
                frames, label, video_id = datasets['video'][i]
                
                # Format frames
                if isinstance(frames, list):
                    frames = torch.stack(frames)
                if frames.dim() == 4:
                    frames = frames.unsqueeze(0)
                
                with torch.no_grad():
                    if hasattr(video_model, 'cuda'):
                        frames = frames.cuda()
                    
                    video_embedding = video_model(frames)
                    if video_embedding.dim() > 1:
                        video_embedding = video_embedding.squeeze()
                    video_embedding = video_embedding.cpu().numpy()
                
                # Validate and save
                if not (np.any(np.isnan(video_embedding)) or np.any(np.isinf(video_embedding))):
                    if save_embeddings:
                        video_save_path = f"specialists/video/video_embedding_{i:06d}.npy"
                        np.save(video_save_path, video_embedding)
                    results['video']['success'] += 1
                else:
                    raise ValueError("Invalid embedding values")
                    
            except Exception as e:
                results['video']['errors'] += 1
                if len(results['video']['error_details']) < 5:
                    results['video']['error_details'].append(f"Sample {i}: {str(e)}")
                    print(f"❌ Video error {i}: {e}")
                    sys.stdout.flush()
        
        # Text Processing
        if 'text' in datasets and text_model is not None:
            try:
                sample = datasets['text'][i]
                
                if len(sample) == 3:
                    tokenized, label, metadata = sample
                    input_ids = torch.tensor(tokenized['input_ids'])
                    attention_mask = torch.tensor(tokenized['attention_mask'])
                elif len(sample) == 4:
                    input_ids, attention_mask, label, video_id = sample
                else:
                    raise ValueError(f"Unexpected text sample format: {len(sample)} elements")
                
                with torch.no_grad():
                    if hasattr(text_model, 'cuda'):
                        input_ids = input_ids.cuda()
                        attention_mask = attention_mask.cuda()
                    
                    text_embedding = text_model(input_ids.unsqueeze(0), attention_mask.unsqueeze(0))
                    if text_embedding.dim() > 1:
                        text_embedding = text_embedding.squeeze()
                    text_embedding = text_embedding.cpu().numpy()
                
                # Validate and save
                if not (np.any(np.isnan(text_embedding)) or np.any(np.isinf(text_embedding))):
                    if save_embeddings:
                        text_save_path = f"specialists/text/text_embedding_{i:06d}.npy"
                        np.save(text_save_path, text_embedding)
                    results['text']['success'] += 1
                else:
                    raise ValueError("Invalid embedding values")
                    
            except Exception as e:
                results['text']['errors'] += 1
                if len(results['text']['error_details']) < 5:
                    results['text']['error_details'].append(f"Sample {i}: {str(e)}")
                    print(f"❌ Text error {i}: {e}")
                    sys.stdout.flush()
        
        # Progress reporting with time estimates
        if i % 100 == 0 or i in [1, 10, 50]:  # More frequent progress updates
            elapsed_time = time.time() - start_time
            if i > 0:
                avg_time_per_sample = elapsed_time / (i + 1)
                remaining_samples = actual_samples - (i + 1)
                estimated_remaining_time = avg_time_per_sample * remaining_samples
                
                # Format time estimates
                def format_time(seconds):
                    if seconds < 60:
                        return f"{seconds:.1f}s"
                    elif seconds < 3600:
                        return f"{seconds/60:.1f}m"
                    else:
                        return f"{seconds/3600:.1f}h"
                
                print(f"⚡ Progress: {i+1:,}/{actual_samples:,} ({(i+1)/actual_samples*100:.1f}%) | "
                      f"Time: {format_time(elapsed_time)} elapsed, {format_time(estimated_remaining_time)} remaining | "
                      f"Speed: {(i+1)/elapsed_time:.1f} samples/sec")
            else:
                print(f"⚡ Starting processing... Sample {i+1:,}/{actual_samples:,}")
            
            sys.stdout.flush()
    
    # Final summary
    total_time = time.time() - start_time
    print(f"\n🎯 === OPTIMIZED EXTRACTION COMPLETED ===")
    print(f"⏱️  Total time: {total_time/60:.1f} minutes ({total_time:.1f} seconds)")
    print(f"⚡ Average speed: {actual_samples/total_time:.1f} samples/second")
    print()
    
    for modality, result in results.items():
        total_processed = result['success'] + result['errors']
        success_rate = (result['success'] / total_processed * 100) if total_processed > 0 else 0
        print(f"📊 {modality.upper()}: {result['success']:,} success, {result['errors']:,} errors ({success_rate:.1f}% success)")
        
        if result['error_details']:
            print(f"   Sample errors: {result['error_details'][:3]}")
    
    sys.stdout.flush()
    return results

# Test the optimized version with a small batch first
print("🧪 Testing optimized extraction with small batch...")
sys.stdout.flush()
test_results_optimized = extract_all_embeddings_optimized(max_samples=50, save_embeddings=False)

🧪 Testing optimized extraction with small batch...
🎯 OPTIMIZED EXTRACTION: Processing 50 samples
📊 Available samples per modality: {'audio': 2949, 'video': 2949, 'text': 2949}
🚀 Processing all modalities for each sample together...
⚡ Starting processing... Sample 1/50
⚡ Progress: 2/50 (4.0%) | Time: 0.4s elapsed, 8.8s remaining | Speed: 5.5 samples/sec
⚡ Progress: 11/50 (22.0%) | Time: 0.7s elapsed, 2.3s remaining | Speed: 16.8 samples/sec

🎯 === OPTIMIZED EXTRACTION COMPLETED ===
⏱️  Total time: 0.0 minutes (1.6 seconds)
⚡ Average speed: 31.5 samples/second

📊 AUDIO: 50 success, 0 errors (100.0% success)
📊 VIDEO: 50 success, 0 errors (100.0% success)
📊 TEXT: 50 success, 0 errors (100.0% success)


# 🚀 Embedding Extraction

This section contains the main embedding extraction pipeline with all necessary functions and execution.

In [8]:
# 🚀 EXECUTE OPTIMIZED EMBEDDING EXTRACTION & CREATE FUSION MANIFEST
import os
import pandas as pd
import sys
from datetime import datetime

print("🎯 STARTING OPTIMIZED EMBEDDING EXTRACTION PIPELINE")
print(f"🕐 Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("=" * 60)
sys.stdout.flush()

# Step 1: Create directories
print("\n📁 Step 1: Creating directories...")
os.makedirs("specialists/audio", exist_ok=True)
os.makedirs("specialists/video", exist_ok=True) 
os.makedirs("specialists/text", exist_ok=True)
print("✅ Directories created")
sys.stdout.flush()

# Step 2: Execute the OPTIMIZED extraction function
print("\n🚀 Step 2: Extracting embeddings with optimized pipeline...")
print("⚡ This processes all modalities together for maximum efficiency")
sys.stdout.flush()

# Run the OPTIMIZED extraction
embedding_results_final = extract_all_embeddings_optimized(max_samples=20000, save_embeddings=True)

# Step 3: Check extraction results
print("\n📊 Step 3: Validating extraction results...")
sys.stdout.flush()

try:
    audio_count = len([f for f in os.listdir("specialists/audio") if f.endswith('.npy')])
    video_count = len([f for f in os.listdir("specialists/video") if f.endswith('.npy')])
    text_count = len([f for f in os.listdir("specialists/text") if f.endswith('.npy')])
    
    print(f"   🎵 Audio embeddings: {audio_count:,}")
    print(f"   🎬 Video embeddings: {video_count:,}")
    print(f"   📝 Text embeddings: {text_count:,}")
    sys.stdout.flush()
    
    # Step 4: Create synchronized fusion manifest (only if we have embeddings)
    if audio_count > 0 or video_count > 0 or text_count > 0:
        print("\n🔗 Step 4: Creating fusion manifest...")
        min_count = min(audio_count, video_count, text_count)
        print(f"   Synchronized samples (all 3 modalities): {min_count:,}")
        sys.stdout.flush()
        
        fusion_data = []
        processed_count = 0
        
        for i in range(min_count):
            try:
                # Verify all embedding files exist
                audio_path = f"specialists/audio/audio_embedding_{i:06d}.npy"
                video_path = f"specialists/video/video_embedding_{i:06d}.npy"
                text_path = f"specialists/text/text_embedding_{i:06d}.npy"
                
                if os.path.exists(audio_path) and os.path.exists(video_path) and os.path.exists(text_path):
                    # Get metadata from original dataset
                    if i < len(datasets['audio']):
                        _, label, video_id_raw = datasets['audio'][i]
                        
                        # Handle video_id format
                        if isinstance(video_id_raw, dict):
                            video_id = video_id_raw.get('video_id', i)
                        else:
                            video_id = video_id_raw
                        
                        fusion_data.append({
                            'audio_path': audio_path,
                            'video_path': video_path,
                            'text_path': text_path,
                            'label': int(label),
                            'video_id': str(video_id),
                            'sample_idx': i
                        })
                        processed_count += 1
                        
                if i % 2000 == 0 and i > 0:
                    print(f"   ✅ Processed {processed_count:,} fusion samples...")
                    sys.stdout.flush()
                    
            except Exception as e:
                print(f"   ❌ Error processing sample {i}: {e}")
                continue
        
        # Create and save fusion manifest
        if fusion_data:
            fusion_df = pd.DataFrame(fusion_data)
            fusion_manifest_path = "artifacts/fusion_manifest_20k.csv"
            fusion_df.to_csv(fusion_manifest_path, index=False)
            
            print(f"\n📄 Fusion manifest created: {fusion_manifest_path}")
            print(f"   📊 Total samples: {len(fusion_df):,}")
            print(f"   📊 Unique videos: {fusion_df['video_id'].nunique()}")
            print(f"   📊 Unique labels: {fusion_df['label'].nunique()}")
            
            # Display video distribution
            print(f"\n📈 Video distribution (top 10):")
            video_counts = fusion_df['video_id'].value_counts()
            for i, (video_id, count) in enumerate(video_counts.head(10).items()):
                print(f"   {i+1:2d}. Video {video_id}: {count:,} samples")
            
            print(f"\n🎉 OPTIMIZED EXTRACTION PIPELINE COMPLETED SUCCESSFULLY!")
            print(f"✅ Results: {len(fusion_df):,} embeddings from {fusion_df['video_id'].nunique()} videos")
            print(f"🕐 Completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
            print(f"✅ Ready for fusion model training!")
        else:
            print("❌ No valid fusion samples created!")
    else:
        print("❌ No embeddings were created!")
        
except Exception as e:
    print(f"❌ Error in validation step: {e}")

sys.stdout.flush()

🎯 STARTING OPTIMIZED EMBEDDING EXTRACTION PIPELINE
🕐 Started at: 2025-09-20 13:12:29

📁 Step 1: Creating directories...
✅ Directories created

🚀 Step 2: Extracting embeddings with optimized pipeline...
⚡ This processes all modalities together for maximum efficiency
🎯 OPTIMIZED EXTRACTION: Processing 2,949 samples
📊 Available samples per modality: {'audio': 2949, 'video': 2949, 'text': 2949}
🚀 Processing all modalities for each sample together...
⚡ Starting processing... Sample 1/2,949
⚡ Progress: 2/2,949 (0.1%) | Time: 0.1s elapsed, 2.0m remaining | Speed: 24.3 samples/sec
⚡ Progress: 11/2,949 (0.4%) | Time: 0.4s elapsed, 1.6m remaining | Speed: 30.2 samples/sec
⚡ Progress: 51/2,949 (1.7%) | Time: 1.6s elapsed, 1.5m remaining | Speed: 32.6 samples/sec
⚠️ Very short audio (19 samples), padding...
⚡ Progress: 101/2,949 (3.4%) | Time: 4.7s elapsed, 2.2m remaining | Speed: 21.6 samples/sec
⚡ Progress: 201/2,949 (6.8%) | Time: 13.0s elapsed, 3.0m remaining | Speed: 15.5 samples/sec
⚡ Progre

## 🎯 Summary & Usage Guide

### 📁 **Generated Files:**
After running this notebook, you'll have:

- **`artifacts/embeddings/audio/`** - Audio embeddings (.npy files)
- **`artifacts/embeddings/video/`** - Video embeddings (.npy files)  
- **`artifacts/embeddings/text/`** - Text embeddings (.npy files)
- **`artifacts/embeddings/fusion_manifest.csv`** - Master manifest for fusion training
- **`artifacts/embeddings/sample_fusion_vector.npy`** - Sample concatenated vector

### 🚀 **How to Use for Fusion Training:**

```python
# 1. Load the fusion manifest
fusion_df = pd.read_csv("artifacts/embeddings/fusion_manifest.csv")

# 2. For each sample, load and concatenate embeddings
def load_fusion_sample(row):
    embeddings = []
    
    # Load audio embedding
    if 'embedding_path_audio' in row:
        audio_emb = np.load(row['embedding_path_audio'])
        embeddings.append(audio_emb)
    
    # Load video embedding  
    if 'embedding_path_video' in row:
        video_emb = np.load(row['embedding_path_video'])
        embeddings.append(video_emb)
        
    # Load text embedding
    if 'embedding_path_text' in row:
        text_emb = np.load(row['embedding_path_text'])
        embeddings.append(text_emb)
    
    # Concatenate all modalities
    fusion_vector = np.concatenate(embeddings)
    return fusion_vector, row['label']

# 3. Use in your fusion model training
for _, row in fusion_df.iterrows():
    features, label = load_fusion_sample(row)
    # Train your fusion model with these features
```

### ⚡ **Configuration Options:**
- **Batch Size**: Modify `config['batch_size']` in cell 1
- **Sample Limit**: Adjust `max_samples` in cell 5 
- **Output Directory**: Change `config['output_dir']` in cell 1

### 🔧 **Troubleshooting:**
- **"Model not found"**: Check if your specialist models are trained and saved
- **"Dataset empty"**: Verify your manifest CSV files exist and have data
- **"Out of memory"**: Reduce batch_size or max_samples

### 📊 **Performance Tips:**
- Start with 1000 samples for testing
- Use GPU for faster extraction
- Process in smaller batches if memory limited