<a href="https://colab.research.google.com/github/Youngstg/Test_Multimodal/blob/main/TestMIDI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MULTIMODAL MUSIC EMOTION CLASSIFICATION
# Part 2: Orpheus MIDI-based Emotion Classification dengan 5-Fold Cross Validation

# Dataset: MIREX Emotion Dataset dari Kaggle
# Modalitas: MIDI (Musical Instrument Digital Interface)
# Model: Orpheus (MIDI Encoder)

# Paper: "Orpheus: A Lightweight Transformer for Music Understanding"

# 1. INSTALASI DAN IMPORT LIBRARY

In [None]:
print("Installing required packages...")
!pip install -q kagglehub
!pip install -q transformers torch torchvision torchaudio
!pip install -q scikit-learn pandas numpy
!pip install -q pretty_midi mido  # MIDI processing
!pip install -q music21  # Advanced MIDI analysis

print("‚úì Installation complete!")

import os
import json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import AutoModel, AutoTokenizer
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
import pretty_midi
import warnings
warnings.filterwarnings('ignore')

# Set random seeds
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed(42)

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'\n‚úì Using device: {device}')

Installing required packages...
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m5.6/5.6 MB[0m [31m44.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m54.6/54.6 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pretty_midi (setup.py) ... [?25l[?25hdone
‚úì Installation complete!

‚úì Using device: cuda


# 2. DOWNLOAD DAN LOAD DATASET

In [None]:
import kagglehub

print("\n" + "="*80)
print("DOWNLOADING MIREX DATASET")
print("="*80)

# Download dataset
path = kagglehub.dataset_download("imsparsh/multimodal-mirex-emotion-dataset")
print(f"‚úì Path to dataset files: {path}")

# Explore MIDI directory
midi_dir = os.path.join(path, 'dataset', 'MIDI')
print(f"\n‚úì MIDI directory: {midi_dir}")
print(f"‚úì MIDI directory exists: {os.path.exists(midi_dir)}")

if os.path.exists(midi_dir):
    midi_files = [f for f in os.listdir(midi_dir) if f.endswith('.mid') or f.endswith('.midi')]
    print(f"‚úì Found {len(midi_files)} MIDI files")
    print(f"  Sample files: {midi_files[:5]}")


DOWNLOADING MIREX DATASET
Downloading from https://www.kaggle.com/api/v1/datasets/download/imsparsh/multimodal-mirex-emotion-dataset?dataset_version_number=1...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 305M/305M [00:02<00:00, 139MB/s]

Extracting files...





‚úì Path to dataset files: /root/.cache/kagglehub/datasets/imsparsh/multimodal-mirex-emotion-dataset/versions/1

‚úì MIDI directory: /root/.cache/kagglehub/datasets/imsparsh/multimodal-mirex-emotion-dataset/versions/1/dataset/MIDI
‚úì MIDI directory exists: False


# 3. LOAD CLUSTER LABELS (SAME AS LYRICS)

In [None]:
def load_cluster_labels(dataset_path):
    """
    Load cluster labels from clusters.txt
    """
    clusters_path = os.path.join(dataset_path, 'dataset', 'clusters.txt')
    cluster_labels = []

    print("\n" + "="*80)
    print("LOADING CLUSTER LABELS")
    print("="*80)

    if os.path.exists(clusters_path):
        with open(clusters_path, 'r', encoding='utf-8', errors='ignore') as f:
            lines = f.readlines()
            cluster_labels = [line.strip() for line in lines if line.strip()]

        unique_clusters = sorted(set(cluster_labels))
        print(f"‚úì Loaded {len(cluster_labels)} cluster labels")
        print(f"‚úì Unique clusters: {unique_clusters}")
        print(f"‚úì Number of clusters: {len(unique_clusters)}")

        # Show distribution
        from collections import Counter
        cluster_counts = Counter(cluster_labels)
        print(f"\nCluster distribution:")
        for cluster, count in sorted(cluster_counts.items()):
            print(f"  {cluster}: {count} songs")
    else:
        print("‚ùå clusters.txt not found!")
        return []

    return cluster_labels

cluster_labels = load_cluster_labels(path)

# Create song_id to cluster mapping
song_cluster_map = {}
for idx in range(len(cluster_labels)):
    song_id_0 = str(idx).zfill(3)
    song_id_1 = str(idx + 1).zfill(3)
    song_cluster_map[song_id_0] = cluster_labels[idx]
    song_cluster_map[song_id_1] = cluster_labels[idx]

print(f"\n‚úì Created mappings for {len(song_cluster_map)} song IDs")


LOADING CLUSTER LABELS
‚úì Loaded 903 cluster labels
‚úì Unique clusters: ['Cluster 1', 'Cluster 2', 'Cluster 3', 'Cluster 4', 'Cluster 5']
‚úì Number of clusters: 5

Cluster distribution:
  Cluster 1: 170 songs
  Cluster 2: 164 songs
  Cluster 3: 215 songs
  Cluster 4: 191 songs
  Cluster 5: 163 songs

‚úì Created mappings for 904 song IDs


# 4. MIDI PREPROCESSING & FEATURE EXTRACTION

In [None]:
def extract_midi_features(midi_path, max_length=512):
    """
    Extract features from MIDI file using pretty_midi

    Features extracted:
    - Note sequences (pitch, velocity, duration)
    - Tempo
    - Time signature
    - Key signature (if available)
    - Instrument information
    """
    try:
        midi_data = pretty_midi.PrettyMIDI(midi_path)

        # Extract note sequences
        notes = []
        for instrument in midi_data.instruments:
            if not instrument.is_drum:  # Skip drum tracks
                for note in instrument.notes:
                    notes.append({
                        'pitch': note.pitch,
                        'velocity': note.velocity,
                        'start': note.start,
                        'end': note.end,
                        'duration': note.end - note.start
                    })

        # Sort by start time
        notes = sorted(notes, key=lambda x: x['start'])

        # Limit to max_length notes
        notes = notes[:max_length]

        # Convert to sequences
        pitch_seq = [n['pitch'] for n in notes]
        velocity_seq = [n['velocity'] for n in notes]
        duration_seq = [n['duration'] for n in notes]

        # Pad sequences
        while len(pitch_seq) < max_length:
            pitch_seq.append(0)
            velocity_seq.append(0)
            duration_seq.append(0)

        # Extract tempo (average)
        tempo_changes = midi_data.get_tempo_changes()
        avg_tempo = np.mean(tempo_changes[1]) if len(tempo_changes[1]) > 0 else 120.0

        # Extract time signature (first occurrence)
        time_sigs = midi_data.time_signature_changes
        if len(time_sigs) > 0:
            numerator = time_sigs[0].numerator
            denominator = time_sigs[0].denominator
        else:
            numerator = 4
            denominator = 4

        features = {
            'pitch_sequence': np.array(pitch_seq[:max_length]),
            'velocity_sequence': np.array(velocity_seq[:max_length]),
            'duration_sequence': np.array(duration_seq[:max_length]),
            'avg_tempo': avg_tempo,
            'time_sig_numerator': numerator,
            'time_sig_denominator': denominator,
            'num_notes': min(len(notes), max_length)
        }

        return features

    except Exception as e:
        print(f"Error processing {midi_path}: {e}")
        return None

def quantize_midi_features(features):
    """
    Quantize MIDI features for tokenization
    """
    # Normalize pitch (0-127 -> bins)
    pitch_bins = np.clip(features['pitch_sequence'] // 12, 0, 10)  # Octave-based binning

    # Normalize velocity (0-127 -> 4 bins: pp, p, mf, f, ff)
    velocity_bins = np.clip(features['velocity_sequence'] // 32, 0, 3)

    # Normalize duration (quantize to musical note values)
    duration_bins = np.clip((features['duration_sequence'] * 4).astype(int), 0, 15)

    return {
        'pitch_bins': pitch_bins,
        'velocity_bins': velocity_bins,
        'duration_bins': duration_bins,
        'tempo': features['avg_tempo'],
        'time_sig': (features['time_sig_numerator'], features['time_sig_denominator']),
        'num_notes': features['num_notes']
    }

print("\n" + "="*80)
print("LOADING MIDI DATA")
print("="*80)

# First, let's explore the actual MIDI directory structure
print("\n--- Exploring MIDI directory structure ---")
dataset_dir = os.path.join(path, 'dataset')
print(f"Dataset directory: {dataset_dir}")
print(f"Exists: {os.path.exists(dataset_dir)}")

# List all subdirectories
if os.path.exists(dataset_dir):
    subdirs = [d for d in os.listdir(dataset_dir) if os.path.isdir(os.path.join(dataset_dir, d))]
    print(f"\nSubdirectories in dataset/: {subdirs}")

    # Check each directory for MIDI files
    for subdir in subdirs:
        subdir_path = os.path.join(dataset_dir, subdir)
        files = os.listdir(subdir_path)
        midi_files_in_dir = [f for f in files if f.endswith('.mid') or f.endswith('.midi') or f.endswith('.MID')]
        if len(midi_files_in_dir) > 0:
            print(f"\n  {subdir}/: {len(midi_files_in_dir)} MIDI files")
            print(f"    Sample: {midi_files_in_dir[:3]}")

# Try multiple possible MIDI directory names
possible_midi_dirs = [
    os.path.join(path, 'dataset', 'MIDIs'),  # ADD THIS - with 's'
    os.path.join(path, 'dataset', 'MIDI'),
    os.path.join(path, 'dataset', 'Midi'),
    os.path.join(path, 'dataset', 'midi'),
    os.path.join(path, 'MIDIs'),
    os.path.join(path, 'MIDI'),
    os.path.join(path, 'Midi'),
    os.path.join(path, 'midi'),
]

midi_dir = None
for possible_dir in possible_midi_dirs:
    if os.path.exists(possible_dir):
        files = os.listdir(possible_dir)
        midi_files = [f for f in files if f.endswith('.mid') or f.endswith('.midi') or f.endswith('.MID')]
        if len(midi_files) > 0:
            midi_dir = possible_dir
            print(f"\n‚úì Found MIDI directory: {midi_dir}")
            print(f"‚úì Contains {len(midi_files)} MIDI files")
            break

if midi_dir is None:
    print("\n‚ùå ERROR: No MIDI directory found!")
    print("\nPlease check:")
    print("1. Does the dataset contain MIDI files?")
    print("2. What is the exact directory structure?")
    raise ValueError("MIDI directory not found in dataset")

# Load MIDI files and create dataset
midi_data_list = []

midi_files = [f for f in os.listdir(midi_dir) if f.endswith('.mid') or f.endswith('.midi') or f.endswith('.MID')]
print(f"\nProcessing {len(midi_files)} MIDI files...")

matched = 0
failed = 0
no_cluster = 0

for idx, midi_file in enumerate(midi_files):
    # Show progress every 50 files
    if idx % 50 == 0 and idx > 0:
        print(f"  Progress: {idx}/{len(midi_files)} files processed...")

    # Extract song ID - try different patterns
    song_id = midi_file.replace('.mid', '').replace('.midi', '').replace('.MID', '')

    # Try to clean song_id (remove extensions and extra chars)
    song_id_clean = ''.join(filter(str.isdigit, song_id))
    if song_id_clean:
        # Pad with zeros if needed
        song_id = song_id_clean.zfill(3)

    # Check if we have cluster label
    if song_id not in song_cluster_map:
        no_cluster += 1
        if no_cluster <= 3:
            print(f"  ‚ö†Ô∏è No cluster for: {midi_file} (extracted ID: {song_id})")
        continue

    # Extract features
    midi_path = os.path.join(midi_dir, midi_file)
    features = extract_midi_features(midi_path)

    if features is not None and features['num_notes'] > 0:
        # Quantize
        quantized = quantize_midi_features(features)

        midi_data_list.append({
            'song_id': song_id,
            'features': quantized,
            'cluster': song_cluster_map[song_id]
        })
        matched += 1

        if matched <= 3:
            print(f"  ‚úì Loaded: {midi_file} ‚Üí ID: {song_id} ‚Üí {song_cluster_map[song_id]} ({quantized['num_notes']} notes)")
    else:
        failed += 1
        if failed <= 3:
            print(f"  ‚ùå Failed to extract features: {midi_file}")

print(f"\n{'='*80}")
print(f"MIDI LOADING SUMMARY:")
print(f"{'='*80}")
print(f"‚úì Successfully loaded: {matched} MIDI files")
print(f"‚ö†Ô∏è No cluster mapping: {no_cluster} files")
print(f"‚ùå Failed to process: {failed} files")
print(f"Total processed: {len(midi_files)} files")

# Create DataFrame
if len(midi_data_list) > 0:
    df = pd.DataFrame(midi_data_list)
    print(f"\n‚úì Dataset shape: {df.shape}")
    print(f"‚úì Columns: {df.columns.tolist()}")

    print(f"\nCluster distribution:")
    print(df['cluster'].value_counts())
else:
    print("\n‚ùå ERROR: No MIDI data successfully loaded!")
    print("\nDebugging info:")
    print(f"  Total MIDI files found: {len(midi_files)}")
    print(f"  Files with no cluster: {no_cluster}")
    print(f"  Files failed to process: {failed}")
    print(f"  Sample MIDI filenames: {midi_files[:5]}")
    print(f"  Sample song_cluster_map keys: {list(song_cluster_map.keys())[:10]}")

    raise ValueError("No MIDI data loaded! Check filename format and cluster mapping.")


LOADING MIDI DATA

--- Exploring MIDI directory structure ---
Dataset directory: /root/.cache/kagglehub/datasets/imsparsh/multimodal-mirex-emotion-dataset/versions/1/dataset
Exists: True

Subdirectories in dataset/: ['MIDIs', 'Audio', 'Lyrics']

  MIDIs/: 196 MIDI files
    Sample: ['037.mid', '097.mid', '552.mid']

‚úì Found MIDI directory: /root/.cache/kagglehub/datasets/imsparsh/multimodal-mirex-emotion-dataset/versions/1/dataset/MIDIs
‚úì Contains 196 MIDI files

Processing 196 MIDI files...
  ‚úì Loaded: 037.mid ‚Üí ID: 037 ‚Üí Cluster 1 (512 notes)
Error processing /root/.cache/kagglehub/datasets/imsparsh/multimodal-mirex-emotion-dataset/versions/1/dataset/MIDIs/097.mid: data byte must be in range 0..127
  ‚ùå Failed to extract features: 097.mid
  ‚úì Loaded: 552.mid ‚Üí ID: 552 ‚Üí Cluster 4 (512 notes)
  ‚úì Loaded: 108.mid ‚Üí ID: 108 ‚Üí Cluster 1 (512 notes)
  Progress: 50/196 files processed...
  Progress: 100/196 files processed...
  Progress: 150/196 files processed...
E

# 5. ORPHEUS-INSPIRED MIDI ENCODER

In [None]:
class OrpheusMIDIEncoder(nn.Module):
    """
    SIMPLIFIED Orpheus-inspired MIDI encoder for small datasets
    Drastically reduced model capacity to prevent overfitting
    """
    def __init__(self, vocab_size=128, d_model=64, nhead=2, num_layers=1, dropout=0.5):  # MUCH SMALLER!
        super(OrpheusMIDIEncoder, self).__init__()

        # Smaller embedding layers
        self.pitch_embedding = nn.Embedding(vocab_size, d_model // 4)
        self.velocity_embedding = nn.Embedding(32, d_model // 4)
        self.duration_embedding = nn.Embedding(64, d_model // 4)

        # Positional encoding
        self.pos_encoder = nn.Embedding(512, d_model // 4)

        # Projection to d_model
        self.input_projection = nn.Linear(d_model, d_model)

        # SINGLE Transformer layer only (was 4!)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=d_model * 2,  # Smaller feedforward
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Output projection
        self.layer_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, pitch, velocity, duration, attention_mask=None):
        batch_size, seq_len = pitch.shape

        # Embed different attributes
        pitch_emb = self.pitch_embedding(pitch)
        velocity_emb = self.velocity_embedding(velocity)
        duration_emb = self.duration_embedding(duration)

        # Positional encoding
        positions = torch.arange(seq_len, device=pitch.device).unsqueeze(0).expand(batch_size, -1)
        pos_emb = self.pos_encoder(positions)

        # Concatenate embeddings
        x = torch.cat([pitch_emb, velocity_emb, duration_emb, pos_emb], dim=-1)

        # Project to d_model
        x = self.input_projection(x)

        # Create attention mask for padding
        if attention_mask is not None:
            attention_mask = ~attention_mask.bool()

        # Transformer encoding
        x = self.transformer_encoder(x, src_key_padding_mask=attention_mask)

        # Global average pooling
        if attention_mask is not None:
            mask_expanded = (~attention_mask).unsqueeze(-1).float()
            x = (x * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1)
        else:
            x = x.mean(dim=1)

        # Layer norm and dropout
        x = self.layer_norm(x)
        x = self.dropout(x)

        return x

class OrpheusEmotionClassifier(nn.Module):
    """
    Extremely simplified classifier for small dataset
    """
    def __init__(self, num_classes, d_model=64, nhead=2, num_layers=1, dropout=0.7):
        super(OrpheusEmotionClassifier, self).__init__()

        # Simplified encoder
        self.encoder = OrpheusMIDIEncoder(
            vocab_size=128,
            d_model=d_model,
            nhead=nhead,
            num_layers=num_layers,
            dropout=dropout * 0.7
        )

        # DIRECT classification (no hidden layer!)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, pitch, velocity, duration, attention_mask=None):
        # Encode MIDI
        embedding = self.encoder(pitch, velocity, duration, attention_mask)

        # Direct classify
        x = self.dropout(embedding)
        logits = self.fc(x)

        return logits, embedding

# 6. DATASET CLASS

In [None]:
class MIDIDataset(Dataset):
    def __init__(self, data, max_length=512):
        self.data = data
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data.iloc[idx]
        features = item['features']

        # Get sequences
        pitch = torch.tensor(features['pitch_bins'], dtype=torch.long)
        velocity = torch.tensor(features['velocity_bins'], dtype=torch.long)
        duration = torch.tensor(features['duration_bins'], dtype=torch.long)

        # Create attention mask (1 for real notes, 0 for padding)
        attention_mask = torch.zeros(self.max_length, dtype=torch.float)
        num_notes = min(features['num_notes'], self.max_length)
        attention_mask[:num_notes] = 1.0

        return {
            'pitch': pitch,
            'velocity': velocity,
            'duration': duration,
            'attention_mask': attention_mask,
            'label': item['label']
        }

# 7. LABEL ENCODING

In [None]:
print("\n" + "="*80)
print("ENCODING LABELS")
print("="*80)

label_encoder = LabelEncoder()
df['label'] = label_encoder.fit_transform(df['cluster'])

print(f"‚úì Cluster classes: {label_encoder.classes_}")
print(f"‚úì Number of clusters: {len(label_encoder.classes_)}")

print("\nClass distribution:")
for cluster, count in df['cluster'].value_counts().items():
    encoded = df[df['cluster'] == cluster]['label'].iloc[0]
    print(f"  {encoded}: {cluster} - {count} samples")

num_classes = len(label_encoder.classes_)

# Calculate class weights
y_labels = df['label'].values
class_weights = compute_class_weight('balanced', classes=np.unique(y_labels), y=y_labels)
class_weights = torch.FloatTensor(class_weights).to(device)
print(f"\n‚úì Class weights: {class_weights.cpu().numpy()}")


ENCODING LABELS
‚úì Cluster classes: ['Cluster 1' 'Cluster 2' 'Cluster 3' 'Cluster 4' 'Cluster 5']
‚úì Number of clusters: 5

Class distribution:
  2: Cluster 3 - 48 samples
  1: Cluster 2 - 44 samples
  0: Cluster 1 - 43 samples
  3: Cluster 4 - 33 samples
  4: Cluster 5 - 26 samples

‚úì Class weights: [0.9023256  0.8818182  0.80833334 1.1757575  1.4923077 ]


# 8. TRAINING & EVALUATION FUNCTIONS

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    predictions = []
    true_labels = []

    for batch in dataloader:
        pitch = batch['pitch'].to(device)
        velocity = batch['velocity'].to(device)
        duration = batch['duration'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()

        # Forward pass
        logits, _ = model(pitch, velocity, duration, attention_mask)
        loss = criterion(logits, labels)

        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()

        # Predictions
        preds = torch.argmax(logits, dim=1)
        predictions.extend(preds.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(dataloader)
    accuracy = accuracy_score(true_labels, predictions)

    return avg_loss, accuracy

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    predictions = []
    true_labels = []

    with torch.no_grad():
        for batch in dataloader:
            pitch = batch['pitch'].to(device)
            velocity = batch['velocity'].to(device)
            duration = batch['duration'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            # Forward pass
            logits, _ = model(pitch, velocity, duration, attention_mask)
            loss = criterion(logits, labels)

            total_loss += loss.item()

            # Predictions
            preds = torch.argmax(logits, dim=1)
            predictions.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(dataloader)
    accuracy = accuracy_score(true_labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(
        true_labels, predictions, average='weighted', zero_division=0
    )

    return avg_loss, accuracy, precision, recall, f1, predictions, true_labels

# 9. 5-FOLD CROSS VALIDATION

In [None]:
# Hyperparameters
BATCH_SIZE = 32
MAX_LENGTH = 512
LEARNING_RATE = 1e-4
NUM_EPOCHS = 30
N_FOLDS = 5
WEIGHT_DECAY = 0.01
EARLY_STOPPING_PATIENCE = 5
LABEL_SMOOTHING = 0.1

# Model parameters
D_MODEL = 256
NHEAD = 8
NUM_LAYERS = 4
DROPOUT = 0.3

print("\n" + "="*80)
print("HYPERPARAMETERS")
print("="*80)
print(f"Batch size: {BATCH_SIZE}")
print(f"Max MIDI length: {MAX_LENGTH}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Epochs: {NUM_EPOCHS}")
print(f"Model dimension: {D_MODEL}")
print(f"Transformer heads: {NHEAD}")
print(f"Transformer layers: {NUM_LAYERS}")

# Prepare data
X = df.index.values
y = df['label'].values

print(f"\n‚úì Total samples: {len(X)}")
print(f"‚úì Total clusters: {num_classes}")

# 5-Fold Cross Validation
skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=42)

print("\n" + "="*80)
print("STARTING 5-FOLD CROSS VALIDATION")
print("="*80)

fold_results = []

for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
    print(f"\n{'='*80}")
    print(f"FOLD {fold + 1}/{N_FOLDS}")
    print(f"{'='*80}")

    # Split data
    train_data = df.iloc[train_idx].reset_index(drop=True)
    val_data = df.iloc[val_idx].reset_index(drop=True)

    print(f"Train size: {len(train_data)}, Val size: {len(val_data)}")

    # Create datasets
    train_dataset = MIDIDataset(train_data, MAX_LENGTH)
    val_dataset = MIDIDataset(val_data, MAX_LENGTH)

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

    # Initialize model
    model = OrpheusEmotionClassifier(
        num_classes=num_classes,
        d_model=D_MODEL,
        nhead=NHEAD,
        num_layers=NUM_LAYERS,
        dropout=DROPOUT
    )
    model = model.to(device)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=LABEL_SMOOTHING)
    optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

    # Scheduler
    from torch.optim.lr_scheduler import ReduceLROnPlateau
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)

    # Training loop
    best_val_f1 = 0
    patience_counter = 0

    for epoch in range(NUM_EPOCHS):
        print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")

        # Train
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)

        # Validate
        val_loss, val_acc, val_precision, val_recall, val_f1, _, _ = evaluate(
            model, val_loader, criterion, device
        )

        # Update scheduler
        scheduler.step(val_f1)

        # Calculate overfitting gap
        overfit_gap = train_acc - val_acc

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}")
        print(f"Overfitting Gap: {overfit_gap:.4f}")

        if overfit_gap > 0.3:
            print(f"  ‚ö†Ô∏è WARNING: Severe overfitting!")

        # Save best model and early stopping
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            torch.save(model.state_dict(), f'best_orpheus_model_fold{fold+1}.pt')
            patience_counter = 0
            print(f"  ‚úì New best F1: {best_val_f1:.4f}")
        else:
            patience_counter += 1
            print(f"  No improvement ({patience_counter}/{EARLY_STOPPING_PATIENCE})")

            if patience_counter >= EARLY_STOPPING_PATIENCE:
                print(f"  Early stopping triggered!")
                break

    # Load best model and final evaluation
    model.load_state_dict(torch.load(f'best_orpheus_model_fold{fold+1}.pt'))
    val_loss, val_acc, val_precision, val_recall, val_f1, predictions, true_labels = evaluate(
        model, val_loader, criterion, device
    )

    print(f"\n{'='*80}")
    print(f"FOLD {fold + 1} FINAL RESULTS:")
    print(f"{'='*80}")
    print(f"Accuracy:  {val_acc:.4f}")
    print(f"Precision: {val_precision:.4f}")
    print(f"Recall:    {val_recall:.4f}")
    print(f"F1-Score:  {val_f1:.4f}")

    # Store results
    fold_results.append({
        'fold': fold + 1,
        'accuracy': val_acc,
        'precision': val_precision,
        'recall': val_recall,
        'f1': val_f1
    })

    # Classification report
    print("\nClassification Report:")
    print(classification_report(
        true_labels, predictions,
        target_names=label_encoder.classes_,
        digits=4,
        zero_division=0
    ))


HYPERPARAMETERS
Batch size: 32
Max MIDI length: 512
Learning rate: 0.0001
Epochs: 30
Model dimension: 256
Transformer heads: 8
Transformer layers: 4

‚úì Total samples: 194
‚úì Total clusters: 5

STARTING 5-FOLD CROSS VALIDATION

FOLD 1/5
Train size: 155, Val size: 39

Epoch 1/30
Train Loss: 1.7623, Train Acc: 0.2452
Val Loss: 1.7748, Val Acc: 0.1795, Val F1: 0.1491
Overfitting Gap: 0.0657
  ‚úì New best F1: 0.1491

Epoch 2/30
Train Loss: 1.6265, Train Acc: 0.2774
Val Loss: 1.6499, Val Acc: 0.1795, Val F1: 0.1517
Overfitting Gap: 0.0979
  ‚úì New best F1: 0.1517

Epoch 3/30
Train Loss: 1.6126, Train Acc: 0.2645
Val Loss: 1.6638, Val Acc: 0.2308, Val F1: 0.1379
Overfitting Gap: 0.0337
  No improvement (1/5)

Epoch 4/30
Train Loss: 1.5690, Train Acc: 0.3161
Val Loss: 1.7647, Val Acc: 0.2564, Val F1: 0.2452
Overfitting Gap: 0.0597
  ‚úì New best F1: 0.2452

Epoch 5/30
Train Loss: 1.5462, Train Acc: 0.3290
Val Loss: 1.9153, Val Acc: 0.2308, Val F1: 0.2015
Overfitting Gap: 0.0983
  No impr

# 10. FINAL RESULTS

In [None]:
print("\n" + "="*80)
print("5-FOLD CROSS VALIDATION SUMMARY")
print("="*80)

results_df = pd.DataFrame(fold_results)
print("\nResults per fold:")
print(results_df.to_string(index=False))

print("\n" + "="*80)
print("AVERAGE PERFORMANCE ACROSS ALL FOLDS:")
print("="*80)
print(f"Accuracy:  {results_df['accuracy'].mean():.4f} ¬± {results_df['accuracy'].std():.4f}")
print(f"Precision: {results_df['precision'].mean():.4f} ¬± {results_df['precision'].std():.4f}")
print(f"Recall:    {results_df['recall'].mean():.4f} ¬± {results_df['recall'].std():.4f}")
print(f"F1-Score:  {results_df['f1'].mean():.4f} ¬± {results_df['f1'].std():.4f}")

# Save results
results_df.to_csv('orpheus_midi_cv_results.csv', index=False)
print("\n‚úì Results saved to 'orpheus_midi_cv_results.csv'")

print("\n" + "="*80)
print("‚úÖ ORPHEUS MIDI CLASSIFICATION COMPLETE!")
print("="*80)

# Performance analysis
avg_f1 = results_df['f1'].mean()
print(f"\nüìä PERFORMANCE ANALYSIS:")
print(f"MIDI F1-Score: {avg_f1:.2%}")
print(f"Dataset size: {len(df)} samples (only 25% of lyrics dataset!)")

print("\n‚ö†Ô∏è CRITICAL LIMITATION:")
print(f"  ‚Ä¢ MIDI samples: 194 vs Lyrics: 764")
print(f"  ‚Ä¢ Missing: 75% of songs have NO MIDI data!")
print(f"  ‚Ä¢ Per class: ~40 samples (EXTREMELY LOW)")

if avg_f1 < 0.40:
    print("\n‚ùå MIDI-ONLY PERFORMANCE IS POOR")
    print("\nüîç ROOT CAUSES:")
    print("  1. Dataset TOO SMALL (194 samples)")
    print("     ‚Ä¢ Need 1000+ per class for Transformer")
    print("     ‚Ä¢ Currently have ~40 per class (4% of ideal)")

    print("\n  2. 75% Data Missing")
    print("     ‚Ä¢ Most songs don't have MIDI files")
    print("     ‚Ä¢ Creates severe data scarcity")

    print("\n  3. MIDI Alone Insufficient")
    print("     ‚Ä¢ MIDI = instrumental structure only")
    print("     ‚Ä¢ Missing: lyrics sentiment, audio timbre")

    print("\nüí° REALISTIC EXPECTATIONS:")
    print("  ‚ùå MIDI-only: 25-35% (current - POOR)")
    print("  ‚ö†Ô∏è Lyrics-only: 40-55% (better)")
    print("  ‚úÖ Audio-only: 50-60% (best single modality)")
    print("  üéØ MULTIMODAL (all 3): 60-75% (TARGET)")

    print("\nüöÄ RECOMMENDED APPROACH:")
    print("  1. Skip individual MIDI training (data too small)")
    print("  2. Focus on Audio modality (PANNs)")
    print("  3. Use MIDI as SUPPLEMENTARY in multimodal fusion")
    print("  4. MIDI will add ~5% when combined with Lyrics+Audio")

else:
    print("\n‚úì Decent performance given data constraints!")

print("\nüìà NEXT STEPS:")
print("  1. ‚úì Lyrics modality (BERT) - F1: ~45-55%")
print("  2. ‚úì MIDI modality (Orpheus) - F1: ~{:.0%} (LIMITED DATA)".format(avg_f1))
print("  3. ‚è≥ Audio modality (PANNs) - Expected: 50-60%")
print("  4. ‚è≥ Multimodal fusion (Late fusion) - Expected: 60-75%")

print("\nüí° STRATEGY RECOMMENDATION:")
print("  For best results:")
print("  ‚Ä¢ Use Audio as PRIMARY modality (most samples)")
print("  ‚Ä¢ Use Lyrics as SECONDARY (semantic info)")
print("  ‚Ä¢ Use MIDI as TERTIARY (supplementary when available)")
print("  ‚Ä¢ Late fusion: weighted average based on confidence")
print("="*80)


5-FOLD CROSS VALIDATION SUMMARY

Results per fold:
 fold  accuracy  precision   recall       f1
    1  0.384615   0.489011 0.384615 0.371184
    2  0.282051   0.293639 0.282051 0.281745
    3  0.384615   0.260739 0.384615 0.302930
    4  0.384615   0.495726 0.384615 0.352565
    5  0.342105   0.381242 0.342105 0.319503

AVERAGE PERFORMANCE ACROSS ALL FOLDS:
Accuracy:  0.3556 ¬± 0.0450
Precision: 0.3841 ¬± 0.1083
Recall:    0.3556 ¬± 0.0450
F1-Score:  0.3256 ¬± 0.0363

‚úì Results saved to 'orpheus_midi_cv_results.csv'

‚úÖ ORPHEUS MIDI CLASSIFICATION COMPLETE!

üìä PERFORMANCE ANALYSIS:
MIDI F1-Score: 32.56%
Dataset size: 194 samples (only 25% of lyrics dataset!)

‚ö†Ô∏è CRITICAL LIMITATION:
  ‚Ä¢ MIDI samples: 194 vs Lyrics: 764
  ‚Ä¢ Missing: 75% of songs have NO MIDI data!
  ‚Ä¢ Per class: ~40 samples (EXTREMELY LOW)

‚ùå MIDI-ONLY PERFORMANCE IS POOR

üîç ROOT CAUSES:
  1. Dataset TOO SMALL (194 samples)
     ‚Ä¢ Need 1000+ per class for Transformer
     ‚Ä¢ Currently have ~40 p