# Task 2: Piano Transcription via TorchCREPE Fine-Tuning

**Author:** Anirudh Annabathula
**Date:** _May 31, 2025_

---

**Overview**  
In this notebook, we fine-tune a small PyTorch Bi-LSTM head on top of pretrained TorchCREPE features to transcribe piano WAVs into MIDI (symbolic output). We use the MAESTRO 2004 dataset of aligned piano recordings + MIDI.  
- **Section 1**: EDA & dataset statistics  
- **Section 2**: Extract CREPE features  
- **Section 3**: Build frame‐level labels  
- **Section 4**: PyTorch `Dataset` class  
- **Section 5**: Model definition & training loop  
- **Section 6**: Inference (WAV → `symbolic_conditioned.mid`)  
- **Section 7**: Example evaluation and results  

---

> **Note**: Before running this notebook, ensure you have executed the shell‐scripts:
> 
> ```bash
> # 1) CREPE features
> python scripts/extract_crepe_features.py \
>   --input_dir /home/ubuntu/data/maestro_2004 \
>   --output_dir /home/ubuntu/data/maestro_crepe \
>   --crepe_model full \
>   --device cuda
> 
> # 2) Frame-level labels
> python scripts/build_frame_targets.py \
>   --wav_dir /home/ubuntu/data/maestro_2004 \
>   --midi_dir /home/ubuntu/data/maestro_2004 \
>   --output_dir /home/ubuntu/data/maestro_labels \
>   --sr 16000 \
>   --hop_length 160
> ```
>
> These generate `~/data/maestro_crepe/* (n × .npz)` and `~/data/maestro_labels/* (n × .npz)`.  
> You need Python 3.10, PyTorch, TorchCREPE, librosa, pretty_midi installed in your `env_task2` venv.


In [None]:
# Cell 2: Basic imports & GPU check

import os
import glob
import numpy as np
import torch

print("Python version:", sys.version.split()[0])
print("PyTorch version:", torch.__version__)
print("CUDA available?", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA device:", torch.cuda.get_device_name(0))


## 1. Exploratory Data Analysis

We first look at the contents of **MAESTRO 2004** to understand how many files we have, their durations, and MIDI note statistics.

- **Audio source**: `/home/ubuntu/data/maestro_2004/2004/*.wav`  
- **MIDI source** : `/home/ubuntu/data/maestro_2004/2004/*.midi`

We will:
1. Count how many WAV and MIDI pairs exist.  
2. Plot the distribution of audio durations.  
3. Count total MIDI notes per piece.  
4. Show an example spectrogram + MIDI‐overlaid plot for a 5 s snippet.


In [None]:
# Cell 4: Count WAV/MIDI files and compute durations & note counts

import librosa
import pretty_midi
import matplotlib.pyplot as plt

wav_paths  = sorted(glob.glob("/home/ubuntu/data/maestro_2004/2004/*.wav"))
midi_paths = sorted(glob.glob("/home/ubuntu/data/maestro_2004/2004/*.midi"))
assert len(wav_paths) == len(midi_paths), "Mismatch between WAV and MIDI count"

print(f"Found {len(wav_paths)} WAV files and {len(midi_paths)} MIDI files.\n")

durations = []
note_counts = []

for wav, midi in zip(wav_paths, midi_paths):
    # 1. Audio duration in seconds (we load header only via librosa.get_duration)
    dur = librosa.get_duration(filename=wav, sr=16000)
    durations.append(dur)
    # 2. Total MIDI notes
    pm = pretty_midi.PrettyMIDI(midi)
    total_notes = sum(len(inst.notes) for inst in pm.instruments)
    note_counts.append(total_notes)

# Plot distribution of audio durations
plt.figure(figsize=(6,3))
plt.hist(durations, bins=15, color="C0", edgecolor="k")
plt.xlabel("Duration (s)")
plt.ylabel("Number of pieces")
plt.title("MAESTRO 2004 Audio Durations")
plt.tight_layout()
plt.show()

# Plot distribution of total MIDI notes
plt.figure(figsize=(6,3))
plt.hist(note_counts, bins=15, color="C1", edgecolor="k")
plt.xlabel("Total MIDI notes per piece")
plt.ylabel("Number of pieces")
plt.title("MAESTRO 2004 MIDI Note Counts")
plt.tight_layout()
plt.show()

# Display summary stats
print("Duration (s):  min=", np.min(durations), 
      "  max=", np.max(durations), 
      "  mean=", round(np.mean(durations),2))
print("MIDI notes:    min=", np.min(note_counts), 
      "  max=", np.max(note_counts), 
      "  mean=", round(np.mean(note_counts),2))


In [None]:
# Cell 5: Example 5-second snippet spectrogram + CREPE overlay

import librosa.display

# Pick the first WAV for demonstration
wav_example = wav_paths[0]
midi_example = midi_paths[0]
snippet_start = 30.0  # seconds
snippet_dur   = 5.0   # seconds

# 1. Load 5 s snippet at 16 kHz
y_snip, sr = librosa.load(wav_example, sr=16000, mono=True, 
                          offset=snippet_start, duration=snippet_dur)

# 2. Compute a mel spectrogram
S = librosa.feature.melspectrogram(y_snip, sr=sr, n_mels=128, hop_length=256)
S_db = librosa.power_to_db(S, ref=np.max)

# 3. Plot mel spectrogram
plt.figure(figsize=(6,4))
librosa.display.specshow(S_db, sr=sr, hop_length=256, 
                         x_axis="time", y_axis="mel", cmap="magma")
plt.title("5 s Mel Spectrogram (30–35 s snippet)")
plt.colorbar(format="%+2.0f dB")
plt.tight_layout()
plt.show()

# 4. Optionally, show MIDI piano-roll for same snippet
pm = pretty_midi.PrettyMIDI(midi_example)
plt.figure(figsize=(6,1.5))
times = np.linspace(0, snippet_dur, S_db.shape[1])
for inst in pm.instruments:
    for note in inst.notes:
        if snippet_start <= note.start < snippet_start + snippet_dur:
            plt.hlines(note.pitch, 
                       (note.start - snippet_start), 
                       (note.end   - snippet_start), 
                       lw=2, color="cyan")
plt.xlim(0, snippet_dur)
plt.ylim(20, 108)  # piano pitch range
plt.xlabel("Time (s)")
plt.ylabel("MIDI pitch")
plt.title("Piano-Roll of Snippet (30–35 s)")
plt.tight_layout()
plt.show()


## 2. Extract TorchCREPE Features

We use TorchCREPE to estimate **frame-level fundamental frequency (f₀)** and **periodicity (confidence)** at 10 ms hops (16 kHz sampling, hop_length=160).  
The output for each WAV is saved as an `.npz` with two arrays:  
- `f0`: [T] in Hz  
- `conf`: [T] in [0, 1] confidence

These `.npz` files live under `/home/ubuntu/data/maestro_crepe/…`, mirroring the original `maestro_2004/…wav` structure.


In [None]:
# Cell 7: Verify few CREPE .npz files exist and inspect their shapes

crepe_files = sorted(glob.glob("/home/ubuntu/data/maestro_crepe/**/*.npz", recursive=True))
print("Found", len(crepe_files), "CREPE .npz files.")

# Load one example
npz_ex = crepe_files[0]
data = np.load(npz_ex)
print("Example:", npz_ex)
print("  f0 shape  :", data["f0"].shape)
print("  conf shape:", data["conf"].shape)

# Plot F0 and confidence for first 100 frames (~1 s)
f0_vals = data["f0"][:100]
conf_vals = data["conf"][:100]
times = np.arange(len(f0_vals)) * 0.01  # seconds

plt.figure(figsize=(6,2))
plt.plot(times, f0_vals, linewidth=1., label="f0 (Hz)")
plt.plot(times, conf_vals * np.max(f0_vals), linestyle="--", 
         label="confidence × max(f0)")
plt.xlabel("Time (s)")
plt.legend(loc="upper right")
plt.title("CREPE Features for First 1 s")
plt.tight_layout()
plt.show()


## 3. Build Frame-Level Labels

Each MIDI file is converted into a `[T × 89]` array, where `T = ⌈audio_samples / hop_length⌉`.  
- Columns 0–87 → MIDI pitches 21..108 (A0..C8).  
- Column 88 → “No-note” (when no pitch is active).  

We saved these under `/home/ubuntu/data/maestro_labels/… .npz` with key `"labels"`.  
Below, we load one and inspect its distribution.


In [None]:
# Cell 9: Verify a label .npz and plot pitch‐distribution for one piece

label_files = sorted(glob.glob("/home/ubuntu/data/maestro_labels/**/*.npz", recursive=True))
print("Found", len(label_files), "label .npz files.")

npz_lab = label_files[0]
lab_data = np.load(npz_lab)["labels"]  # shape (T, 89)

print("Example:", npz_lab)
print("  labels shape:", lab_data.shape)
print("  Sum over time for each pitch bin (0..87):")
pitch_sums = lab_data[:, :88].sum(axis=0)
plt.figure(figsize=(6,2))
plt.bar(np.arange(21,109), pitch_sums, width=1.0, color="C2")
plt.xlabel("MIDI pitch")
plt.ylabel("Frame‐count")
plt.title("Pitch‐Histogram (frames) for First Piece")
plt.tight_layout()
plt.show()

# Count how many frames are “no-note”
no_note_count = (lab_data[:,88] == 1).sum()
print("No-note frames:", no_note_count, "/", lab_data.shape[0])


## 4. PyTorch `Dataset`: Pairing CREPE & Label Files

We create a `MaestroFrameDataset` that, for each index, returns:
- `features`: FloatTensor [T × 2] (`[f0, conf]`)  
- `targets`: LongTensor [T] (in 0..88)  

If desired, we can set `max_frames` to pad/clip every example to a fixed length.  


In [None]:
# Cell 11: Define MaestroFrameDataset

import torch
from torch.utils.data import Dataset

class MaestroFrameDataset(Dataset):
    def __init__(self, crepe_dir, label_dir, max_frames=None):
        self.pairs = []
        for crepe_npz in glob.glob(os.path.join(crepe_dir, "**", "*.npz"), recursive=True):
            rel = os.path.relpath(crepe_npz, crepe_dir)
            label_npz = os.path.join(label_dir, os.path.splitext(rel)[0] + ".npz")
            if os.path.exists(label_npz):
                self.pairs.append((crepe_npz, label_npz))
        self.max_frames = max_frames

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        crepe_npz, label_npz = self.pairs[idx]
        data = np.load(crepe_npz)
        f0   = data["f0"]       # (T,)
        conf = data["conf"]     # (T,)
        feats = np.stack([f0, conf], axis=1).astype(np.float32)  # (T,2)

        lbl = np.load(label_npz)["labels"]            # (T,89)
        targets = np.argmax(lbl, axis=1).astype(np.int64)  # (T,)

        if self.max_frames is not None:
            T = feats.shape[0]
            if T < self.max_frames:
                pad = self.max_frames - T
                feats   = np.pad(feats,   ((0,pad),(0,0)), mode="constant")
                targets = np.pad(targets, (0,pad), mode="constant", constant_values=88)
            else:
                feats   = feats[:self.max_frames]
                targets = targets[:self.max_frames]

        return torch.from_numpy(feats), torch.from_numpy(targets)

# Quick sanity check
ds = MaestroFrameDataset("/home/ubuntu/data/maestro_crepe", "/home/ubuntu/data/maestro_labels")
print("Dataset size:", len(ds))
f0_feats, targs = ds[0]
print("Example features shape:", f0_feats.shape, "  targets shape:", targs.shape)


## 5. Model Definition & Training

We define a simple Bi-LSTM head:

- **Input** → `[batch, T, 2]` (CREPE `f0, conf`)  
- **Bi-LSTM** (2 layers, hidden_dim = 128)  
- **Dense** → 89-way softmax (`88` pitches + `1` no-note)  

**Training details**  
- Loss: CrossEntropy on flattened `(batch×T)` predictions vs. targets  
- Optimizer: Adam, learning rate = 1e-4  
- Batch size: 8  
- 10 epochs, split 10% validation  
- Save best validation‐accuracy checkpoint to `transcriber_best.pt`


In [None]:
# Cell 13: Define model, training & validation loops

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

class FrameTranscriber(nn.Module):
    def __init__(self, hidden_dim=128, num_layers=2, dropout=0.3, num_classes=89):
        super().__init__()
        self.lstm = nn.LSTM(2, hidden_dim, num_layers=num_layers,
                            batch_first=True, bidirectional=True, dropout=dropout)
        self.fc   = nn.Linear(2*hidden_dim, num_classes)

    def forward(self, x):
        out, _ = self.lstm(x)        # out: [B, T, 2*hidden_dim]
        logits = self.fc(out)        # [B, T, num_classes]
        return logits

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    epoch_loss = 0.0
    for feats, targets in loader:
        feats   = feats.to(device)      # [B, T, 2]
        targets = targets.to(device)    # [B, T]
        optimizer.zero_grad()
        logits = model(feats)           # [B, T, 89]
        loss   = criterion(logits.view(-1,89), targets.view(-1))
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(loader)

def eval_epoch(model, loader, criterion, device):
    model.eval()
    val_loss = 0.0
    correct = 0
    total   = 0
    with torch.no_grad():
        for feats, targets in loader:
            feats   = feats.to(device)
            targets = targets.to(device)
            logits  = model(feats)
            loss    = criterion(logits.view(-1,89), targets.view(-1))
            val_loss += loss.item()
            preds = logits.argmax(dim=2)   # [B, T]
            correct += (preds == targets).sum().item()
            total   += targets.numel()
    return val_loss / len(loader), correct / total


In [None]:
# Cell 14: Run training

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# 1. Instantiate dataset & split
dataset = MaestroFrameDataset("/home/ubuntu/data/maestro_crepe", 
                              "/home/ubuntu/data/maestro_labels",
                              max_frames=None)
val_n = int(len(dataset) * 0.1)
train_n = len(dataset) - val_n
train_ds, val_ds = random_split(dataset, [train_n, val_n])

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, drop_last=True)
val_loader   = DataLoader(val_ds,   batch_size=8, shuffle=False)

# 2. Build model, loss, optimizer
model     = FrameTranscriber(hidden_dim=128, num_layers=2, dropout=0.3).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

best_acc = 0.0
for epoch in range(1, 11):
    tr_loss = train_epoch(model, train_loader, criterion, optimizer, device)
    vl_loss, vl_acc = eval_epoch(model, val_loader, criterion, device)
    print(f"Epoch {epoch:02d}  TrainLoss {tr_loss:.4f}  ValLoss {vl_loss:.4f}  ValAcc {vl_acc:.4f}")
    if vl_acc > best_acc:
        best_acc = vl_acc
        torch.save(model.state_dict(), "/home/ubuntu/assignment2/transcriber_best.pt")
        print(f"→ Saved best model (ValAcc={vl_acc:.4f})")
print("Training complete. Best ValAcc:", best_acc)


## 6. Inference: WAV → MIDI

Given a new WAV (16 kHz or auto-resampled), we:

1. Extract CREPE features (`[T × 2]`).  
2. Load the trained `transcriber_best.pt` Bi-LSTM.  
3. Forward‐pass to get frame‐wise predictions in `[0..88]` (88 pitches + “no-note”).  
4. Merge consecutive frames of the same pitch into a single MIDI note (duration = until pitch changes).  
5. Write final `symbolic_conditioned.mid`.

Below is the code to perform inference on one example file.


In [None]:
# Cell 16: Inference function & example run

import torchcrepe
import pretty_midi

def extract_crepe(wav_path, model="full", device="cuda"):
    y, sr = librosa.load(wav_path, sr=16000, mono=True)
    audio = torch.tensor(y, dtype=torch.float32)[None].to(device)
    with torch.no_grad():
        f0, periodicity = torchcrepe.predict(
            audio, 16000, model=model, hop_length=160,
            fmin=65.41, fmax=1975.5, device=device, return_periodicity=True
        )
    return np.stack([f0[0].cpu().numpy(), periodicity[0].cpu().numpy()], axis=1), len(y)

def frames_to_midi(preds, audio_len, hop=160, sr=16000, out_midi="out.mid"):
    T = len(preds)
    pm = pretty_midi.PrettyMIDI()
    piano = pretty_midi.Instrument(program=0)
    cur, start = None, None
    for t in range(T):
        p = int(preds[t])
        if p != cur:
            # close previous note
            if cur is not None and cur != 88:
                s = start * hop / sr
                e = t * hop / sr
                note = pretty_midi.Note(velocity=80, pitch=cur + 21, start=s, end=e)
                piano.notes.append(note)
            # start new
            if p != 88:
                start = t
                cur = p
            else:
                cur = 88
                start = None
    # handle last note
    if cur is not None and cur != 88:
        s = start * hop / sr
        e = audio_len / sr
        note = pretty_midi.Note(velocity=80, pitch=cur + 21, start=s, end=e)
        piano.notes.append(note)
    pm.instruments.append(piano)
    pm.write(out_midi)
    print(f"Wrote MIDI {out_midi} with {len(piano.notes)} notes.")

# Load trained model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = FrameTranscriber(hidden_dim=128, num_layers=2, dropout=0.3).to(device)
state = torch.load("/home/ubuntu/assignment2/transcriber_best.pt", map_location=device)
model.load_state_dict(state); model.eval()

# Example WAV (pick a held-out MAESTRO 2004 file or any piano WAV)
test_wav = "/home/ubuntu/data/maestro_2004/2004/MIDI-Unprocessed_01_R1_2004-06-01_02.wav"
feat_matrix, audio_length = extract_crepe(test_wav, model="full", device=device)
feats_t = torch.from_numpy(feat_matrix)[None].to(device)  # [1, T, 2]

with torch.no_grad():
    logits = model(feats_t)  # [1, T, 89]
    preds  = logits.argmax(dim=2)[0].cpu().numpy()

# Write MIDI
output_midi = "/home/ubuntu/assignment2/symbolic_conditioned.mid"
frames_to_midi(preds, audio_length, hop=160, sr=16000, out_midi=output_midi)


## 7. Example Evaluation

We compare our model’s MIDI against the ground-truth MIDI for one held-out piece:

1. **Frame‐level accuracy** was printed during training.  
2. **Note‐level F₁**: we define a simple function to match predicted vs. true notes (onset within ±50 ms, same pitch).  
3. We compute precision, recall, and F₁ for that test piece.


In [None]:
# Cell 18: Note-level F1 for one piece

def note_f1(gt_midi_path, pred_midi_path, tol=0.05):
    """
    Match predicted notes to ground-truth:
      + A predicted note is TP if there exists a GT note of same pitch
        whose start_time is within ±tol seconds of the predicted start.
      + FP if no match.  
      + FN if a GT note is unmatched.
    Return (precision, recall, f1).
    """
    gt_pm   = pretty_midi.PrettyMIDI(gt_midi_path)
    pred_pm = pretty_midi.PrettyMIDI(pred_midi_path)

    gt_notes   = [(n.start, n.pitch) for inst in gt_pm.instruments for n in inst.notes]
    pred_notes = [(n.start, n.pitch) for inst in pred_pm.instruments for n in inst.notes]

    matches = []
    used_gt = set()
    for p_start, p_pitch in pred_notes:
        best = None
        for i, (g_start, g_pitch) in enumerate(gt_notes):
            if i in used_gt: continue
            if g_pitch != p_pitch: continue
            if abs(g_start - p_start) <= tol:
                best = i
                break
        if best is not None:
            matches.append(best)
            used_gt.add(best)

    tp = len(matches)
    fp = len(pred_notes) - tp
    fn = len(gt_notes) - tp
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec  = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1   = 2 * prec * rec / (prec + rec + 1e-8)
    return prec, rec, f1

# Run on example piece
gt_midi   = "/home/ubuntu/data/maestro_2004/2004/MIDI-Unprocessed_01_R1_2004-06-01_02.midi"
pred_midi = "/home/ubuntu/assignment2/symbolic_conditioned.mid"
prec, rec, f1 = note_f1(gt_midi, pred_midi, tol=0.05)
print(f"Precision: {prec:.3f}  Recall: {rec:.3f}  F1: {f1:.3f}")


## 8. Save & Export

- **Notebook HTML**: File → Download as → HTML (save as `workbook.html`).  
- **Generated MIDI**: `symbolic_conditioned.mid` is in the root `~/assignment2/`.  
- **Video**: Record a ~20 min walkthrough and upload to Google Drive; place the shareable link in `video_url.txt`.  

<mark style="color:green">Your `symbolic_conditioned.mid` is now ready for submission.</mark>
