In [1]:
import librosa
import torch

from muscribe import midi2score
from muscribe.audio2midi import PianoTranscription


def get_midi(audio_path, output_midi_path):
    import time

    device = "cuda" if torch.cuda.is_available() else "cpu"
    transcriptor = PianoTranscription(device=device)
    transcribe_time = time.time()
    transcriptor.transcribe(audio_path, output_midi_path)
    print("Transcribe time: {:.3f} s".format(time.time() - transcribe_time))
    if output_midi_path:
        return midi2score.read_midi_notes(output_midi_path)


@torch.no_grad()
def get_beats(audio_path: str, offset: float = 0.0, duration: float | None = None):
    import numpy as np
    from BeatNet.BeatNet import BDA, BeatNet
    from madmom.features.downbeats import DBNDownBeatTrackingProcessor

    beatnet = BeatNet(model=1)
    audio, _ = librosa.load(
        audio_path, sr=beatnet.sample_rate, offset=offset, duration=duration
    )
    feats = beatnet.proc.process_audio(audio).T
    feats = torch.from_numpy(feats).unsqueeze(0).to(beatnet.device)
    bn_model: BDA = beatnet.model  # type: ignore
    preds = bn_model.final_pred(bn_model(feats)[0])
    preds = preds.cpu().detach().numpy()
    beat_activ = np.transpose(preds[:2, :])

    db_tracker = DBNDownBeatTrackingProcessor(beats_per_bar=[4], fps=50)
    return db_tracker(beat_activ)  # Using DBN offline inference to infer beat/downbeats


def get_keysig(midi_notes):
    key_sig_pro = midi2score.RNNKeySignatureProcessor()
    return key_sig_pro.process(midi_notes)


def get_hand_parts(midi_notes):
    hand_parts_pro = midi2score.RNNHandPartProcessor()
    return hand_parts_pro.process(midi_notes)


midi = get_midi("example/sonatine.mp3", "example/sonatine.midi")
midi = midi2score.read_midi_notes("example/sonatine.midi", offset=60, duration=120)
beats = get_beats("example/sonatine.mp3", offset=60, duration=120)
key_change = get_keysig(midi)
hand_parts = get_hand_parts(midi)
builder = midi2score.MusicXMLBuilder(beats)
builder.add_notes(midi.numpy(), hand_parts.numpy())
builder.add_key_changes(key_change)
builder.infer_bpm_changes(diff_size=2, log_bin_size=0.03)
builder.render("example/sonatine.192+120.xml")