In [5]:
import pretty_midi
import os
import h5py
import random
import numpy as np

In [6]:

def split_midi_by_bars(pm: pretty_midi.PrettyMIDI, bars_per_segment: int) -> list:
    """
    Splits a PrettyMIDI object into segments of a specified number of bars.

    Args:
        pm: The PrettyMIDI object to split.
        bars_per_segment: The number of bars each segment should contain.

    Returns:
        A list of new PrettyMIDI objects, each containing one segment.
    """
    # Get the timestamps of the start of each bar
    downbeats = pm.get_downbeats()

    # Add the final end time to the list of downbeats to capture the last segment
    downbeats = np.append(downbeats, pm.get_end_time())

    segments = []
    for i in range(0, len(downbeats) - 1, bars_per_segment):
        # Define the start and end time for the current segment
        start_time = downbeats[i]
        end_time_index = min(i + bars_per_segment, len(downbeats) - 1)
        end_time = downbeats[end_time_index]
        
        # Skip if the segment is empty
        if start_time >= end_time:
            continue

        # Create a new PrettyMIDI object for the segment
        segment_pm = pretty_midi.PrettyMIDI()

        # Iterate through the instruments of the original file
        for instrument in pm.instruments:
            # Create a new instrument for the segment
            segment_instrument = pretty_midi.Instrument(program=instrument.program, is_drum=instrument.is_drum, name=instrument.name)

            # Copy notes that fall within the segment's time range
            for note in instrument.notes:
                if start_time <= note.start < end_time:
                    # Create a new Note object with adjusted timing
                    new_note = pretty_midi.Note(
                        velocity=note.velocity,
                        pitch=note.pitch,
                        start=note.start - start_time, # Make time relative to segment start
                        end=note.end - start_time
                    )
                    segment_instrument.notes.append(new_note)
            
            # Add the new instrument to the segment's PrettyMIDI object
            if segment_instrument.notes:
                segment_pm.instruments.append(segment_instrument)
        
        if segment_pm.instruments:
            segments.append(segment_pm)
            
    return segments

In [7]:

### GLOBAL VARIABLES ###
# Define the number of cols per bar.
COLS_PER_BAR = 16
# Define the directory of the dataset.
INPUT_DIR_PATH =[
    "./raw/maestro-v3.0.0/2018/",
#    "./raw/maestro-v3.0.0/2017/",
#    "./raw/maestro-v3.0.0/2015/",
#    "./raw/maestro-v3.0.0/2014/"
]

# Define the output directory path.
OUT_DIR_PATH = "./raw/maestro-v3.0.0/splitted_2018/"

midi_data = []
midi_splits = []
counter = 0
for dir in INPUT_DIR_PATH:
    for filename in os.listdir(dir):
        counter += 1
        # Open the midi file.
        midi = pretty_midi.PrettyMIDI(os.path.join(dir, filename))
        dim = midi.get_piano_roll().shape[1]
        midi_splits = split_midi_by_bars(midi, 50)
        for i in range(len(midi_splits)):
            midi_data.append(midi_splits[i])

In [9]:
print(len(midi_data))

1033


In [10]:
midi_saved = random.sample(midi_data, 80)

In [13]:
for i in range(80):
    midi_saved[i].write(OUT_DIR_PATH + "file_" + str(i) + ".midi")