In [1]:
import exmel
import pickle

In [5]:
dataset = exmel.Dataset("dataset_v1")

In [2]:
with open("song_name_to_candidates.pkl", "rb") as f:
    song_name_to_candidates = pickle.load(f)

In [6]:
from tqdm.auto import tqdm
from exmel.sequence import song_stats

score_model = exmel.XGBoostModel('xgb_hop1_miss2_len8_large.json')
name_to_scores: dict[str, list[float]] = {}
for song in tqdm(dataset):
    score_model.load_song_stats(song_stats(song.melody, song.performance))
    candidates = song_name_to_candidates[song.name]
    name_to_scores[song.name] = score_model(candidates)

  0%|          | 0/34 [00:00<?, ?it/s]

In [7]:
from itertools import product
from exmel.wisp import weighted_interval_scheduling
from exmel.alignment import Alignment, concat_matches, FrozenMatch
import numpy as np


pen_min_tuple = list(product([0], range(31)))

f1_list: list[float] = []
precision_list: list[float] = []
recall_list: list[float] = []

for p, m in pen_min_tuple:
    f1_score = 0
    precision = 0
    recall = 0
    for song in tqdm(dataset, desc=f"{p=}, {m=}"):
        scores = np.array(name_to_scores[song.name]) - p
        candidates = song_name_to_candidates[song.name]
        updated_candidates: list[FrozenMatch] = []
        for candidate, score in zip(candidates, scores):
            if score >= m:
                updated_candidates.append(candidate.update_score(score))
        candidates = updated_candidates
        opt_score, opt_subset = weighted_interval_scheduling(
            candidates, return_subset=True, verbose=False)
        discarded_matches = [match for match in candidates if match not in opt_subset]
        concat_events = concat_matches(opt_subset)
        alignment = Alignment(concat_events, opt_subset, discarded_matches, opt_score,
                        sum(match.sum_miss for match in opt_subset),
                        sum(match.sum_error for match in opt_subset))
        assert song.ground_truth is not None
        result = exmel.evaluate_melody(song.ground_truth, alignment.events, plot=False)
        f1_score += result.f1_score
        precision += result.precision
        recall += result.recall
    f1_score /= len(dataset)
    precision /= len(dataset)
    recall /= len(dataset)
    f1_list.append(f1_score)
    precision_list.append(precision)
    recall_list.append(recall)
    print(f"{p=}, {m=}, {f1_score=}, {precision=}, {recall=}")

p=0, m=0:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=0, f1_score=0.9366007610850913, precision=0.9330983836544795, recall=0.9409448372095421


p=0, m=1:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=1, f1_score=0.939840474391724, precision=0.9394254901163056, recall=0.9409448372095421


p=0, m=2:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=2, f1_score=0.9405893685429446, precision=0.941171501986201, recall=0.9407447571775294


p=0, m=3:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=3, f1_score=0.942064275633886, precision=0.9442442841943257, recall=0.940522961377299


p=0, m=4:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=4, f1_score=0.9441551275365185, precision=0.9484158482149283, recall=0.9405074309369647


p=0, m=5:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=5, f1_score=0.9476293811022448, precision=0.9558056836341573, recall=0.940173992678634


p=0, m=6:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=6, f1_score=0.9510651935207719, precision=0.9623191175847933, recall=0.9406956107118729


p=0, m=7:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=7, f1_score=0.953398215745433, precision=0.9675131981610418, recall=0.9402168533973672


p=0, m=8:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=8, f1_score=0.9540747074878847, precision=0.971733374269161, recall=0.9375443367688759


p=0, m=9:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=9, f1_score=0.9535580824499649, precision=0.9753310956304266, recall=0.9334737965490255


p=0, m=10:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=10, f1_score=0.9515122122942602, precision=0.9760442575570525, recall=0.9290323162860707


p=0, m=11:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=11, f1_score=0.9509700747776357, precision=0.9763025628924966, recall=0.9277782250844367


p=0, m=12:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=12, f1_score=0.9494253551066556, precision=0.9778263063564518, recall=0.9237369813700704


p=0, m=13:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=13, f1_score=0.948555830780605, precision=0.9771936925729302, recall=0.9225940784075505


p=0, m=14:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=14, f1_score=0.9467079730649919, precision=0.9781088341756883, recall=0.9185034017267792


p=0, m=15:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=15, f1_score=0.9446349461725688, precision=0.9773487490234177, recall=0.9153476469309311


p=0, m=16:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=16, f1_score=0.943477316007967, precision=0.9784256830009542, recall=0.9123712778351075


p=0, m=17:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=17, f1_score=0.9410670229076172, precision=0.9781463471451204, recall=0.9083555340729856


p=0, m=18:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=18, f1_score=0.9400725014145986, precision=0.97829733746378, recall=0.9063581828195961


p=0, m=19:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=19, f1_score=0.9392196497438693, precision=0.9783179274751542, recall=0.9048336131253754


p=0, m=20:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=20, f1_score=0.9353304350042422, precision=0.977817203368799, recall=0.8983026247186411


p=0, m=21:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=21, f1_score=0.9335315468788731, precision=0.9782959331995289, recall=0.8948483200362586


p=0, m=22:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=22, f1_score=0.9296585002677251, precision=0.9776900920080471, recall=0.8888875484961075


p=0, m=23:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=23, f1_score=0.9300577471066663, precision=0.9798239103575833, recall=0.8883958544784806


p=0, m=24:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=24, f1_score=0.9289580345612096, precision=0.9812491104915155, recall=0.8856108084226205


p=0, m=25:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=25, f1_score=0.9286480426951711, precision=0.9826294759983616, recall=0.8840247281627169


p=0, m=26:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=26, f1_score=0.9262975447017392, precision=0.9820011646187178, recall=0.8809078232139755


p=0, m=27:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=27, f1_score=0.9219960359796191, precision=0.9820140295309769, recall=0.8739872790245715


p=0, m=28:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=28, f1_score=0.9199533738852507, precision=0.9807301936627605, recall=0.871633298565883


p=0, m=29:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=29, f1_score=0.9189976259189181, precision=0.981431263455327, recall=0.8700583230292888


p=0, m=30:   0%|          | 0/34 [00:00<?, ?it/s]

p=0, m=30, f1_score=0.9170750753513791, precision=0.9821234706890118, recall=0.8672721136462793


# Data

In [12]:
from dataclasses import dataclass

@dataclass(frozen=True, slots=True)
class Detection:
    tp: int
    fp: int
    start: float
    end: float

with open("detections.pkl", "rb") as f:
    detections = pickle.load(f)

In [13]:
from typing import TypedDict
from exmel.alignment import MatchLike
import numpy as np
from collections import Counter

class Entry(TypedDict):
    name: str
    tp: int
    fp: int
    length: int
    misses: int
    error: float
    velocity: float
    duration: float
    note_mean: float
    note_std: float
    note_entropy: float
    note_unique: int
    note_change: int

class ScoreFunctions:

    @staticmethod
    def length(match: MatchLike) -> int:
        return len(match.events)

    @staticmethod
    def misses(match: MatchLike) -> int:
        return match.sum_miss

    @staticmethod
    def error(match: MatchLike) -> float:
        return match.sum_error

    @staticmethod
    def velocity(match: MatchLike) -> float:
        return float(np.mean([event.velocity for event in match.events]))
    
    @staticmethod
    def duration(match: MatchLike) -> float:
        return match.events[-1].time - match.events[0].time

    @staticmethod
    def note_mean(match: MatchLike) -> float:
        return sum(event.note for event in match.events) / len(match.events)

    @staticmethod
    def note_std(match: MatchLike) -> float:
        return float(np.std([event.note for event in match.events]))

    @staticmethod
    def note_entropy(match: MatchLike) -> float:
        counts = np.array(list(Counter(event.note for event in match.events).values()))
        probs = counts / counts.sum()
        entropy = -np.sum(probs * np.log2(probs))
        return float(entropy)

    @staticmethod
    def note_unique(match: MatchLike) -> int:
        return len(set(event.note for event in match.events))

    @staticmethod
    def note_change(match: MatchLike) -> int:
        return sum(match.events[i].note != match.events[i - 1].note
                   for i in range(1, len(match.events)))

sf = ScoreFunctions

In [None]:
dataset = exmel.Dataset("dataset_v1")


In [19]:
len(detections['All of Me(John Legend)']), len(song_name_to_candidates['All of Me(John Legend)'])

(89455, 89455)

In [21]:
detections['All of Me(John Legend)'][0]

Detection(tp=8, fp=0, start=17.115885416666668, end=19.973958333333332)

In [22]:
from tqdm.auto import tqdm

data: list[Entry] = []

for i, song in enumerate(dataset, 1):
    print(f"{i}/{len(dataset)}")
    for match, dt in tqdm(
        zip(song_name_to_candidates[song.name], detections[song.name]),
        total=len(song_name_to_candidates[song.name])):
        assert song.ground_truth is not None
        tp, fp = dt.tp, dt.fp
        data.append(Entry(
            name=song.name,
            tp=tp,
            fp=fp,
            length=sf.length(match),
            misses=sf.misses(match),
            error=sf.error(match),
            velocity=sf.velocity(match),
            duration=sf.duration(match),
            note_mean=sf.note_mean(match),
            note_std=sf.note_std(match),
            note_entropy=sf.note_entropy(match),
            note_unique=sf.note_unique(match),
            note_change=sf.note_change(match)))

1/34


  0%|          | 0/89455 [00:00<?, ?it/s]

2/34


  0%|          | 0/42679 [00:00<?, ?it/s]

3/34


  0%|          | 0/143774 [00:00<?, ?it/s]

4/34


  0%|          | 0/43752 [00:00<?, ?it/s]

5/34


  0%|          | 0/191239 [00:00<?, ?it/s]

6/34


  0%|          | 0/110853 [00:00<?, ?it/s]

7/34


  0%|          | 0/114819 [00:00<?, ?it/s]

8/34


  0%|          | 0/99185 [00:00<?, ?it/s]

9/34


  0%|          | 0/193203 [00:00<?, ?it/s]

10/34


  0%|          | 0/474568 [00:00<?, ?it/s]

11/34


  0%|          | 0/94470 [00:00<?, ?it/s]

12/34


  0%|          | 0/100287 [00:00<?, ?it/s]

13/34


  0%|          | 0/267854 [00:00<?, ?it/s]

14/34


  0%|          | 0/183897 [00:00<?, ?it/s]

15/34


  0%|          | 0/69795 [00:00<?, ?it/s]

16/34


  0%|          | 0/71091 [00:00<?, ?it/s]

17/34


  0%|          | 0/64259 [00:00<?, ?it/s]

18/34


  0%|          | 0/267908 [00:00<?, ?it/s]

19/34


  0%|          | 0/44618 [00:00<?, ?it/s]

20/34


  0%|          | 0/60017 [00:00<?, ?it/s]

21/34


  0%|          | 0/88307 [00:00<?, ?it/s]

22/34


  0%|          | 0/81310 [00:00<?, ?it/s]

23/34


  0%|          | 0/119737 [00:00<?, ?it/s]

24/34


  0%|          | 0/205465 [00:00<?, ?it/s]

25/34


  0%|          | 0/94701 [00:00<?, ?it/s]

26/34


  0%|          | 0/156823 [00:00<?, ?it/s]

27/34


  0%|          | 0/83886 [00:00<?, ?it/s]

28/34


  0%|          | 0/55454 [00:00<?, ?it/s]

29/34


  0%|          | 0/176213 [00:00<?, ?it/s]

30/34


  0%|          | 0/58540 [00:00<?, ?it/s]

31/34


  0%|          | 0/251115 [00:00<?, ?it/s]

32/34


  0%|          | 0/50815 [00:00<?, ?it/s]

33/34


  0%|          | 0/45451 [00:00<?, ?it/s]

34/34


  0%|          | 0/160851 [00:00<?, ?it/s]

In [24]:
import pandas as pd
df = pd.DataFrame(data)
# df.to_csv("matches.csv", index=False, encoding="utf-8")

In [26]:
with open("hop1_miss2_len8_matches.pkl", "wb") as f:
    pickle.dump(data, f)