In [None]:
import midi_score

midi_recording = "example/heartgrace.midi"
notes = midi_score.read_note_sequence(midi_recording)
# Create processor models
beat_pro = midi_score.RNNJointBeatProcessor()
quant_pro = midi_score.RNNJointQuantisationProcessor()
part_pro = midi_score.RNNHandPartProcessor()
time_sig_pro = midi_score.RNNTimeSignatureProcessor()
key_sig_pro = midi_score.RNNKeySignatureProcessor()

beats1 = beat_pro.process(notes)
beats2, onset_positions, note_values = quant_pro.process(notes)
hand_parts = part_pro.process(notes)

time_sig_changes = time_sig_pro.process(notes)
key_sig_changes = key_sig_pro.process(notes)
print("Time signature changes:")
print(time_sig_changes)
print("Key signature changes:")
print(key_sig_changes)

In [None]:
import pymusicxml as mxml
from fractions import Fraction
import numpy as np
import pretty_midi as pm
import pandas as pd


class KeyChangeTracker:
    EPS = 0.01

    def __init__(self, key_changes) -> None:
        self.key_changes = key_changes
        self._idx = None
        self._changes_in_bar = []

    def update_time(self, time: float):
        assert self._idx is None or 0 <= self._idx < len(self.key_changes)
        next_idx = 0 if self._idx is None else self._idx + 1
        if next_idx >= len(self.key_changes):
            return
        next_time, next_key = self.key_changes[next_idx]
        if next_time - time < self.EPS:
            self._idx = next_idx
            self._changes_in_bar.append(next_key)

    def get_measure_keysig(self):
        ret = self._changes_in_bar[0] if self._changes_in_bar else None
        self._changes_in_bar = []
        return ret


def to_frac(x):
    return Fraction(x).limit_denominator(32)


def pitch_to_str(pitch):
    octave, note = int(pitch) // 12, int(pitch) % 12
    note = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"][note]
    return f"{note}{octave}"


def cut_to_simple_time(duration: float):
    integral = np.floor(np.log2(duration))
    pow = -integral + 1
    # Shift to 2 <= duration < 4
    assert 2 <= (d_in_quarter := duration * 2**pow) < 4
    duration = np.round(d_in_quarter).astype(int)
    return float(duration / 2 ** int(pow))


# TODO: should use time signature -- current DNN predictions are wrong
timesig = (4, 4)


def get_measures_of_part(hand_part, clef):
    measures, measure_notes = [], []
    groups = list(hand_part.groupby("onset_position"))
    key_tracker = KeyChangeTracker(key_sig_changes)
    for i, (onset, group) in enumerate(groups):
        # timesig_idx, timesig = update_idx(timesig_idx, time_sig_changes, start_time, "4/4")
        key_tracker.update_time(group.iloc[0]["start_time"])
        if i + 1 < len(groups):
            duration = float(groups[i + 1][0] - onset)
        else:
            duration = group["note_value"].min()  # TODO: should use note length
        dur_in_beats = float(cut_to_simple_time(duration))
        pitches = [mxml.Pitch.from_string(s) for s in group["pitch"]]
        dur_in_quarter = dur_in_beats * 4 / timesig[1]
        if len(pitches) > 1:
            measure_notes.append(mxml.Chord(pitches, dur_in_quarter))
        else:
            measure_notes.append(mxml.Note(pitches[0], dur_in_quarter))
        if sum([n.duration.written_length / timesig[0] for n in measure_notes]) >= 1:
            kwargs = {}
            if i == 0:
                kwargs.update({"clef": clef, "time_signature": timesig})
            if (new_keysig := key_tracker.get_measure_keysig()) is not None:
                kwargs.update({"key": new_keysig})
            measures.append(mxml.Measure(measure_notes, **kwargs))
            measure_notes = []
    measures.append(mxml.Measure(measure_notes, time_signature=timesig, clef=clef))
    return measures


cat_output = np.stack([onset_positions, note_values, hand_parts], axis=1)
input_output = np.concatenate([cat_output, notes], axis=1)
df = pd.DataFrame(
    input_output,
    columns=[
        "onset_position",
        "note_value",
        "hand_part",
        "pitch",
        "start_time",
        "last_time",
        "velocity",
    ],
)
df = df.sort_values("onset_position")
# df["onset_position"] = df["onset_position"].apply(to_frac)
df["pitch"] = df["pitch"].apply(pitch_to_str)
left_hand_df, right_hand_df = df[df["hand_part"] == 0], df[df["hand_part"] == 1]

piano_left = mxml.Part("piano_left", get_measures_of_part(left_hand_df, "bass"))
piano_right = mxml.Part("piano_right", get_measures_of_part(right_hand_df, "treble"))
score = mxml.Score([piano_left, piano_right])
score.export_to_file("example/heartgrace.xml")

In [None]:
import matplotlib.pyplot as plt


def get_piano_roll(midi_file):
    notes = [
        (note.pitch, note.start, note.end)
        for instrument in pm.PrettyMIDI(midi_file).instruments
        for note in instrument.notes
    ]
    start_time = min(note[1] for note in notes)
    end_time = max(note[2] for note in notes)
    pr = np.zeros((128, int((end_time - start_time) * 100)))
    for pitch, start, end in notes:
        start = int((start - start_time) * 100)
        end = int((end - start_time) * 100)
        pr[pitch, start:end] = 1
    return pr, start_time, end_time


fig, ax = plt.subplots(figsize=(20, 5))
pr_seg, start_time, end_time = get_piano_roll(midi_recording)
ax.imshow(pr_seg, aspect="auto", origin="lower", cmap="gray")
# for b in beats1[np.logical_and(beats1 >= start_time, beats1 <= end_time)]:
#     ax.axvline(x=(b - start_time) * 100, ymin=0, ymax=0.5, color="g")
for b in beats2[np.logical_and(beats2 >= start_time, beats2 <= end_time)]:
    ax.axvline(x=(b - start_time) * 100, ymin=0.5, ymax=1, color="r")
ax.set_xbound(0, 6000)

In [None]:
def move_by_pow_of_2(x, target):
    l2x, l2_target = np.log2(x), np.log2(target)
    offset = np.round(l2_target - l2x)
    return 2 ** (l2x + offset), offset


def bpm_over_time(
    beat_lens: np.ndarray,
    anchor_bpm: float,
    bin_precision: float = 0.05,
    wsize: int = 10,
    fill_nan: bool = True,
):
    def bpm_from_mode(n_samples, bins):
        mode_idx = n_samples.argmax()
        if n_samples[mode_idx] < sum(n_samples) / 2:
            return None
        mode_min, mode_max = bins[mode_idx], bins[mode_idx + 1]
        mode_samples = samples[np.logical_and(samples >= mode_min, samples <= mode_max)]
        return np.mean(mode_samples)

    ret = []
    for idx in range(wsize, len(beat_lens)):
        samples = 60 / beat_lens[idx - wsize : idx]
        bin_min, bin_max = np.log2(samples.min() - 0.01), np.log2(samples.max() + 0.01)
        bin_width = np.log2(1 + bin_precision)
        bins = 2 ** np.arange(bin_min, bin_max + bin_width, bin_width)
        n_samples, _ = np.histogram(samples, bins)
        window_bpm = bpm_from_mode(n_samples, bins)
        if window_bpm is None:
            # print(f"! BPM in window diverged: {samples}")
            ret.append(np.nan if fill_nan else ret[-1] if ret else anchor_bpm)
            continue
        ret.append(move_by_pow_of_2(window_bpm, anchor_bpm)[0])
    prefix = [np.nan if fill_nan else ret[0]] * wsize
    return np.array(prefix + ret)


xs = beats2
beat_lens = np.diff(xs, prepend=start_time)
bpms = bpm_over_time(beat_lens, 120, 0.1, wsize=10, fill_nan=True)
pow_offsets = [
    move_by_pow_of_2(60 / beat_len, bpm)[1] for beat_len, bpm in zip(beat_lens, bpms)
]
plt.plot(xs, bpms, label="smoothed")
# plt.plot(xs, pow_offsets, label='beat_pow_offset')
# plt.plot(xs, 60 / np.diff(xs, prepend=0), label='original')
plt.legend()
plt.show()

In [None]:
from fractions import Fraction

df_ = df.copy()


def remap_note_lens(onsets, note_lens, pow_offsets):
    def to_frac(xs):
        return xs.apply(lambda x: Fraction(x).limit_denominator(32))

    def modf(xs):
        decimal, integer = np.modf(xs)
        return to_frac(decimal), integer.astype(int)

    beat_lens = np.array([Fraction(2**offset) for offset in pow_offsets] + [0])
    beat_remap = np.cumsum(np.concatenate([[0], beat_lens]))
    onset_decimal, note_beat_idx = modf(onsets)
    onsets = beat_remap[note_beat_idx] + onset_decimal
    note_lens = to_frac(note_lens) * beat_lens[note_beat_idx]
    return onsets, note_lens


df_["onset_positions"], df_["note_values"] = remap_note_lens(
    df_["onset_positions"], df_["note_values"], pow_offsets
)
df_.sort_values("onset_positions").head(20)