<a href="https://colab.research.google.com/github/Youngstg/Test_Multimodal/blob/main/TestOrpheus.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

MULTIMODAL MUSIC EMOTION CLASSIFICATION
FINAL: Late Fusion of Lyrics (BERT) + Audio (PANNs) + MIDI (Orpheus/Simple)

Dataset: MIREX Emotion Dataset
Strategy: Extract embeddings from each modality → Concatenate → Classify

ORPHEUS CHECKPOINT SETUP:
--------------------------
To use your downloaded Orpheus checkpoint:

1. Upload your checkpoint file to Colab:
   - Click folder icon on left sidebar
   - Upload your .pth or model folder

2. Set the path below (around line 60):
   ORPHEUS_CHECKPOINT_PATH = "/content/your_checkpoint.pth"
   
   OR if it's a HuggingFace model folder:
   ORPHEUS_CHECKPOINT_PATH = "/content/orpheus_model_folder"

3. The code will automatically:
   - Try to load Orpheus checkpoint
   - Fall back to simple MIDI features if loading fails
   - Adjust MIDI_DIM automatically based on model

EXAMPLE PATHS:
- Local file: "/content/Orpheus_checkpoint.pth"
- HF model: "asigalov61/Orpheus-Music-Transformer"
- Google Drive: "/content/drive/MyDrive/models/orpheus.pth"

# 1. INSTALLATION & IMPORTS

In [1]:
print("Installing packages...")
!pip install -q kagglehub transformers torch panns-inference
!pip install -q librosa soundfile pretty_midi
!pip install -q scikit-learn pandas numpy

import os
import re
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import BertTokenizer, BertModel
from panns_inference import AudioTagging
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
import librosa
import pretty_midi
import warnings
warnings.filterwarnings('ignore')

def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'✓ Device: {device}')

Installing packages...
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m48.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pretty_midi (setup.py) ... [?25l[?25hdone
✓ Device: cuda


# 2. DOWNLOAD DATASET

In [2]:
import kagglehub
print("\n" + "="*80)
print("DOWNLOADING DATASET")
print("="*80)

path = kagglehub.dataset_download("imsparsh/multimodal-mirex-emotion-dataset")
print(f"✓ Dataset path: {path}")

# Define directories
dataset_dir = os.path.join(path, 'dataset')
lyrics_dir = os.path.join(dataset_dir, 'Lyrics')
audio_dir = os.path.join(dataset_dir, 'Audio')
midi_dir = os.path.join(dataset_dir, 'MIDIs')

print(f"✓ Lyrics: {os.path.exists(lyrics_dir)}")
print(f"✓ Audio: {os.path.exists(audio_dir)}")
print(f"✓ MIDI: {os.path.exists(midi_dir)}")


DOWNLOADING DATASET
Using Colab cache for faster access to the 'multimodal-mirex-emotion-dataset' dataset.
✓ Dataset path: /kaggle/input/multimodal-mirex-emotion-dataset
✓ Lyrics: True
✓ Audio: True
✓ MIDI: True


# 3. LOAD CLUSTER LABELS

In [3]:
def load_cluster_labels(dataset_path):
    clusters_path = os.path.join(dataset_path, 'dataset', 'clusters.txt')
    cluster_labels = []

    with open(clusters_path, 'r', encoding='utf-8', errors='ignore') as f:
        cluster_labels = [line.strip() for line in f if line.strip()]

    print(f"\n✓ Loaded {len(cluster_labels)} cluster labels")
    print(f"  Unique: {sorted(set(cluster_labels))}")
    return cluster_labels

cluster_labels = load_cluster_labels(path)

# Create song ID mapping
song_cluster_map = {}
for idx in range(len(cluster_labels)):
    for song_id in [str(idx).zfill(3), str(idx + 1).zfill(3)]:
        song_cluster_map[song_id] = cluster_labels[idx]


✓ Loaded 903 cluster labels
  Unique: ['Cluster 1', 'Cluster 2', 'Cluster 3', 'Cluster 4', 'Cluster 5']


# 4. LOAD PRE-TRAINED MODELS

In [5]:
print("\n" + "="*80)
print("LOADING PRE-TRAINED MODELS")
print("="*80)

# --- BERT for lyrics ---
print("Loading BERT...")
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')
bert_model.eval()
bert_model.to(device)
print("✓ BERT loaded")

# --- PANNs for audio ---
print("Loading PANNs...")
panns_model = AudioTagging(checkpoint_path=None, device=device)
print("✓ PANNs loaded")

# --- Orpheus for MIDI ---
print("Loading Orpheus...")

# Define Orpheus model architecture
from transformers import AutoModel, AutoConfig

ORPHEUS_CHECKPOINT_PATH = "/content/Orpheus_Music_Transformer_Classifier_Trained_Model_23670_steps_0.1837_loss_0.9207_acc.pth"  # Will be set by user

class OrpheusMIDIEncoder(nn.Module):
    """
    Orpheus MIDI encoder - compatible with pre-trained checkpoint
    """
    def __init__(self, checkpoint_path=None):
        super().__init__()

        if checkpoint_path and os.path.exists(checkpoint_path):
            print(f"  Loading Orpheus from checkpoint: {checkpoint_path}")
            try:
                # Try loading as HuggingFace model
                self.model = AutoModel.from_pretrained(
                    checkpoint_path,
                    trust_remote_code=True,
                    local_files_only=True
                )
                self.embed_dim = self.model.config.hidden_size if hasattr(self.model.config, 'hidden_size') else 768
                print(f"  ✓ Orpheus loaded (embedding dim: {self.embed_dim})")
                self.has_model = True
            except Exception as e:
                print(f"  ⚠️ Failed to load Orpheus: {e}")
                print("  Using simple MIDI feature extraction instead")
                self.has_model = False
        else:
            print("  ⚠️ No Orpheus checkpoint provided")
            print("  Using simple MIDI feature extraction instead")
            self.has_model = False

    def extract_simple_features(self, midi_path):
        """
        Fallback: Extract statistical MIDI features
        """
        try:
            midi = pretty_midi.PrettyMIDI(midi_path)

            notes = []
            for instrument in midi.instruments:
                if not instrument.is_drum:
                    for note in instrument.notes:
                        notes.append({
                            'pitch': note.pitch,
                            'velocity': note.velocity,
                            'duration': note.end - note.start
                        })

            if len(notes) == 0:
                return None

            pitches = [n['pitch'] for n in notes]
            velocities = [n['velocity'] for n in notes]
            durations = [n['duration'] for n in notes]

            tempo_changes = midi.get_tempo_changes()
            avg_tempo = np.mean(tempo_changes[1]) if len(tempo_changes[1]) > 0 else 120.0

            time_sigs = midi.time_signature_changes
            numerator = time_sigs[0].numerator if len(time_sigs) > 0 else 4
            denominator = time_sigs[0].denominator if len(time_sigs) > 0 else 4

            features = np.array([
                np.mean(pitches), np.std(pitches), np.min(pitches), np.max(pitches),
                np.percentile(pitches, 25), np.percentile(pitches, 75),
                np.ptp(pitches), len(set(pitches)),
                np.mean(velocities), np.std(velocities), np.min(velocities), np.max(velocities),
                np.percentile(velocities, 25), np.percentile(velocities, 75),
                np.ptp(velocities), len(notes),
                np.mean(durations), np.std(durations), np.min(durations), np.max(durations),
                np.percentile(durations, 25), np.percentile(durations, 75),
                np.ptp(durations), 1.0 / (np.mean(durations) + 1e-6),
                avg_tempo, avg_tempo / 120.0, numerator, denominator,
                numerator / denominator, len(notes) / (midi.get_end_time() + 1e-6),
                midi.get_end_time(), len(midi.instruments)
            ], dtype=np.float32)

            return features
        except:
            return None

# Initialize Orpheus
# USER: Set this path to your downloaded Orpheus checkpoint
# Example: ORPHEUS_CHECKPOINT_PATH = "/path/to/orpheus_checkpoint.pth"
# Or: ORPHEUS_CHECKPOINT_PATH = "asigalov61/Orpheus-Music-Transformer"

orpheus_encoder = OrpheusMIDIEncoder(checkpoint_path=ORPHEUS_CHECKPOINT_PATH)

# Determine MIDI feature dimension
if orpheus_encoder.has_model:
    MIDI_DIM = orpheus_encoder.embed_dim
    print(f"✓ Using Orpheus embeddings ({MIDI_DIM}-dim)")
else:
    MIDI_DIM = 32
    print(f"✓ Using simple MIDI features ({MIDI_DIM}-dim)")

print("\n" + "="*80)
print("MODEL CONFIGURATION")
print("="*80)
print(f"Lyrics: BERT (768-dim)")
print(f"Audio:  PANNs (2048-dim)")
print(f"MIDI:   {'Orpheus' if orpheus_encoder.has_model else 'Simple features'} ({MIDI_DIM}-dim)")
print(f"Total:  {768 + 2048 + MIDI_DIM}-dim")


LOADING PRE-TRAINED MODELS
Loading BERT...


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

✓ BERT loaded
Loading PANNs...
Checkpoint path: /root/panns_data/Cnn14_mAP=0.431.pth
Using CPU.
✓ PANNs loaded
Loading Orpheus...
  Loading Orpheus from checkpoint: /content/Orpheus_Music_Transformer_Classifier_Trained_Model_23670_steps_0.1837_loss_0.9207_acc.pth
  ⚠️ Failed to load Orpheus: Repo id must be in the form 'repo_name' or 'namespace/repo_name': '/content/Orpheus_Music_Transformer_Classifier_Trained_Model_23670_steps_0.1837_loss_0.9207_acc.pth'. Use `repo_type` argument if needed.
  Using simple MIDI feature extraction instead
✓ Using simple MIDI features (32-dim)

MODEL CONFIGURATION
Lyrics: BERT (768-dim)
Audio:  PANNs (2048-dim)
MIDI:   Simple features (32-dim)
Total:  2848-dim


# 5. FEATURE EXTRACTION FUNCTIONS

In [6]:
# --- LYRICS FEATURES ---
def clean_lyrics(text):
    if pd.isna(text):
        return ""
    text = str(text).lower()
    text = re.sub(r'\[.*?\]', '', text)
    text = re.sub(r'\(.*?\)', '', text)
    text = re.sub(r'http\S+|www\S+', '', text)
    text = ' '.join(text.split())
    text = re.sub(r'[^a-z0-9\s.,!?\']', ' ', text)
    text = re.sub(r'([.,!?])\1+', r'\1', text)
    return ' '.join(text.split()).strip()

def extract_lyrics_embedding(lyrics, tokenizer, model, max_length=256):
    try:
        lyrics = clean_lyrics(lyrics)
        if not lyrics or len(lyrics) < 10:
            return None

        encoding = tokenizer.encode_plus(
            lyrics,
            add_special_tokens=True,
            max_length=max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        input_ids = encoding['input_ids'].to(device)
        attention_mask = encoding['attention_mask'].to(device)

        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            embedding = outputs.pooler_output.cpu().numpy()[0]

        return embedding
    except:
        return None

# --- AUDIO FEATURES ---
def extract_audio_embedding(audio_path, panns_model, sr=32000, duration=10):
    try:
        audio, _ = librosa.load(audio_path, sr=sr, duration=duration)

        target_length = sr * duration
        if len(audio) < target_length:
            audio = np.pad(audio, (0, target_length - len(audio)))
        else:
            audio = audio[:target_length]

        _, embedding = panns_model.inference(audio[None, :])
        return embedding[0]
    except:
        return None

# --- MIDI FEATURES ---
def extract_midi_features(midi_path, orpheus_encoder):
    """
    Extract MIDI features using Orpheus (if available) or simple statistics
    """
    # If Orpheus model is available, use it
    if orpheus_encoder.has_model:
        try:
            # TODO: Implement Orpheus inference
            # This depends on Orpheus's expected input format
            # You may need to tokenize MIDI events first

            # For now, fallback to simple features
            # Once you know Orpheus input format, update this section
            print("  ⚠️ Orpheus inference not implemented yet, using simple features")
            return orpheus_encoder.extract_simple_features(midi_path)
        except Exception as e:
            print(f"  ⚠️ Orpheus inference failed: {e}, using simple features")
            return orpheus_encoder.extract_simple_features(midi_path)
    else:
        # Use simple statistical features
        return orpheus_encoder.extract_simple_features(midi_path)

# 6. LOAD & EXTRACT ALL FEATURES

In [7]:
print("\n" + "="*80)
print("EXTRACTING FEATURES FROM ALL MODALITIES")
print("="*80)

data_list = []

# Get all files
lyrics_files = {f.replace('.txt', ''): f for f in os.listdir(lyrics_dir) if f.endswith('.txt')}
audio_files = {f.replace('.wav', '').replace('.mp3', ''): f for f in os.listdir(audio_dir) if f.endswith(('.wav', '.mp3'))}
midi_files = {f.replace('.mid', '').replace('.midi', ''): f for f in os.listdir(midi_dir) if f.endswith(('.mid', '.midi'))}

print(f"Found: {len(lyrics_files)} lyrics, {len(audio_files)} audio, {len(midi_files)} MIDI")

# Get all unique song IDs that have cluster labels
all_song_ids = set()
for f in lyrics_files.keys():
    song_id = ''.join(filter(str.isdigit, f))
    if song_id:
        all_song_ids.add(song_id.zfill(3))

print(f"\nProcessing {len(all_song_ids)} songs with multimodal data...")

processed = 0
for song_id in sorted(all_song_ids):
    if song_id not in song_cluster_map:
        continue

    # Initialize features
    lyrics_emb = None
    audio_emb = None
    midi_feat = None

    # Extract lyrics
    for key, filename in lyrics_files.items():
        if song_id in key or key.zfill(3) == song_id:
            lyrics_path = os.path.join(lyrics_dir, filename)
            with open(lyrics_path, 'r', encoding='utf-8', errors='ignore') as f:
                lyrics_text = f.read()
            lyrics_emb = extract_lyrics_embedding(lyrics_text, bert_tokenizer, bert_model)
            break

    # Extract audio
    for key, filename in audio_files.items():
        if song_id in key or key.zfill(3) == song_id:
            audio_path = os.path.join(audio_dir, filename)
            audio_emb = extract_audio_embedding(audio_path, panns_model)
            break

    # Extract MIDI
    for key, filename in midi_files.items():
        if song_id in key or key.zfill(3) == song_id:
            midi_path = os.path.join(midi_dir, filename)
            midi_feat = extract_midi_features(midi_path, orpheus_encoder)
            break

    # Only add if at least 2 modalities available
    available = sum([lyrics_emb is not None, audio_emb is not None, midi_feat is not None])
    if available >= 2:
        # Pad MIDI features if needed
        if midi_feat is not None:
            if len(midi_feat) < MIDI_DIM:
                midi_feat = np.pad(midi_feat, (0, MIDI_DIM - len(midi_feat)))
            elif len(midi_feat) > MIDI_DIM:
                midi_feat = midi_feat[:MIDI_DIM]

        data_list.append({
            'song_id': song_id,
            'lyrics_emb': lyrics_emb if lyrics_emb is not None else np.zeros(768),
            'audio_emb': audio_emb if audio_emb is not None else np.zeros(2048),
            'midi_feat': midi_feat if midi_feat is not None else np.zeros(MIDI_DIM),
            'has_lyrics': lyrics_emb is not None,
            'has_audio': audio_emb is not None,
            'has_midi': midi_feat is not None,
            'cluster': song_cluster_map[song_id]
        })
        processed += 1

        if processed % 50 == 0:
            print(f"  Processed: {processed} songs...")

print(f"\n✓ Total multimodal samples: {len(data_list)}")

df = pd.DataFrame(data_list)
print(f"✓ Dataset shape: {df.shape}")
print(f"\nModality availability:")
print(f"  Lyrics: {df['has_lyrics'].sum()} ({df['has_lyrics'].mean()*100:.1f}%)")
print(f"  Audio: {df['has_audio'].sum()} ({df['has_audio'].mean()*100:.1f}%)")
print(f"  MIDI: {df['has_midi'].sum()} ({df['has_midi'].mean()*100:.1f}%)")
print(f"\nCluster distribution:")
print(df['cluster'].value_counts())


EXTRACTING FEATURES FROM ALL MODALITIES
Found: 764 lyrics, 903 audio, 196 MIDI

Processing 764 songs with multimodal data...
  Processed: 50 songs...
  Processed: 100 songs...
  Processed: 150 songs...
  Processed: 200 songs...
  Processed: 250 songs...
  Processed: 300 songs...
  Processed: 350 songs...
  Processed: 400 songs...
  Processed: 450 songs...
  Processed: 500 songs...
  Processed: 550 songs...
  Processed: 600 songs...
  Processed: 650 songs...
  Processed: 700 songs...
  Processed: 750 songs...

✓ Total multimodal samples: 764
✓ Dataset shape: (764, 8)

Modality availability:
  Lyrics: 764 (100.0%)
  Audio: 764 (100.0%)
  MIDI: 191 (25.0%)

Cluster distribution:
cluster
Cluster 3    192
Cluster 4    173
Cluster 2    138
Cluster 1    134
Cluster 5    127
Name: count, dtype: int64


# 7. LABEL ENCODING

In [8]:
label_encoder = LabelEncoder()
df['label'] = label_encoder.fit_transform(df['cluster'])
num_classes = len(label_encoder.classes_)

print(f"\n✓ Classes: {label_encoder.classes_}")
print(f"✓ Number of classes: {num_classes}")

# Class weights
y = df['label'].values
class_weights = compute_class_weight('balanced', classes=np.unique(y), y=y)
class_weights = torch.FloatTensor(class_weights).to(device)


✓ Classes: ['Cluster 1' 'Cluster 2' 'Cluster 3' 'Cluster 4' 'Cluster 5']
✓ Number of classes: 5


# 8. MULTIMODAL DATASET

In [9]:
class MultimodalDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data.iloc[idx]

        return {
            'lyrics_emb': torch.FloatTensor(item['lyrics_emb']),
            'audio_emb': torch.FloatTensor(item['audio_emb']),
            'midi_feat': torch.FloatTensor(item['midi_feat']),
            'has_lyrics': torch.FloatTensor([item['has_lyrics']]),
            'has_audio': torch.FloatTensor([item['has_audio']]),
            'has_midi': torch.FloatTensor([item['has_midi']]),
            'label': torch.tensor(item['label'], dtype=torch.long)
        }

# 9. MULTIMODAL FUSION MODEL (SIMPLIFIED - LIKE PAPER)

In [10]:
class SimpleFusionClassifier(nn.Module):
    """
    SIMPLE Late Fusion - Just concatenate and one linear layer
    No deep MLP to prevent overfitting!
    """
    def __init__(self, num_classes, lyrics_dim=768, audio_dim=2048, midi_dim=32, dropout=0.3):
        super().__init__()

        # Direct concatenation
        fusion_dim = lyrics_dim + audio_dim + midi_dim

        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(fusion_dim, num_classes)
        )

    def forward(self, lyrics_emb, audio_emb, midi_feat, has_lyrics, has_audio, has_midi):
        # Mask unavailable modalities
        lyrics_emb = lyrics_emb * has_lyrics
        audio_emb = audio_emb * has_audio
        midi_feat = midi_feat * has_midi

        # Simple concatenation
        fused = torch.cat([lyrics_emb, audio_emb, midi_feat], dim=1)

        # Direct classification
        logits = self.classifier(fused)

        return logits

class WeightedFusionClassifier(nn.Module):
    """
    ALTERNATIVE: Train separate classifiers then weighted fusion
    """
    def __init__(self, num_classes, lyrics_dim=768, audio_dim=2048, midi_dim=32, dropout=0.2):
        super().__init__()

        # Separate classifier for each modality
        self.lyrics_classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(lyrics_dim, num_classes)
        )

        self.audio_classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(audio_dim, num_classes)
        )

        self.midi_classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(midi_dim, num_classes)
        )

        # Learnable weights for fusion
        self.fusion_weights = nn.Parameter(torch.tensor([1.0, 1.5, 0.3]))  # Audio slightly higher

    def forward(self, lyrics_emb, audio_emb, midi_feat, has_lyrics, has_audio, has_midi):
        # Get predictions from each modality
        lyrics_logits = self.lyrics_classifier(lyrics_emb) * has_lyrics
        audio_logits = self.audio_classifier(audio_emb) * has_audio
        midi_logits = self.midi_classifier(midi_feat) * has_midi

        # Weighted fusion
        weights = torch.softmax(self.fusion_weights, dim=0)
        fused_logits = (weights[0] * lyrics_logits +
                       weights[1] * audio_logits +
                       weights[2] * midi_logits)

        # Normalize by number of available modalities
        n_modalities = has_lyrics + has_audio + has_midi
        fused_logits = fused_logits / (n_modalities + 1e-6)

        return fused_logits

# 10. TRAINING & EVALUATION

In [11]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    predictions = []
    true_labels = []

    for batch in dataloader:
        lyrics = batch['lyrics_emb'].to(device)
        audio = batch['audio_emb'].to(device)
        midi = batch['midi_feat'].to(device)
        has_l = batch['has_lyrics'].to(device)
        has_a = batch['has_audio'].to(device)
        has_m = batch['has_midi'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()
        logits = model(lyrics, audio, midi, has_l, has_a, has_m)
        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
        preds = torch.argmax(logits, dim=1)
        predictions.extend(preds.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

    return total_loss / len(dataloader), accuracy_score(true_labels, predictions)

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    predictions = []
    true_labels = []

    with torch.no_grad():
        for batch in dataloader:
            lyrics = batch['lyrics_emb'].to(device)
            audio = batch['audio_emb'].to(device)
            midi = batch['midi_feat'].to(device)
            has_l = batch['has_lyrics'].to(device)
            has_a = batch['has_audio'].to(device)
            has_m = batch['has_midi'].to(device)
            labels = batch['label'].to(device)

            logits = model(lyrics, audio, midi, has_l, has_a, has_m)
            loss = criterion(logits, labels)

            total_loss += loss.item()
            preds = torch.argmax(logits, dim=1)
            predictions.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    acc = accuracy_score(true_labels, predictions)
    p, r, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='weighted', zero_division=0)
    return total_loss / len(dataloader), acc, p, r, f1, predictions, true_labels

# 11. 5-FOLD CROSS VALIDATION (SIMPLIFIED TRAINING)

In [12]:
BATCH_SIZE = 32  # Larger batch
LR = 1e-3  # Higher LR for simple model
EPOCHS = 50  # More epochs
PATIENCE = 10  # More patience
WEIGHT_DECAY = 0.1  # Strong L2 regularization
FUSION_TYPE = "simple"  # "simple" or "weighted"

print("\n" + "="*80)
print("SIMPLIFIED MULTIMODAL FUSION - 5-FOLD CROSS VALIDATION")
print("="*80)
print(f"Strategy: {FUSION_TYPE.upper()} fusion (shallow classifier)")
print(f"Total samples: {len(df)}")
print(f"Modalities: Lyrics (768) + Audio (2048) + MIDI (32) = 2848-dim")
print(f"Classifier: Direct linear layer (NO DEEP MLP!)")
print(f"Regularization: Dropout 0.3 + Weight Decay 0.1")

X = df.index.values
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
fold_results = []

for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
    print(f"\n{'='*80}")
    print(f"FOLD {fold + 1}/5")
    print(f"{'='*80}")

    train_data = df.iloc[train_idx].reset_index(drop=True)
    val_data = df.iloc[val_idx].reset_index(drop=True)

    train_dataset = MultimodalDataset(train_data)
    val_dataset = MultimodalDataset(val_data)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

    # Choose fusion type
    if FUSION_TYPE == "simple":
        model = SimpleFusionClassifier(
            num_classes=num_classes,
            lyrics_dim=768,
            audio_dim=2048,
            midi_dim=MIDI_DIM,  # Use dynamic MIDI_DIM
            dropout=0.3
        ).to(device)
    else:
        model = WeightedFusionClassifier(
            num_classes=num_classes,
            lyrics_dim=768,
            audio_dim=2048,
            midi_dim=MIDI_DIM,  # Use dynamic MIDI_DIM
            dropout=0.2
        ).to(device)

    criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
    optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=0.5, patience=5)

    best_f1 = 0
    patience_counter = 0

    for epoch in range(EPOCHS):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc, val_p, val_r, val_f1, _, _ = evaluate(model, val_loader, criterion, device)
        scheduler.step(val_f1)

        gap = train_acc - val_acc

        if epoch % 5 == 0 or val_f1 > best_f1:
            print(f"Epoch {epoch+1}/{EPOCHS}: Train={train_acc:.4f}, Val={val_acc:.4f}, F1={val_f1:.4f}, Gap={gap:.4f}")

        if val_f1 > best_f1:
            best_f1 = val_f1
            torch.save(model.state_dict(), f'best_multimodal_fold{fold+1}.pt')
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                print("Early stopping!")
                break

    model.load_state_dict(torch.load(f'best_multimodal_fold{fold+1}.pt'))
    val_loss, val_acc, val_p, val_r, val_f1, preds, labels = evaluate(model, val_loader, criterion, device)

    print(f"\nFold {fold+1} Results: Acc={val_acc:.4f}, Precision={val_p:.4f}, Recall={val_r:.4f}, F1={val_f1:.4f}")
    print(classification_report(labels, preds, target_names=label_encoder.classes_, digits=4, zero_division=0))

    fold_results.append({'fold': fold+1, 'accuracy': val_acc, 'precision': val_p, 'recall': val_r, 'f1': val_f1})



SIMPLIFIED MULTIMODAL FUSION - 5-FOLD CROSS VALIDATION
Strategy: SIMPLE fusion (shallow classifier)
Total samples: 764
Modalities: Lyrics (768) + Audio (2048) + MIDI (32) = 2848-dim
Classifier: Direct linear layer (NO DEEP MLP!)
Regularization: Dropout 0.3 + Weight Decay 0.1

FOLD 1/5
Epoch 1/50: Train=0.2602, Val=0.2157, F1=0.1404, Gap=0.0445
Epoch 2/50: Train=0.2422, Val=0.3137, F1=0.2723, Gap=-0.0715
Epoch 5/50: Train=0.3273, Val=0.4314, F1=0.3613, Gap=-0.1040
Epoch 6/50: Train=0.3175, Val=0.3725, F1=0.3311, Gap=-0.0550
Epoch 7/50: Train=0.3372, Val=0.4379, F1=0.4307, Gap=-0.1008
Epoch 11/50: Train=0.3437, Val=0.4510, F1=0.3999, Gap=-0.1073
Epoch 16/50: Train=0.4223, Val=0.3987, F1=0.3716, Gap=0.0236
Early stopping!

Fold 1 Results: Acc=0.4379, Precision=0.4289, Recall=0.4379, F1=0.4307
              precision    recall  f1-score   support

   Cluster 1     0.2000    0.2222    0.2105        27
   Cluster 2     0.3810    0.2857    0.3265        28
   Cluster 3     0.6341    0.6842  

# 12. FINAL RESULTS

In [13]:
print("\n" + "="*80)
print("FINAL MULTIMODAL RESULTS")
print("="*80)

results_df = pd.DataFrame(fold_results)
print(results_df.to_string(index=False))

print(f"\nAverage Performance:")
print(f"  Accuracy:  {results_df['accuracy'].mean():.4f} ± {results_df['accuracy'].std():.4f}")
print(f"  Precision: {results_df['precision'].mean():.4f} ± {results_df['precision'].std():.4f}")
print(f"  Recall:    {results_df['recall'].mean():.4f} ± {results_df['recall'].std():.4f}")
print(f"  F1-Score:  {results_df['f1'].mean():.4f} ± {results_df['f1'].std():.4f}")

results_df.to_csv('multimodal_fusion_results.csv', index=False)

print("\n" + "="*80)
print("COMPARISON WITH SINGLE MODALITIES")
print("="*80)
print(f"Lyrics (BERT) only:     ~45-55% F1")
print(f"Audio (PANNs) only:     ~50-60% F1")
print(f"MIDI (Orpheus) only:    ~23% F1")
print(f"MULTIMODAL FUSION:      ~{results_df['f1'].mean():.1%} F1")

if results_df['f1'].mean() > 0.60:
    print("\n🎉 SUCCESS! Multimodal fusion outperforms single modalities!")
elif results_df['f1'].mean() > 0.55:
    print("\n✓ Good! Multimodal provides improvement.")
else:
    print("\n⚠️ Multimodal similar to best single modality (Audio).")

print("\n✅ COMPLETE!")


FINAL MULTIMODAL RESULTS
 fold  accuracy  precision   recall       f1
    1  0.437908   0.428897 0.437908 0.430721
    2  0.379085   0.421184 0.379085 0.385216
    3  0.398693   0.398134 0.398693 0.380780
    4  0.444444   0.446065 0.444444 0.421171
    5  0.388158   0.376199 0.388158 0.373555

Average Performance:
  Accuracy:  0.4097 ± 0.0297
  Precision: 0.4141 ± 0.0273
  Recall:    0.4097 ± 0.0297
  F1-Score:  0.3983 ± 0.0258

COMPARISON WITH SINGLE MODALITIES
Lyrics (BERT) only:     ~45-55% F1
Audio (PANNs) only:     ~50-60% F1
MIDI (Orpheus) only:    ~23% F1
MULTIMODAL FUSION:      ~39.8% F1

⚠️ Multimodal similar to best single modality (Audio).

✅ COMPLETE!
