# Melody extraction evaluation

## Load in data

In [25]:
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from pretty_midi import PrettyMIDI

from deep_pianist_identification.extractors import MelodyExtractor, note_list_to_midi, RollExtractor
from deep_pianist_identification import utils, plotting

In [3]:
RESOLUTION = 0.1
KEEP = ["mehldau", "jarrett", "peterson", "evans", "corea", "barron", "hicks",]

In [4]:
eval_root = Path(utils.get_project_root()) / "references/skyline_evaluation"
full_midi_paths = list(eval_root.rglob("**/*_full.MID"))

### Load as MIDI, truncate to 30 seconds

In [5]:
def proc(midi_fpath):
    read = PrettyMIDI(str(midi_fpath))
    notes = read.instruments[0].notes
    notes_fmt = [(i.start, i.end, i.pitch, i.velocity) for i in notes if 0 <= i.start <= 30. and 0 <= i.end <= 30]
    midi_out = note_list_to_midi(notes_fmt, resolution=read.resolution, )
    melody = MelodyExtractor(midi_out, quantize_resolution=RESOLUTION)
    return melody

## Do analysis

In [68]:
class PianoRollMelodyComparison(plotting.BasePlot):
    def __init__(self, full_roll, manual_roll, auto_roll, startx: int, pianist: str = None):
        super().__init__()
        self.pianist = pianist
        self.startx = startx
        self.full_roll = self._process_roll(full_roll)
        self.manual_roll = self._process_roll(manual_roll)
        self.auto_roll = self._process_roll(auto_roll)
        self.fig, self.ax = plt.subplots(3, 1, sharex=True, sharey=True, figsize=(plotting.WIDTH, plotting.WIDTH / 1.5))

    def _process_roll(self, roll):
        roll[roll > 1] = 1
        roll[roll == 0] = np.nan
        roll = np.flipud(roll)
        return roll[:, self.startx:self.startx + 500]

    def _create_plot(self) -> None:
        for r, ax in zip([self.full_roll, self.manual_roll, self.auto_roll], self.ax.flatten()):
            ax.imshow(r, aspect="auto")

    def _format_ax(self):
        self.ax[0].set(title="Full transcription")
        self.ax[1].set(title="Melody (annotated by hand)")
        self.ax[2].set(title="Melody (algorithm)")

        for ax in self.ax.flatten():
            plt.setp(ax.spines.values(), linewidth=plotting.LINEWIDTH)
            ax.tick_params(axis='both', width=plotting.TICKWIDTH)
            ax.set(
                xticklabels=[0, 1, 2, 3, 4, 5],
                xticks=[0, 100, 200, 300, 400, 500],
                yticks=range(0, 88, 12),
                yticklabels=[f"C{n}" for n in range(0, 8)][::-1]
            )

    def _format_fig(self) -> None:
        if self.pianist is not None:
            self.fig.suptitle("Pianist: " + self.pianist.title())
        self.fig.supxlabel("Time (seconds)")
        self.fig.supylabel("Pitch")
        self.fig.tight_layout()

    def save_fig(self):
        out = str(Path(utils.get_project_root()) /  f'reports/figures/{self.pianist}_{self.startx}')
        plotting.save_fig_all_exts(out, self.fig)


In [71]:
full_notes, hand_notes, auto_notes = [], [], []
res = []

for midi_path in full_midi_paths:
    pianist = midi_path.stem.split("_")[0]
    if not any((c.lower() == pianist.lower() for c in KEEP)):
        continue

    full_midi = PrettyMIDI(str(midi_path))
    full_notes.append(len(full_midi.instruments[0].notes))

    # load up manual annotation melody
    hand_annotated = str(midi_path.resolve()).replace("_full", "_manual")
    hand_annotated_midi = RollExtractor(PrettyMIDI(str(hand_annotated)))
    hand_annotated_note_list = [(i.start, i.pitch) for i in hand_annotated_midi.output_midi.instruments[0].notes]

    hand_notes.append(len(hand_annotated_note_list))

    # compute automatically extracted melody
    auto_extracted_midi = proc(midi_path)
    auto_extracted_note_list = [(i.start, i.pitch) for i in auto_extracted_midi.skylined.instruments[0].notes]

    total_auto_notes = len(auto_extracted_note_list)
    auto_notes.append(total_auto_notes)

    # iterate over all manual annotation melody notes
    hits = 0
    for hand_annotated_note in hand_annotated_note_list:
        for auto_note in auto_extracted_note_list:
            # setting resolution * 2 accounts for both directions
            if auto_note[1] == hand_annotated_note[1] and abs(auto_note[0] - hand_annotated_note[0]) <= (RESOLUTION * 2):
                hits += 1
                break

    # calculate error rate in same way as Norgaard paper, supplementary material
    error_rate = abs(total_auto_notes - hits) / total_auto_notes
    res.append(error_rate)

    # create plot for good piano roll
    if error_rate < 0.1:
        procer = lambda m: [(i.start, i.end, i.pitch, i.velocity) for i in m.instruments[0].notes]
        first_10_full = auto_extracted_midi.create_roll(procer(full_midi))
        first_10_auto = auto_extracted_midi.create_roll(procer(auto_extracted_midi.skylined))
        first_10_man = auto_extracted_midi.create_roll(procer(hand_annotated_midi.output_midi))

        for x in [0, 500, 1000, 1500, 2000, 2500]:
            rp = PianoRollMelodyComparison(first_10_full, first_10_man, first_10_auto, x, pianist=pianist)
            rp.create_plot()
            rp.save_fig()

print("Mean error rate: ", np.mean(res) * 100, np.std(res) * 100)
print("Total full MIDI notes: ", np.sum(full_notes))
print("Mean full MIDI notes", np.mean(full_notes), np.std(full_notes))
print("Total hand-annotated notes: ", np.sum(hand_notes))
print("Mean hand-annotated notes", np.mean(hand_notes), np.std(hand_notes))
print("Total automatically extracted notes: ", np.sum(auto_notes))
print("Mean automatically extracted notes", np.mean(auto_notes), np.std(auto_notes))

  ax.set(
  ax.set(
  ax.set(
  ax.set(
  ax.set(
  ax.set(
  ax.set(
  ax.set(
  ax.set(
  ax.set(
  ax.set(
  ax.set(
  ax.set(
  ax.set(
  ax.set(
  ax.set(
  ax.set(
  ax.set(
  ax.set(
  ax.set(
  ax.set(
  ax.set(
  ax.set(
  ax.set(


Mean error rate:  11.878750848159443
Total full MIDI notes:  3544
Mean full MIDI notes 506.2857142857143 175.7430870704769
Total hand-annotated notes:  1017
Mean hand-annotated notes 145.28571428571428 71.12908644693469
Total automatically extracted notes:  1042
Mean automatically extracted notes 148.85714285714286 60.880277292940335
