<a href="https://colab.research.google.com/github/Youngstg/Test_Multimodal/blob/main/TestMultimodal.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

MULTIMODAL MUSIC EMOTION CLASSIFICATION
FINAL: Late Fusion of Lyrics (BERT) + Audio (PANNs) + MIDI (Simple Features)

Dataset: MIREX Emotion Dataset
Strategy: Extract embeddings from each modality ‚Üí Concatenate ‚Üí Classify

# 1. INSTALLATION & IMPORTS

In [None]:
print("Installing packages...")
!pip install -q kagglehub transformers torch panns-inference
!pip install -q librosa soundfile pretty_midi
!pip install -q scikit-learn pandas numpy

import os
import re
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import BertTokenizer, BertModel
from panns_inference import AudioTagging
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
import librosa
import pretty_midi
import warnings
warnings.filterwarnings('ignore')

def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'‚úì Device: {device}')

Installing packages...
‚úì Device: cpu


# 2. DOWNLOAD DATASET

In [None]:
import kagglehub
print("\n" + "="*80)
print("DOWNLOADING DATASET")
print("="*80)

path = kagglehub.dataset_download("imsparsh/multimodal-mirex-emotion-dataset")
print(f"‚úì Dataset path: {path}")

# Define directories
dataset_dir = os.path.join(path, 'dataset')
lyrics_dir = os.path.join(dataset_dir, 'Lyrics')
audio_dir = os.path.join(dataset_dir, 'Audio')
midi_dir = os.path.join(dataset_dir, 'MIDIs')

print(f"‚úì Lyrics: {os.path.exists(lyrics_dir)}")
print(f"‚úì Audio: {os.path.exists(audio_dir)}")
print(f"‚úì MIDI: {os.path.exists(midi_dir)}")



DOWNLOADING DATASET
‚úì Dataset path: /root/.cache/kagglehub/datasets/imsparsh/multimodal-mirex-emotion-dataset/versions/1
‚úì Lyrics: True
‚úì Audio: True
‚úì MIDI: True


# 3. LOAD CLUSTER LABELS

In [None]:
def load_cluster_labels(dataset_path):
    clusters_path = os.path.join(dataset_path, 'dataset', 'clusters.txt')
    cluster_labels = []

    with open(clusters_path, 'r', encoding='utf-8', errors='ignore') as f:
        cluster_labels = [line.strip() for line in f if line.strip()]

    print(f"\n‚úì Loaded {len(cluster_labels)} cluster labels")
    print(f"  Unique: {sorted(set(cluster_labels))}")
    return cluster_labels

cluster_labels = load_cluster_labels(path)

# Create song ID mapping
song_cluster_map = {}
for idx in range(len(cluster_labels)):
    for song_id in [str(idx).zfill(3), str(idx + 1).zfill(3)]:
        song_cluster_map[song_id] = cluster_labels[idx]


‚úì Loaded 903 cluster labels
  Unique: ['Cluster 1', 'Cluster 2', 'Cluster 3', 'Cluster 4', 'Cluster 5']


# 4. LOAD PRE-TRAINED MODELS

In [None]:
print("\n" + "="*80)
print("LOADING PRE-TRAINED MODELS")
print("="*80)

# BERT for lyrics
print("Loading BERT...")
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')
bert_model.eval()
bert_model.to(device)
print("‚úì BERT loaded")

# PANNs for audio
print("Loading PANNs...")
panns_model = AudioTagging(checkpoint_path=None, device=device)
print("‚úì PANNs loaded")


LOADING PRE-TRAINED MODELS
Loading BERT...
‚úì BERT loaded
Loading PANNs...
Checkpoint path: /root/panns_data/Cnn14_mAP=0.431.pth
Using CPU.
‚úì PANNs loaded


# 5. FEATURE EXTRACTION FUNCTIONS

In [None]:
# --- LYRICS FEATURES ---
def clean_lyrics(text):
    if pd.isna(text):
        return ""
    text = str(text).lower()
    text = re.sub(r'\[.*?\]', '', text)
    text = re.sub(r'\(.*?\)', '', text)
    text = re.sub(r'http\S+|www\S+', '', text)
    text = ' '.join(text.split())
    text = re.sub(r'[^a-z0-9\s.,!?\']', ' ', text)
    text = re.sub(r'([.,!?])\1+', r'\1', text)
    return ' '.join(text.split()).strip()

def extract_lyrics_embedding(lyrics, tokenizer, model, max_length=256):
    try:
        lyrics = clean_lyrics(lyrics)
        if not lyrics or len(lyrics) < 10:
            return None

        encoding = tokenizer.encode_plus(
            lyrics,
            add_special_tokens=True,
            max_length=max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        input_ids = encoding['input_ids'].to(device)
        attention_mask = encoding['attention_mask'].to(device)

        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            embedding = outputs.pooler_output.cpu().numpy()[0]

        return embedding
    except:
        return None

# --- AUDIO FEATURES ---
def extract_audio_embedding(audio_path, panns_model, sr=32000, duration=10):
    try:
        audio, _ = librosa.load(audio_path, sr=sr, duration=duration)

        target_length = sr * duration
        if len(audio) < target_length:
            audio = np.pad(audio, (0, target_length - len(audio)))
        else:
            audio = audio[:target_length]

        _, embedding = panns_model.inference(audio[None, :])
        return embedding[0]
    except:
        return None

# --- MIDI FEATURES ---
def extract_midi_features(midi_path):
    try:
        midi = pretty_midi.PrettyMIDI(midi_path)

        # Extract statistical features
        notes = []
        for instrument in midi.instruments:
            if not instrument.is_drum:
                for note in instrument.notes:
                    notes.append({
                        'pitch': note.pitch,
                        'velocity': note.velocity,
                        'duration': note.end - note.start
                    })

        if len(notes) == 0:
            return None

        # Compute statistics
        pitches = [n['pitch'] for n in notes]
        velocities = [n['velocity'] for n in notes]
        durations = [n['duration'] for n in notes]

        # Tempo
        tempo_changes = midi.get_tempo_changes()
        avg_tempo = np.mean(tempo_changes[1]) if len(tempo_changes[1]) > 0 else 120.0

        # Time signature
        time_sigs = midi.time_signature_changes
        numerator = time_sigs[0].numerator if len(time_sigs) > 0 else 4
        denominator = time_sigs[0].denominator if len(time_sigs) > 0 else 4

        # Create feature vector (32-dim)
        features = np.array([
            # Pitch statistics (8)
            np.mean(pitches), np.std(pitches), np.min(pitches), np.max(pitches),
            np.percentile(pitches, 25), np.percentile(pitches, 75),
            np.ptp(pitches), len(set(pitches)),

            # Velocity statistics (8)
            np.mean(velocities), np.std(velocities), np.min(velocities), np.max(velocities),
            np.percentile(velocities, 25), np.percentile(velocities, 75),
            np.ptp(velocities), len(notes),

            # Duration statistics (8)
            np.mean(durations), np.std(durations), np.min(durations), np.max(durations),
            np.percentile(durations, 25), np.percentile(durations, 75),
            np.ptp(durations), 1.0 / (np.mean(durations) + 1e-6),

            # Temporal features (8)
            avg_tempo, avg_tempo / 120.0, numerator, denominator,
            numerator / denominator, len(notes) / (midi.get_end_time() + 1e-6),
            midi.get_end_time(), len(midi.instruments)
        ], dtype=np.float32)

        return features
    except:
        return None

# 6. LOAD & EXTRACT ALL FEATURES

In [None]:
print("\n" + "="*80)
print("EXTRACTING FEATURES FROM ALL MODALITIES")
print("="*80)

data_list = []

# Get all files
lyrics_files = {f.replace('.txt', ''): f for f in os.listdir(lyrics_dir) if f.endswith('.txt')}
audio_files = {f.replace('.wav', '').replace('.mp3', ''): f for f in os.listdir(audio_dir) if f.endswith(('.wav', '.mp3'))}
midi_files = {f.replace('.mid', '').replace('.midi', ''): f for f in os.listdir(midi_dir) if f.endswith(('.mid', '.midi'))}

print(f"Found: {len(lyrics_files)} lyrics, {len(audio_files)} audio, {len(midi_files)} MIDI")

# Get all unique song IDs that have cluster labels
all_song_ids = set()
for f in lyrics_files.keys():
    song_id = ''.join(filter(str.isdigit, f))
    if song_id:
        all_song_ids.add(song_id.zfill(3))

print(f"\nProcessing {len(all_song_ids)} songs with multimodal data...")

processed = 0
for song_id in sorted(all_song_ids):
    if song_id not in song_cluster_map:
        continue

    # Initialize features
    lyrics_emb = None
    audio_emb = None
    midi_feat = None

    # Extract lyrics
    for key, filename in lyrics_files.items():
        if song_id in key or key.zfill(3) == song_id:
            lyrics_path = os.path.join(lyrics_dir, filename)
            with open(lyrics_path, 'r', encoding='utf-8', errors='ignore') as f:
                lyrics_text = f.read()
            lyrics_emb = extract_lyrics_embedding(lyrics_text, bert_tokenizer, bert_model)
            break

    # Extract audio
    for key, filename in audio_files.items():
        if song_id in key or key.zfill(3) == song_id:
            audio_path = os.path.join(audio_dir, filename)
            audio_emb = extract_audio_embedding(audio_path, panns_model)
            break

    # Extract MIDI
    for key, filename in midi_files.items():
        if song_id in key or key.zfill(3) == song_id:
            midi_path = os.path.join(midi_dir, filename)
            midi_feat = extract_midi_features(midi_path)
            break

    # Only add if at least 2 modalities available
    available = sum([lyrics_emb is not None, audio_emb is not None, midi_feat is not None])
    if available >= 2:
        data_list.append({
            'song_id': song_id,
            'lyrics_emb': lyrics_emb if lyrics_emb is not None else np.zeros(768),
            'audio_emb': audio_emb if audio_emb is not None else np.zeros(2048),
            'midi_feat': midi_feat if midi_feat is not None else np.zeros(32),
            'has_lyrics': lyrics_emb is not None,
            'has_audio': audio_emb is not None,
            'has_midi': midi_feat is not None,
            'cluster': song_cluster_map[song_id]
        })
        processed += 1

        if processed % 50 == 0:
            print(f"  Processed: {processed} songs...")

print(f"\n‚úì Total multimodal samples: {len(data_list)}")

df = pd.DataFrame(data_list)
print(f"‚úì Dataset shape: {df.shape}")
print(f"\nModality availability:")
print(f"  Lyrics: {df['has_lyrics'].sum()} ({df['has_lyrics'].mean()*100:.1f}%)")
print(f"  Audio: {df['has_audio'].sum()} ({df['has_audio'].mean()*100:.1f}%)")
print(f"  MIDI: {df['has_midi'].sum()} ({df['has_midi'].mean()*100:.1f}%)")
print(f"\nCluster distribution:")
print(df['cluster'].value_counts())


EXTRACTING FEATURES FROM ALL MODALITIES
Found: 764 lyrics, 903 audio, 196 MIDI

Processing 764 songs with multimodal data...
  Processed: 50 songs...
  Processed: 100 songs...
  Processed: 150 songs...
  Processed: 200 songs...
  Processed: 250 songs...
  Processed: 300 songs...
  Processed: 350 songs...
  Processed: 400 songs...
  Processed: 450 songs...
  Processed: 500 songs...
  Processed: 550 songs...
  Processed: 600 songs...
  Processed: 650 songs...
  Processed: 700 songs...
  Processed: 750 songs...

‚úì Total multimodal samples: 764
‚úì Dataset shape: (764, 8)

Modality availability:
  Lyrics: 764 (100.0%)
  Audio: 764 (100.0%)
  MIDI: 191 (25.0%)

Cluster distribution:
cluster
Cluster 3    192
Cluster 4    173
Cluster 2    138
Cluster 1    134
Cluster 5    127
Name: count, dtype: int64


# 7. LABEL ENCODING

In [None]:
label_encoder = LabelEncoder()
df['label'] = label_encoder.fit_transform(df['cluster'])
num_classes = len(label_encoder.classes_)

print(f"\n‚úì Classes: {label_encoder.classes_}")
print(f"‚úì Number of classes: {num_classes}")

# Class weights
y = df['label'].values
class_weights = compute_class_weight('balanced', classes=np.unique(y), y=y)
class_weights = torch.FloatTensor(class_weights).to(device)


‚úì Classes: ['Cluster 1' 'Cluster 2' 'Cluster 3' 'Cluster 4' 'Cluster 5']
‚úì Number of classes: 5


# 8. MULTIMODAL DATASET

In [None]:
class MultimodalDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data.iloc[idx]

        return {
            'lyrics_emb': torch.FloatTensor(item['lyrics_emb']),
            'audio_emb': torch.FloatTensor(item['audio_emb']),
            'midi_feat': torch.FloatTensor(item['midi_feat']),
            'has_lyrics': torch.FloatTensor([item['has_lyrics']]),
            'has_audio': torch.FloatTensor([item['has_audio']]),
            'has_midi': torch.FloatTensor([item['has_midi']]),
            'label': torch.tensor(item['label'], dtype=torch.long)
        }

# 9. MULTIMODAL FUSION MODEL

In [None]:
class MultimodalFusionClassifier(nn.Module):
    """
    Late Fusion: Concatenate embeddings from all modalities
    """
    def __init__(self, num_classes, lyrics_dim=768, audio_dim=2048, midi_dim=32,
                 fusion_dim=512, dropout=0.5):
        super().__init__()

        # Project each modality to common dimension
        self.lyrics_proj = nn.Sequential(
            nn.Linear(lyrics_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5)
        )

        self.audio_proj = nn.Sequential(
            nn.Linear(audio_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5)
        )

        self.midi_proj = nn.Sequential(
            nn.Linear(midi_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5)
        )

        # Fusion layer
        fusion_input_dim = 256 + 256 + 128  # 640
        self.fusion = nn.Sequential(
            nn.Linear(fusion_input_dim, fusion_dim),
            nn.BatchNorm1d(fusion_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(fusion_dim, 256),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, lyrics_emb, audio_emb, midi_feat, has_lyrics, has_audio, has_midi):
        # Project each modality
        lyrics_feat = self.lyrics_proj(lyrics_emb)
        audio_feat = self.audio_proj(audio_emb)
        midi_feat_proj = self.midi_proj(midi_feat)

        # Mask unavailable modalities
        lyrics_feat = lyrics_feat * has_lyrics
        audio_feat = audio_feat * has_audio
        midi_feat_proj = midi_feat_proj * has_midi

        # Concatenate
        fused = torch.cat([lyrics_feat, audio_feat, midi_feat_proj], dim=1)

        # Classify
        logits = self.fusion(fused)

        return logits

# 10. TRAINING & EVALUATION

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    predictions = []
    true_labels = []

    for batch in dataloader:
        lyrics = batch['lyrics_emb'].to(device)
        audio = batch['audio_emb'].to(device)
        midi = batch['midi_feat'].to(device)
        has_l = batch['has_lyrics'].to(device)
        has_a = batch['has_audio'].to(device)
        has_m = batch['has_midi'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()
        logits = model(lyrics, audio, midi, has_l, has_a, has_m)
        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
        preds = torch.argmax(logits, dim=1)
        predictions.extend(preds.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

    return total_loss / len(dataloader), accuracy_score(true_labels, predictions)

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    predictions = []
    true_labels = []

    with torch.no_grad():
        for batch in dataloader:
            lyrics = batch['lyrics_emb'].to(device)
            audio = batch['audio_emb'].to(device)
            midi = batch['midi_feat'].to(device)
            has_l = batch['has_lyrics'].to(device)
            has_a = batch['has_audio'].to(device)
            has_m = batch['has_midi'].to(device)
            labels = batch['label'].to(device)

            logits = model(lyrics, audio, midi, has_l, has_a, has_m)
            loss = criterion(logits, labels)

            total_loss += loss.item()
            preds = torch.argmax(logits, dim=1)
            predictions.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    acc = accuracy_score(true_labels, predictions)
    p, r, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='weighted', zero_division=0)
    return total_loss / len(dataloader), acc, p, r, f1, predictions, true_labels

# 11. 5-FOLD CROSS VALIDATION

In [None]:
BATCH_SIZE = 16
LR = 5e-5
EPOCHS = 25
PATIENCE = 7

print("\n" + "="*80)
print("MULTIMODAL FUSION - 5-FOLD CROSS VALIDATION")
print("="*80)
print(f"Total samples: {len(df)}")
print(f"Modalities: Lyrics (768) + Audio (2048) + MIDI (32)")
print(f"Fusion strategy: Late concatenation ‚Üí MLP")

X = df.index.values
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
fold_results = []

for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
    print(f"\n{'='*80}")
    print(f"FOLD {fold + 1}/5")
    print(f"{'='*80}")

    train_data = df.iloc[train_idx].reset_index(drop=True)
    val_data = df.iloc[val_idx].reset_index(drop=True)

    train_dataset = MultimodalDataset(train_data)
    val_dataset = MultimodalDataset(val_data)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

    model = MultimodalFusionClassifier(num_classes=num_classes).to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = AdamW(model.parameters(), lr=LR, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=0.5, patience=3)

    best_f1 = 0
    patience_counter = 0

    for epoch in range(EPOCHS):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc, val_p, val_r, val_f1, _, _ = evaluate(model, val_loader, criterion, device)
        scheduler.step(val_f1)

        print(f"Epoch {epoch+1}/{EPOCHS}: Train Acc={train_acc:.4f}, Val Acc={val_acc:.4f}, F1={val_f1:.4f}")

        if val_f1 > best_f1:
            best_f1 = val_f1
            torch.save(model.state_dict(), f'best_multimodal_fold{fold+1}.pt')
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                print("Early stopping!")
                break

    model.load_state_dict(torch.load(f'best_multimodal_fold{fold+1}.pt'))
    val_loss, val_acc, val_p, val_r, val_f1, preds, labels = evaluate(model, val_loader, criterion, device)

    print(f"\nFold {fold+1} Results: Acc={val_acc:.4f}, Precision={val_p:.4f}, Recall={val_r:.4f}, F1={val_f1:.4f}")
    print(classification_report(labels, preds, target_names=label_encoder.classes_, digits=4, zero_division=0))

    fold_results.append({'fold': fold+1, 'accuracy': val_acc, 'precision': val_p, 'recall': val_r, 'f1': val_f1})

# ============================================================================
# 12. FINAL RESULTS
# ============================================================================

print("\n" + "="*80)
print("FINAL MULTIMODAL RESULTS")
print("="*80)

results_df = pd.DataFrame(fold_results)
print(results_df.to_string(index=False))

print(f"\nAverage Performance:")
print(f"  Accuracy:  {results_df['accuracy'].mean():.4f} ¬± {results_df['accuracy'].std():.4f}")
print(f"  Precision: {results_df['precision'].mean():.4f} ¬± {results_df['precision'].std():.4f}")
print(f"  Recall:    {results_df['recall'].mean():.4f} ¬± {results_df['recall'].std():.4f}")
print(f"  F1-Score:  {results_df['f1'].mean():.4f} ¬± {results_df['f1'].std():.4f}")

results_df.to_csv('multimodal_fusion_results.csv', index=False)

print("\n" + "="*80)
print("COMPARISON WITH SINGLE MODALITIES")
print("="*80)
print(f"Lyrics (BERT) only:     ~45-55% F1")
print(f"Audio (PANNs) only:     ~50-60% F1")
print(f"MIDI (Orpheus) only:    ~23% F1")
print(f"MULTIMODAL FUSION:      ~{results_df['f1'].mean():.1%} F1")

if results_df['f1'].mean() > 0.60:
    print("\nüéâ SUCCESS! Multimodal fusion outperforms single modalities!")
elif results_df['f1'].mean() > 0.55:
    print("\n‚úì Good! Multimodal provides improvement.")
else:
    print("\n‚ö†Ô∏è Multimodal similar to best single modality (Audio).")

print("\n‚úÖ COMPLETE!")


MULTIMODAL FUSION - 5-FOLD CROSS VALIDATION
Total samples: 764
Modalities: Lyrics (768) + Audio (2048) + MIDI (32)
Fusion strategy: Late concatenation ‚Üí MLP

FOLD 1/5
Epoch 1/25: Train Acc=0.2422, Val Acc=0.3529, F1=0.2935
Epoch 2/25: Train Acc=0.3453, Val Acc=0.4641, F1=0.4264
Epoch 3/25: Train Acc=0.3732, Val Acc=0.5098, F1=0.4901
Epoch 4/25: Train Acc=0.4501, Val Acc=0.5033, F1=0.4710
Epoch 5/25: Train Acc=0.4092, Val Acc=0.4706, F1=0.4280
Epoch 6/25: Train Acc=0.4534, Val Acc=0.4837, F1=0.4381
Epoch 7/25: Train Acc=0.4632, Val Acc=0.4902, F1=0.4523
Epoch 8/25: Train Acc=0.4664, Val Acc=0.4967, F1=0.4406
Epoch 9/25: Train Acc=0.4746, Val Acc=0.4837, F1=0.4347
Epoch 10/25: Train Acc=0.4992, Val Acc=0.4510, F1=0.4096
Early stopping!

Fold 1 Results: Acc=0.5098, Precision=0.4888, Recall=0.5098, F1=0.4901
              precision    recall  f1-score   support

   Cluster 1     0.2381    0.1852    0.2083        27
   Cluster 2     0.4231    0.3929    0.4074        28
   Cluster 3     0