In [1]:
!pip install pretty_midi numpy torch tqdm

Collecting pretty_midi
  Downloading pretty_midi-0.2.10.tar.gz (5.6 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/5.6 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.5/5.6 MB[0m [31m14.3 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━[0m [32m3.5/5.6 MB[0m [31m49.6 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m5.6/5.6 MB[0m [31m63.2 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m46.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mido>=1.1.16 (from pretty_midi)
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64

In [3]:
# ========== Download and extract MAESTRO Dataset ==========
import os
import requests
import zipfile

current_dir = os.getcwd()
print(f"Current Dir: {current_dir}")

url = "https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip"
zip_path = os.path.join(current_dir, "maestro-v3.0.0-midi.zip")
extract_path = os.path.join(current_dir, "maestro")

# Download
if not os.path.exists(zip_path):
    print("Downloading MAESTRO Dataset...")
    with requests.get(url, stream=True) as r:
        with open(zip_path, 'wb') as f:
            for chunk in r.iter_content(chunk_size=8192):
                if chunk:
                    f.write(chunk)
    print("Download Complete")

# Extract
if not os.path.exists(extract_path):
    print("Extracting...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_path)
    print(f"Extracted to:{extract_path}")
else:
    print("Dataset already exists")


Current Dir: /content
Downloading MAESTRO Dataset...
Download Complete
Extracting...
Extracted to:/content/maestro


In [13]:
from pretty_midi import PrettyMIDI
import numpy as np

note_events = []
midi_files = 0
for root, _, files in os.walk(extract_path):
    for fname in files:
        if "Recital1-3" in fname and fname.endswith(".midi"):
        # if fname.endswith(".mid"):
            try:
                midi_files += 1
                midi = PrettyMIDI(os.path.join(root, fname))
                for instrument in midi.instruments:
                    if instrument.is_drum:
                        continue
                    for note in instrument.notes:
                        pitch = note.pitch
                        velocity = note.velocity
                        start = note.start
                        end = note.end
                        duration = end - start
                        note_events.append((pitch, velocity, start, duration))
            except Exception as e:
                print(f"Skip {fname}，Error: {e}")

# Sort by start time
note_events.sort(key=lambda x: x[2])
print(len(note_events))

102537


In [11]:
# 3. Gather unique pitches and velocities
pitches = sorted({evt[0] for evt in note_events})
velocities = sorted({evt[1] for evt in note_events})

# 4. Quantize durations to 0.05s bins and count unique
quantized_durations = [round(evt[3]*20) / 20 for evt in note_events]
unique_durations = sorted(set(quantized_durations))

# 5. Display statistics
print(f"Total MIDI files        : {midi_files}")
print(f"Total note events       : {len(note_events)}")
print(f"Unique pitch values     : {len(pitches)}")
print(f"Unique velocity bins    : {len(velocities)}")
print(f"Quantized durations     : {len(unique_durations)}")

Total MIDI files        : 1276
Total note events       : 7040164
Unique pitch values     : 88
Unique velocity bins    : 126
Quantized durations     : 367


In [14]:
pitches = sorted(set(n[0] for n in note_events))
velocities = sorted(set(n[1] for n in note_events))
# durations = sorted(set(round(n[3], 2) for n in note_events))
dur_quantized = [round(n[3] * 20) / 20 for n in note_events]
durations = sorted({d for d in dur_quantized})

pitch2id = {p: i for i, p in enumerate(pitches)}
vel2id = {v: i for i, v in enumerate(velocities)}
dur2id = {d: i for i, d in enumerate(durations)}
id2pitch = {i: p for p, i in pitch2id.items()}
id2vel = {i: v for v, i in vel2id.items()}
id2dur = {i: d for d, i in dur2id.items()}

print(f"pitch classes: {len(pitches)}, velocity classes: {len(velocities)}, duration classes: {len(durations)}")

pitch classes: 88, velocity classes: 121, duration classes: 165


In [15]:
seq_len = 128
X, y = [], []

# note_ids = [(pitch2id[n[0]], vel2id[n[1]], dur2id[round(n[3], 2)]) for n in note_events]
note_ids = [(pitch2id[n[0]], vel2id[n[1]], dur2id[round(n[3] * 20) / 20]) for n in note_events]

for i in range(len(note_ids) - seq_len):
    X.append(note_ids[i:i+seq_len])
    y.append(note_ids[i+seq_len])

X = np.array(X)
y = np.array(y)

In [5]:
import torch
import torch.nn as nn

class PolyBiLSTM(nn.Module):
    def __init__(self, pitch_size, vel_size, dur_size, embed_dim=128, hidden=256, num_layers=2):
        super(PolyBiLSTM, self).__init__()
        self.pitch_emb = nn.Embedding(pitch_size, embed_dim)
        self.vel_emb = nn.Embedding(vel_size, embed_dim)
        self.dur_emb = nn.Embedding(dur_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim * 3, hidden, num_layers=num_layers,
                   batch_first=True, bidirectional=True, dropout=0.2)
        self.pitch_out = nn.Linear(hidden * 2, pitch_size)
        self.vel_out = nn.Linear(hidden * 2, vel_size)
        self.dur_out = nn.Linear(hidden * 2, dur_size)

    def forward(self, x):
        pitch, vel, dur = x[:,:,0], x[:,:,1], x[:,:,2]
        p = self.pitch_emb(pitch)
        v = self.vel_emb(vel)
        d = self.dur_emb(dur)
        x_cat = torch.cat([p, v, d], dim=-1)
        _, (h, _) = self.lstm(x_cat)
        # h = h[-1]
        h = torch.cat((h[-2], h[-1]), dim=1)
        return self.pitch_out(h), self.vel_out(h), self.dur_out(h)

class PolyTransformer(nn.Module):
    def __init__(self, pitch_size, vel_size, dur_size, embed_dim=128, nhead=8, num_layers=4, dim_feedforward=512, dropout=0.1):
        super(PolyTransformer, self).__init__()
        self.pitch_emb = nn.Embedding(pitch_size, embed_dim)
        self.vel_emb = nn.Embedding(vel_size, embed_dim)
        self.dur_emb = nn.Embedding(dur_size, embed_dim)
        self.pos_emb = nn.Parameter(torch.randn(1, 512, embed_dim * 3))

        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim * 3, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.pitch_out = nn.Linear(embed_dim * 3, pitch_size)
        self.vel_out = nn.Linear(embed_dim * 3, vel_size)
        self.dur_out = nn.Linear(embed_dim * 3, dur_size)

    def forward(self, x):
        pitch, vel, dur = x[:, :, 0], x[:, :, 1], x[:, :, 2]
        p = self.pitch_emb(pitch)
        v = self.vel_emb(vel)
        d = self.dur_emb(dur)
        x_cat = torch.cat([p, v, d], dim=-1)

        seq_len = x_cat.size(1)
        pos_emb = self.pos_emb[:, :seq_len, :]
        x_cat = x_cat + pos_emb

        x_cat = x_cat.transpose(0, 1)
        h = self.transformer(x_cat)
        h = h[-1]

        return self.pitch_out(h), self.vel_out(h), self.dur_out(h)


In [16]:
# ========== Train ==========
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

X_tensor = torch.tensor(X, dtype=torch.long)
y_tensor = torch.tensor(y, dtype=torch.long)
dataset = TensorDataset(X_tensor, y_tensor)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PolyBiLSTM(len(pitch2id), len(vel2id), len(dur2id)).to(device)
# model = PolyTransformer(len(pitch2id), len(vel2id), len(dur2id)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.002)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

num_epochs = 15
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
    for xb, yb in loop:
        xb, yb = xb.to(device), yb.to(device)
        out_p, out_v, out_d = model(xb)
        loss = criterion(out_p, yb[:,0]) + criterion(out_v, yb[:,1]) + criterion(out_d, yb[:,2])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        loop.set_postfix(loss=loss.item())
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}, Avg Loss: {avg_loss:.2f}")



Epoch 1, Avg Loss: 10.55




Epoch 2, Avg Loss: 10.37




Epoch 3, Avg Loss: 10.26




Epoch 4, Avg Loss: 10.13




Epoch 5, Avg Loss: 9.97




Epoch 6, Avg Loss: 9.79




Epoch 7, Avg Loss: 9.59




Epoch 8, Avg Loss: 9.38




Epoch 9, Avg Loss: 9.18




Epoch 10, Avg Loss: 9.00




Epoch 11, Avg Loss: 8.82




Epoch 12, Avg Loss: 8.65




Epoch 13, Avg Loss: 8.50




Epoch 14, Avg Loss: 8.35


                                                                           

Epoch 15, Avg Loss: 8.22




In [17]:
num_samples = X.shape[0]
train_end = int(0.8 * num_samples)
val_end   = int(0.9 * num_samples)

X_train, y_train = X[:train_end], y[:train_end]
X_val,   y_val   = X[train_end:val_end], y[train_end:val_end]
X_test,  y_test  = X[val_end:], y[val_end:]

# Create TensorDatasets + DataLoaders
train_ds = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
val_ds   = TensorDataset(torch.tensor(X_val),   torch.tensor(y_val))
test_ds  = TensorDataset(torch.tensor(X_test),  torch.tensor(y_test))

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=64, shuffle=False)
test_loader  = DataLoader(test_ds,  batch_size=64, shuffle=False)

# Evaluation loop for test set
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = PolyBiLSTM(len(pitch2id), len(vel2id), len(dur2id)).to(device)
model.eval()
total_pitch, total_vel, total_dur = 0, 0, 0
count = 0
with torch.no_grad():
    for xb, yb in test_loader:
        xb, yb = xb.to(device), yb.to(device)
        p_logits, v_logits, d_logits = model(xb)
        # Compute top-1 accuracy
        p_pred = torch.argmax(p_logits, dim=1)
        v_pred = torch.argmax(v_logits, dim=1)
        d_pred = torch.argmax(d_logits, dim=1)
        total_pitch += (p_pred == yb[:,0]).sum().item()
        total_vel   += (v_pred == yb[:,1]).sum().item()
        total_dur   += (d_pred == yb[:,2]).sum().item()
        count += xb.size(0)

pitch_acc = total_pitch / count
vel_acc   = total_vel   / count
dur_acc   = total_dur   / count

print(f"Test Pitch Acc: {pitch_acc:.3f}, Vel Acc: {vel_acc:.3f}, Dur Acc: {dur_acc:.3f}")

Test Pitch Acc: 0.462, Vel Acc: 0.365, Dur Acc: 0.594


In [14]:
# --- 2. Compute Most-Frequent Baseline from training labels ---
# Extract train labels as NumPy arrays for convenience
y_train_np = np.array(y_train)  # shape (N_train, 3)
pitch_counts = np.bincount(y_train_np[:, 0])
vel_counts   = np.bincount(y_train_np[:, 1])
dur_counts   = np.bincount(y_train_np[:, 2])

# Find the most frequent ID in each category
mf_pitch = int(np.argmax(pitch_counts))
mf_vel   = int(np.argmax(vel_counts))
mf_dur   = int(np.argmax(dur_counts))

print(f"Most frequent pitch_id = {mf_pitch}, vel_id = {mf_vel}, dur_id = {mf_dur}")

# Evaluate Most-Frequent baseline on test set
total_mf_pitch, total_mf_vel, total_mf_dur = 0, 0, 0
count = 0

for xb, yb in test_loader:
    yb = yb.numpy()  # (batch_size, 3)
    batch_size = yb.shape[0]
    # Create arrays of most frequent predictions
    mf_pitch_preds = np.full(batch_size, mf_pitch, dtype=int)
    mf_vel_preds   = np.full(batch_size, mf_vel,   dtype=int)
    mf_dur_preds   = np.full(batch_size, mf_dur,   dtype=int)
    # Compare to ground truth
    total_mf_pitch += (mf_pitch_preds == yb[:, 0]).sum()
    total_mf_vel   += (mf_vel_preds   == yb[:, 1]).sum()
    total_mf_dur   += (mf_dur_preds   == yb[:, 2]).sum()
    count += batch_size

mf_pitch_acc = total_mf_pitch / count
mf_vel_acc   = total_mf_vel / count
mf_dur_acc   = total_mf_dur / count

print("Most-Frequent Baseline Accuracies:")
print(f"  Pitch Acc: {mf_pitch_acc:.3f}, Vel Acc: {mf_vel_acc:.3f}, Dur Acc: {mf_dur_acc:.3f}")

Most frequent pitch_id = 41, vel_id = 65, dur_id = 1
Most-Frequent Baseline Accuracies:
  Pitch Acc: 0.032, Vel Acc: 0.016, Dur Acc: 0.245


In [15]:
# --- 3. Evaluate Random (Uniform) Baseline on test set ---
# We'll draw a single random prediction for each sample, uniformly over the respective vocab size
pitch_vocab_size = len(pitch2id)
vel_vocab_size   = len(vel2id)
dur_vocab_size   = len(dur2id)

total_rand_pitch, total_rand_vel, total_rand_dur = 0, 0, 0
count = 0

# Set a fixed seed for reproducibility
random.seed(0)

for xb, yb in test_loader:
    yb = yb.numpy()  # (batch_size, 3)
    batch_size = yb.shape[0]
    # Generate random uniform predictions
    rand_pitch_preds = np.random.randint(0, pitch_vocab_size, size=batch_size)
    rand_vel_preds   = np.random.randint(0, vel_vocab_size,   size=batch_size)
    rand_dur_preds   = np.random.randint(0, dur_vocab_size,   size=batch_size)
    # Compare to ground truth
    total_rand_pitch += (rand_pitch_preds == yb[:, 0]).sum()
    total_rand_vel   += (rand_vel_preds   == yb[:, 1]).sum()
    total_rand_dur   += (rand_dur_preds   == yb[:, 2]).sum()
    count += batch_size

rand_pitch_acc = total_rand_pitch / count
rand_vel_acc   = total_rand_vel / count
rand_dur_acc   = total_rand_dur / count

print("Random Uniform Baseline Accuracies:")
print(f"  Pitch Acc: {rand_pitch_acc:.3f}, Vel Acc: {rand_vel_acc:.3f}, Dur Acc: {rand_dur_acc:.3f}")

Random Uniform Baseline Accuracies:
  Pitch Acc: 0.013, Vel Acc: 0.009, Dur Acc: 0.006


In [8]:
# Save model and dictionary
import pickle

torch.save(model.state_dict(), "poly_lstm.pth")
with open("lstm_midi_maps.pkl", "wb") as f:
    pickle.dump({"pitch2id": pitch2id, "vel2id": vel2id, "dur2id": dur2id,
                 "id2pitch": id2pitch, "id2vel": id2vel, "id2dur": id2dur}, f)

In [7]:
# Load model and dictionary (if needed)
from torch.utils.data import DataLoader, TensorDataset
import torch
import pickle
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open("lstm_midi_maps (1).pkl", "rb") as f:
    midi_maps = pickle.load(f)

pitch2id = midi_maps["pitch2id"]
vel2id = midi_maps["vel2id"]
dur2id = midi_maps["dur2id"]
id2pitch = midi_maps["id2pitch"]
id2vel = midi_maps["id2vel"]
id2dur = midi_maps["id2dur"]

# model = PolyBiLSTM(len(pitch2id), len(vel2id), len(dur2id)).to(device)
model = PolyTransformer(len(pitch2id), len(vel2id), len(dur2id)).to(device)
model.load_state_dict(torch.load("poly_lstm (1).pth"))

<All keys matched successfully>

In [9]:
# ========== Generate music ==========
from pretty_midi import PrettyMIDI, Instrument, Note
import random
import torch.nn.functional as F

model.eval()

def sample_from_logits(logits, temperature=1.0, top_k=None):
    logits = logits / temperature
    probs = F.softmax(logits, dim=-1)
    if top_k is not None and top_k > 0:
        topk_vals, topk_idx = torch.topk(probs, top_k)
        mask = torch.zeros_like(probs)
        mask.scatter_(1, topk_idx, topk_vals)
        probs = mask / mask.sum(dim=-1, keepdim=True)
    sampled_id = torch.multinomial(probs, num_samples=1).item()
    return sampled_id

# baseMidi = PrettyMIDI('/content/maestro/maestro-v3.0.0/2018/MIDI-Unprocessed_Recital1-3_MID--AUDIO_01_R1_2018_wav--1.midi')
baseMidi = PrettyMIDI('/content/maestro/maestro-v3.0.0/2018/MIDI-Unprocessed_Recital1-3_MID--AUDIO_01_R1_2018_wav--2.midi')
baseNotes = []
for instrument in baseMidi.instruments:
    if instrument.is_drum:
        continue
    for note in instrument.notes:
        pitch = note.pitch
        velocity = note.velocity
        start = note.start
        end = note.end
        duration = end - start
        baseNotes.append((pitch, velocity, start, duration))

basenote_ids = [(pitch2id[n[0]], vel2id[n[1]], dur2id[round(n[3] * 20) / 20]) for n in baseNotes]
generated = basenote_ids[:seq_len]

for _ in range(128):
    inp = torch.tensor([generated[-seq_len:]], dtype=torch.long).to(device)
    with torch.no_grad():
        p, v, d = model(inp)
    pitch = sample_from_logits(p, temperature=0.8, top_k=40)
    vel   = sample_from_logits(v, temperature=0.8, top_k=None)
    dur   = sample_from_logits(d, temperature=0.8, top_k=None)
    generated.append([pitch, vel, dur])

# Write into midi file
start = 0.0
notes = []
for p, v, d in generated:
    pitch = id2pitch[p]
    velocity = id2vel[v]
    duration = id2dur[d]
    notes.append((pitch, velocity, start, start + duration))
    start += duration

midi = PrettyMIDI()
piano = Instrument(program=0)
for pitch, vel, s, e in notes:
    piano.notes.append(Note(velocity=vel, pitch=pitch, start=s, end=e))
midi.instruments.append(piano)
midi.write("generated_lstm_polyphonic.mid")
print("Generate complete：generated_lstm_polyphonic.mid")

Generate complete：generated_lstm_polyphonic.mid


In [17]:
import numpy as np
import random

# --- 3. Compute Most-Frequent IDs from Training Set ---
num_samples = X.shape[0]
train_end = int(0.8 * num_samples)
val_end   = int(0.9 * num_samples)

X_train, y_train = X[:train_end], y[:train_end]
X_val,   y_val   = X[train_end:val_end], y[train_end:val_end]
X_test,  y_test  = X[val_end:], y[val_end:]

train_ds = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
val_ds   = TensorDataset(torch.tensor(X_val),   torch.tensor(y_val))
test_ds  = TensorDataset(torch.tensor(X_test),  torch.tensor(y_test))

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=64, shuffle=False)
test_loader  = DataLoader(test_ds,  batch_size=64, shuffle=False)


y_train_np = np.array(y_train)  # shape (N_train, 3)
mf_pitch = int(np.argmax(np.bincount(y_train_np[:, 0])))
mf_vel   = int(np.argmax(np.bincount(y_train_np[:, 1])))
mf_dur   = int(np.argmax(np.bincount(y_train_np[:, 2])))
print(f"Most-Frequent Baseline IDs → Pitch: {mf_pitch}, Vel: {mf_vel}, Dur: {mf_dur}")

# --- 4. Initialize Accumulators ---
# Model metrics
ce_pitch = nn.CrossEntropyLoss()
ce_vel   = nn.CrossEntropyLoss()
ce_dur   = nn.CrossEntropyLoss()

total_ce_pitch = 0.0
total_ce_vel   = 0.0
total_ce_dur   = 0.0

total_mse_pitch = 0.0
total_mse_vel   = 0.0
total_mse_dur   = 0.0

total_mae_pitch = 0.0
total_mae_vel   = 0.0
total_mae_dur   = 0.0

count = 0

# Baseline accumulators
mf_mse_pitch = 0.0
mf_mse_vel   = 0.0
mf_mse_dur   = 0.0

mf_mae_pitch = 0.0
mf_mae_vel   = 0.0
mf_mae_dur   = 0.0

rand_mse_pitch = 0.0
rand_mse_vel   = 0.0
rand_mse_dur   = 0.0

rand_mae_pitch = 0.0
rand_mae_vel   = 0.0
rand_mae_dur   = 0.0

pitch_vocab_size = len(pitch2id)
vel_vocab_size   = len(vel2id)
dur_vocab_size   = len(dur2id)

# Set seed for reproducibility
random.seed(0)

# --- 5. Iterate Over Test Set ---
with torch.no_grad():
    for xb, yb in test_loader:
        xb, yb = xb.to(device), yb.to(device)
        batch_size = xb.size(0)

        # ---- Model Predictions ----
        p_logits, v_logits, d_logits = model(xb)

        # Cross-Entropy Loss
        loss_p = ce_pitch(p_logits, yb[:, 0])
        loss_v = ce_vel(  v_logits, yb[:, 1])
        loss_d = ce_dur(  d_logits, yb[:, 2])

        total_ce_pitch += loss_p.item() * batch_size
        total_ce_vel   += loss_v.item() * batch_size
        total_ce_dur   += loss_d.item() * batch_size

        # Predicted IDs
        p_pred = torch.argmax(p_logits, dim=1)
        v_pred = torch.argmax(v_logits, dim=1)
        d_pred = torch.argmax(d_logits, dim=1)

        # MSE & MAE for model
        mse_p = (p_pred.float() - yb[:, 0].float()).pow(2).sum().item()
        mse_v = (v_pred.float() - yb[:, 1].float()).pow(2).sum().item()
        mse_d = (d_pred.float() - yb[:, 2].float()).pow(2).sum().item()

        mae_p = torch.abs(p_pred.float() - yb[:, 0].float()).sum().item()
        mae_v = torch.abs(v_pred.float() - yb[:, 1].float()).sum().item()
        mae_d = torch.abs(d_pred.float() - yb[:, 2].float()).sum().item()

        total_mse_pitch += mse_p
        total_mse_vel   += mse_v
        total_mse_dur   += mse_d

        total_mae_pitch += mae_p
        total_mae_vel   += mae_v
        total_mae_dur   += mae_d

        # ---- Most-Frequent Baseline ----
        yb_np = yb.cpu().numpy()
        # MSE
        mf_mse_pitch += ((mf_pitch - yb_np[:, 0]) ** 2).sum()
        mf_mse_vel   += ((mf_vel   - yb_np[:, 1]) ** 2).sum()
        mf_mse_dur   += ((mf_dur   - yb_np[:, 2]) ** 2).sum()
        # MAE
        mf_mae_pitch += np.abs(mf_pitch - yb_np[:, 0]).sum()
        mf_mae_vel   += np.abs(mf_vel   - yb_np[:, 1]).sum()
        mf_mae_dur   += np.abs(mf_dur   - yb_np[:, 2]).sum()

        # ---- Random Uniform Baseline ----
        rand_p = np.random.randint(0, pitch_vocab_size,   size=batch_size)
        rand_v = np.random.randint(0, vel_vocab_size,     size=batch_size)
        rand_d = np.random.randint(0, dur_vocab_size,     size=batch_size)

        # MSE
        rand_mse_pitch += ((rand_p - yb_np[:, 0]) ** 2).sum()
        rand_mse_vel   += ((rand_v - yb_np[:, 1]) ** 2).sum()
        rand_mse_dur   += ((rand_d - yb_np[:, 2]) ** 2).sum()
        # MAE
        rand_mae_pitch += np.abs(rand_p - yb_np[:, 0]).sum()
        rand_mae_vel   += np.abs(rand_v - yb_np[:, 1]).sum()
        rand_mae_dur   += np.abs(rand_d - yb_np[:, 2]).sum()

        count += batch_size

# --- 6. Compute Model Metrics ---
avg_ce_pitch = total_ce_pitch / count
avg_ce_vel   = total_ce_vel   / count
avg_ce_dur   = total_ce_dur   / count

perplexity_pitch = math.exp(avg_ce_pitch)
perplexity_vel   = math.exp(avg_ce_vel)
perplexity_dur   = math.exp(avg_ce_dur)

mse_pitch = total_mse_pitch / count
mse_vel   = total_mse_vel   / count
mse_dur   = total_mse_dur   / count

mae_pitch = total_mae_pitch / count
mae_vel   = total_mae_vel   / count
mae_dur   = total_mae_dur   / count

# --- 7. Compute Baseline Metrics ---
mf_mse_pitch /= count
mf_mse_vel   /= count
mf_mse_dur   /= count

mf_mae_pitch /= count
mf_mae_vel   /= count
mf_mae_dur   /= count

rand_mse_pitch /= count
rand_mse_vel   /= count
rand_mse_dur   /= count

rand_mae_pitch /= count
rand_mae_vel   /= count
rand_mae_dur   /= count

# --- 8. Print Results ---
print("===== Model (BiLSTM) Metrics =====")
print(f"Cross-Entropy (Pitch): {avg_ce_pitch:.4f}, Perplexity: {perplexity_pitch:.2f}")
print(f"Cross-Entropy (Vel):   {avg_ce_vel:.4f}, Perplexity: {perplexity_vel:.2f}")
print(f"Cross-Entropy (Dur):   {avg_ce_dur:.4f}, Perplexity: {perplexity_dur:.2f}\n")

print(f"MSE (Pitch): {mse_pitch:.3f}, MAE (Pitch): {mae_pitch:.3f}")
print(f"MSE (Vel):   {mse_vel:.3f},   MAE (Vel):   {mae_vel:.3f}")
print(f"MSE (Dur):   {mse_dur:.3f},   MAE (Dur):   {mae_dur:.3f}\n")

print("===== Most-Frequent Baseline Metrics =====")
print(f"MSE (Pitch): {mf_mse_pitch:.3f}, MAE (Pitch): {mf_mae_pitch:.3f}")
print(f"MSE (Vel):   {mf_mse_vel:.3f},   MAE (Vel):   {mf_mae_vel:.3f}")
print(f"MSE (Dur):   {mf_mse_dur:.3f},   MAE (Dur):   {mf_mae_dur:.3f}\n")

print("===== Random Uniform Baseline Metrics =====")
print(f"MSE (Pitch): {rand_mse_pitch:.3f}, MAE (Pitch): {rand_mae_pitch:.3f}")
print(f"MSE (Vel):   {rand_mse_vel:.3f},   MAE (Vel):   {rand_mae_vel:.3f}")
print(f"MSE (Dur):   {rand_mse_dur:.3f},   MAE (Dur):   {rand_mae_dur:.3f}")

Most-Frequent Baseline IDs → Pitch: 41, Vel: 65, Dur: 1
===== Model (BiLSTM) Metrics =====
Cross-Entropy (Pitch): 4.0844, Perplexity: 59.40
Cross-Entropy (Vel):   4.4245, Perplexity: 83.47
Cross-Entropy (Dur):   2.3970, Perplexity: 10.99

MSE (Pitch): 197.186, MAE (Pitch): 11.461
MSE (Vel):   441.791,   MAE (Vel):   17.843
MSE (Dur):   130.246,   MAE (Dur):   4.344

===== Most-Frequent Baseline Metrics =====
MSE (Pitch): 194.101, MAE (Pitch): 11.442
MSE (Vel):   413.645,   MAE (Vel):   16.624
MSE (Dur):   130.246,   MAE (Dur):   4.344

===== Random Uniform Baseline Metrics =====
MSE (Pitch): 830.302, MAE (Pitch): 24.104
MSE (Vel):   1672.456,   MAE (Vel):   34.050
MSE (Dur):   8185.461,   MAE (Dur):   77.149
