In [None]:
def crop_and_shift_midi(midi_path, midi_output, duration, song):
    # load MIDI file
    midi_data = pretty_midi.PrettyMIDI(midi_path)
    # get start time and instruments to keep
    start_time, tokeep_instruments, initial_tempo = HWD_DICT[song]    
    # compute end time
    end_time = start_time + duration
    # new midi object to store the cropped and modified MIDI

    cropped_midi = pretty_midi.PrettyMIDI(initial_tempo=initial_tempo)
    
    for idx, instrument in enumerate(midi_data.instruments):
        if idx in tokeep_instruments:
        # new instrument in the cropped MIDI
            new_instrument = pretty_midi.Instrument(program=instrument.program)
            
            for note in instrument.notes:
                if note.start >= start_time and note.end <= end_time:

                    # No more sax notes after 15 seconds in Hakuna
                    if song == 'Hakuna' and idx == 7 and note.start > start_time+15:
                        continue

                    # Shift temporale delle note
                    new_start = note.start - start_time
                    new_end = note.end - start_time
                    # Crea una nuova nota con i tempi shiftati
                    new_note = pretty_midi.Note(
                        pitch=note.pitch,
                        start=new_start,
                        end=new_end,
                        velocity=note.velocity
                    )
                    new_instrument.notes.append(new_note)
                
            # Aggiungi lo strumento al nuovo file MIDI
            cropped_midi.instruments.append(new_instrument)

    # Salva il nuovo file MIDI
    cropped_midi.write(midi_output)
    return

import numpy as np 
import pretty_midi
import librosa
from basic_pitch.constants import (
    ANNOTATION_HOP,
    ANNOTATIONS_BASE_FREQUENCY,
    CONTOURS_BINS_PER_SEMITONE,
    NOTES_BINS_PER_SEMITONE,
)


def create_onset(midi_path):

    # Load MIDI file
    midi_data = pretty_midi.PrettyMIDI(midi_path)
    onsets_indices = np.empty((0, 2))

    # Get onsets indices
    for instrument in midi_data.instruments:
        for note in instrument.notes:
            row = [note.start, note.pitch]
            onsets_indices = np.vstack((onsets_indices, row))
    # Translate from time to frame index
    onsets_indices[:,0] = np.round(onsets_indices[:,0] / ANNOTATION_HOP)

    # Translate from MIDI pitch to frequency bin
    onsets_indices[:,1] = librosa.midi_to_hz(onsets_indices[:,1])
    onsets_indices[:,1] = 12.0 * NOTES_BINS_PER_SEMITONE * np.log2(onsets_indices[:,1] / ANNOTATIONS_BASE_FREQUENCY)
    onsets_indices[:,1] = onsets_indices[:,1]

    # Round to the nearest integer (they are indices)
    onsets_indices = onsets_indices.astype(int)
    # Create onset values
    onset_values = np.ones(onsets_indices.shape[0])
    return onsets_indices, onset_values

def create_notes(midi_path):

    # Load MIDI file
    midi_data = pretty_midi.PrettyMIDI(midi_path)
    note_indices = np.empty((0, 2))

    for instrument in midi_data.instruments:
        for note in instrument.notes:
            start = np.round(note.start / ANNOTATION_HOP)
            end = np.round(note.end / ANNOTATION_HOP)
            # Duration in frames
            duration = int(end - start)
            # Translate from MIDI pitch to frequency bin
            note_bin = librosa.midi_to_hz(note.pitch)
            note_bin = 12.0 * NOTES_BINS_PER_SEMITONE * np.log2(note_bin / ANNOTATIONS_BASE_FREQUENCY)

            for i in range(duration):
                row = [start + i, note_bin]
                note_indices = np.vstack((note_indices, row))

    # Round to the nearest integer (they are indices)
    note_indices = note_indices.astype(int)

    # Create note values
    note_values = np.ones(note_indices.shape[0])
    return note_indices, note_values

def create_contour(midi_path):
    # Load MIDI file
    midi_data = pretty_midi.PrettyMIDI(midi_path)
    note_indices = np.empty((0, 2))

    for instrument in midi_data.instruments:
        for note in instrument.notes:
            start = np.round(note.start / ANNOTATION_HOP)
            end = np.round(note.end / ANNOTATION_HOP)
            # Duration in frames
            duration = int(end - start)
            # Translate from MIDI pitch to frequency bin
            note_bin = librosa.midi_to_hz(note.pitch)
            note_bin = 12.0 * CONTOURS_BINS_PER_SEMITONE * np.log2(note_bin / ANNOTATIONS_BASE_FREQUENCY)

            for i in range(duration):
                row = [start + i, note_bin]
                note_indices = np.vstack((note_indices, row))

    # Round to the nearest integer (they are indices)
    note_indices = note_indices.astype(int)

    # Create contour values
    note_values = np.ones(note_indices.shape[0])
    return note_indices, note_values

In [39]:
#print(note_indices, note_values)
#print(onset_indices, onset_values)
import os
import pandas as pd
import warnings
from tqdm import tqdm

# Ignora il warning specifico di pretty_midi
warnings.filterwarnings("ignore", category=RuntimeWarning, module="pretty_midi")


track_ids = [] 
df = pd.read_csv(os.path.join(os.path.expanduser('~'), 'mir_datasets/hwd/MLEndHWD_Audio_Attributes.csv'))
for i in range(len(df)):
    track_ids.append(f'{df.loc[i,"Song"]}_{df.loc[i,"Interpretation"]}_{df.loc[i,"Interpreter"]}_{df.loc[i,"Public filename"].removesuffix(".wav")}')

HWD_DIR = 'mir_datasets/hwd'
for track_id in tqdm(track_ids):
    attr_path = os.path.join(os.path.expanduser('~'),HWD_DIR,f'MLEndHWD_{track_id[:track_id.find("_")]}_Audio_Files', f'{track_id[-4:]}.wav')
    try:
        duration = sox.file_info.duration(attr_path)
    except:
        print(f'Error with {track_id}')
        continue
    time_scale = np.arange(0, duration + ANNOTATION_HOP, ANNOTATION_HOP)
    n_time_frames = len(time_scale)

    crop_and_shift_midi(f'/home/seraf/mir_datasets/hwd/MIDI/{track_id[:track_id.find("_")]}.mid', f'/home/seraf/{track_id[:track_id.find("_")]}.mid', duration=duration, song=f'{track_id[:track_id.find("_")]}')

    note_indices, note_values = create_notes(f'/home/seraf/{track_id[:track_id.find("_")]}.mid')
    onset_indices, onset_values = create_onset(f'/home/seraf/{track_id[:track_id.find("_")]}.mid')
    contour_indices, contour_values = create_contour(f'/home/seraf/{track_id[:track_id.find("_")]}.mid')



 21%|██▏       | 1410/6611 [01:30<07:36, 11.39it/s]

Error with StarWars_Whistle_122_1408


 29%|██▉       | 1942/6611 [02:25<05:19, 14.60it/s]

Error with Panther_Hum_85_1940


100%|██████████| 6611/6611 [07:37<00:00, 14.46it/s]


In [None]:
#Error with StarWars_Whistle_122_1408
#Error with Panther_Hum_85_1940


In [None]:
class HWDFilterInvalidTracks(beam.DoFn):
    DOWNLOAD_ATTRIBUTES = ["audio_path", "midi_path"]

    def __init__(self, source: str):
        self.source = source

    def setup(self):
        import apache_beam as beam

        self.filesystem = beam.io.filesystems.FileSystems()

    def process(self, element: Tuple[str, str]):
        import tempfile

        import apache_beam as beam
        import sox
        import soundfile as sf

        from basic_pitch.constants import (
            AUDIO_N_CHANNELS,
            AUDIO_SAMPLE_RATE,
        )

        track_id, split = element
        if split == "omitted":
            return None
        print(f"Processing (track_id, split): ({track_id}, {split})")
        logging.info(f"Processing (track_id, split): ({track_id}, {split})")

        with tempfile.TemporaryDirectory() as local_tmp_dir:

            for attr in self.DOWNLOAD_ATTRIBUTES:
                if attr == "audio_path":
                    attr_path = os.path.join(HWD_DIR,f'MLEndHWD_{track_id[:track_id.find("_")]}_Audio_Files', f'{track_id[-4:]}.wav')
                    audio_path = attr_path
                if attr == "midi_path":
                    attr_path = os.path.join(HWD_DIR, 'MIDI', f'{track_id[:track_id.find("_")]}.mid')
                source = os.path.join(self.source, attr_path)
                dest = os.path.join(local_tmp_dir, attr_path)

                if not dest:
                    print(f"\n\n\n\nCould not find {attr} for {track_id}\n\n\n\n")
                    return None
                logging.info(f"Downloading {attr} from {source} to {dest}")
                os.makedirs(os.path.dirname(dest), exist_ok=True)
                with self.filesystem.open(source) as s, open(dest, "wb") as d:
                    d.write(s.read())

            local_wav_path = "{}_tmp.wav".format(os.path.join(local_tmp_dir, audio_path))
            tfm = sox.Transformer()
            tfm.rate(AUDIO_SAMPLE_RATE)
            tfm.channels(AUDIO_N_CHANNELS)
            try:
                tfm.build(os.path.join(local_tmp_dir, audio_path), local_wav_path)
            except Exception as e:
                logging.info(f"Could not process {local_wav_path}. Exception: {e}")
                print(f"\n\n\n\nCould not process {local_wav_path}. Exception: {e}\n\n\n\n")
                return None
            
            try:
                data, samplerate = sf.read(local_wav_path)    
                sf.write(local_wav_path, data, samplerate, subtype='PCM_16')
            except Exception as e:
                logging.info(f"Could not convert to PCM {local_wav_path}. Exception: {e}")
                print(f"\n\n\n\nCould not convert to PCM {local_wav_path}. Exception: {e}\n\n\n\n")
                return None
            
            ##If return None skip the track else return the track_id and split
            yield beam.pvalue.TaggedOutput(split, track_id)

