In [1]:
import midi_score

def get_midi(audio_path, output_midi_path):
    import torch
    import time
    from audio_midi import PianoTranscription, load_audio, sample_rate

    device = "cuda" if torch.cuda.is_available() else "cpu"
    (audio, _) = load_audio(audio_path, sr=sample_rate, mono=True)
    transcriptor = PianoTranscription(device=device)
    transcribe_time = time.time()
    transcriptor.transcribe(audio, output_midi_path)
    print("Transcribe time: {:.3f} s".format(time.time() - transcribe_time))
    return midi_score.read_note_sequence(output_midi_path)

def get_beats(audio_path):
    from BeatNet.BeatNet import BeatNet

    estimator = BeatNet(1, mode="offline", inference_model="DBN", plot=[], thread=False)
    return estimator.process(audio_path)

def get_keysig(midi_notes):
    key_sig_pro = midi_score.RNNKeySignatureProcessor()
    return key_sig_pro.process(midi_notes)

def get_hand_parts(midi_notes):
    hand_parts_pro = midi_score.RNNHandPartProcessor()
    return hand_parts_pro.process(midi_notes)


# midi = get_midi("example/heartgrace.mp3", None)
midi = midi_score.read_note_sequence("example/heartgrace.midi")
beats = get_beats("example/heartgrace.mp3")
key_change = get_keysig(midi)
hand_parts = get_hand_parts(midi)
midi_score.write_to_xml(beats, midi.numpy(), hand_parts.numpy(), key_change, "example/heartgrace.xml")