In [None]:
import os
from pathlib import Path
from typing import List, Union, Optional, Tuple

import pretty_midi
from basic_pitch.inference import predict

# Classes for ml model

In [None]:
def trim(midi_data: pretty_midi.PrettyMIDI):
    first_note = midi_data.instruments[0].notes[0]
    start_time = first_note.start

    for instrument in midi_data.instruments:
        for i, note in enumerate(instrument.notes):
            note.start -= start_time
            note.end -= start_time

    return midi_data


class BasicPitcher:
    def __init__(self, default_path: Path):
        self.default_path = default_path

    def __call__(self, audio_file_path: Union[Path, str],
                 output_path: Optional[Union[Path, str]]) -> pretty_midi.PrettyMIDI:
        model_output, midi_data, note_events = predict(audio_file_path,
                                                       onset_threshold=0.6,
                                                       minimum_frequency=130.813,
                                                       maximum_frequency=1278.75)

        if output_path is None:
            output_path = self.default_path
        midi_data = trim(midi_data)
        midi_data.write(output_path)

        return midi_data


# Classes for sheet music generation

In [None]:

import music21.stream

from music21 import environment, stream, converter
from music21.note import Note

SETTINGS = {
    'musescoreDirectPNGPath': 'C:/Program Files/MuseScore 3/bin/MuseScore3.exe',
    'musicxmlPath': 'C:/Program Files/MuseScore 3/bin/MuseScore3.exe',
    'lilypondPath': 'C:/Program Files/LilyPond/usr/bin/lilypond.exe'
}


class SheetGenerator:
    def __init__(self, fractions: List[float], pause_fractions: List[float], default_path: Path):
        self.fractions = fractions
        self.pause_fractions = pause_fractions
        self.default_path = default_path
        self.env = environment.Environment
        us = environment.UserSettings()
        us_path = us.getSettingsPath()
        if not os.path.exists(us_path):
            us.create()

        us['musescoreDirectPNGPath'] = SETTINGS['musescoreDirectPNGPath']
        us['musicxmlPath'] = SETTINGS['musicxmlPath']
        us['lilypondPath'] = SETTINGS['lilypondPath']

    def get_notes_from_midi(self, midi_data: pretty_midi.PrettyMIDI) -> Tuple[List[pretty_midi.Note], float]:
        notes = []
        try:
            tempo = midi_data.estimate_tempo()
        except ValueError:
            tempo = 0
            return notes, tempo

        beats_per_second = tempo / 60
        avg_note_time = 1 / beats_per_second
        print(avg_note_time)
        for instrument in midi_data.instruments:
            for i, note in enumerate(instrument.notes):
                note_time = note.end - note.start
                # if note_time / avg_note_time < 0.1 and i > 0:
                #     if note.pitch == instrument.notes[i - 1].pitch:
                #         instrument.notes[i - 1].end = note.end
                #         continue
                notes.append(note)

                note_fraction = note_time / avg_note_time
                note_fraction = min(self.fractions, key=lambda x: abs(x - note_fraction))

        return notes, tempo

    def _preprocess_notes(self, notes: List[pretty_midi.Note], tempo: float) -> stream.Stream:
        m21_notes = []
        notes_in_one_sec = tempo / 60
        one_time = round(1 / notes_in_one_sec, 2)
        stream1 = stream.Stream()

        for i, _note in enumerate(notes):
            options = self.fractions
            pause_options = self.pause_fractions

            name = _note.pitch
            rest = None
            if i + 1 < len(notes):
                next_note = notes[i + 1]

                if next_note.start < _note.end:
                    _note.end = next_note.start
                pause_fraction = (next_note.start - _note.end) / one_time
                if pause_fraction > 0.7:
                    rest_fraction = min(pause_options, key=lambda x: abs(x - pause_fraction))
                    rest = music21.note.Rest(quarterLength=rest_fraction)

            note_time = _note.end - _note.start
            note_fraction = note_time / one_time

            note_fraction = min(options, key=lambda x: abs(x - note_fraction))
            m21_note = Note(name, quarterLength=note_fraction)

            m21_notes.append(m21_note)
            stream1.append(m21_note)

            if rest is not None:
                stream1.append(rest)

        return stream1

    @staticmethod
    def read_xml(xml_path: Path) -> stream.Stream:
        xml_score = converter.parse(xml_path)
        xml_stream = stream.Stream(xml_score.parts[0].flatten().notesAndRests)

        return xml_stream

    def __call__(self, midi_data: pretty_midi.PrettyMIDI, output_path: Optional[Union[Path, str]] = None) -> str:
        notes, tempo = self.get_notes_from_midi(midi_data)
        stream1 = self._preprocess_notes(notes, tempo)

        if output_path is None:
            output_path = self.default_path

        stream1.write('musicxml', fp=output_path)

        return str(output_path)


# Example of pipeline usage

In [None]:
audio_path = 'data/1.wav'
midi_path = 'data/1.mid'
sheet_path = 'data/1.xml'

basic_pitcher = BasicPitcher(Path(midi_path))
sheet_generator = SheetGenerator(fractions=[0.25, 0.5, 1, 2, 4], pause_fractions=[0.25, 0.5, 1, 2, 4],
                                 default_path=Path(sheet_path))

midi_data = basic_pitcher(audio_path, midi_path)
sheet_generator(midi_data, sheet_path)