In [5]:
import pretty_midi
import os
import h5py
import random
import numpy as np
from midi_preprocessing import *

In [6]:
COLS_PER_BAR = 16

In [7]:

def split_midi_by_bars(pm: pretty_midi.PrettyMIDI, bars_per_segment: int) -> list:
    """
    Splits a PrettyMIDI object into segments of a specified number of bars.

    Args:
        pm: The PrettyMIDI object to split.
        bars_per_segment: The number of bars each segment should contain.

    Returns:
        A list of new PrettyMIDI objects, each containing one segment.
    """
    # Get the timestamps of the start of each bar
    downbeats = pm.get_downbeats()

    # Add the final end time to the list of downbeats to capture the last segment
    downbeats = np.append(downbeats, pm.get_end_time())

    segments = []
    for i in range(0, len(downbeats) - 1, bars_per_segment):
        # Define the start and end time for the current segment
        start_time = downbeats[i]
        end_time_index = min(i + bars_per_segment, len(downbeats) - 1)
        end_time = downbeats[end_time_index]
        
        # Skip if the segment is empty
        if start_time >= end_time:
            continue

        # Create a new PrettyMIDI object for the segment
        segment_pm = pretty_midi.PrettyMIDI()

        # Iterate through the instruments of the original file
        for instrument in pm.instruments:
            # Create a new instrument for the segment
            segment_instrument = pretty_midi.Instrument(program=instrument.program, is_drum=instrument.is_drum, name=instrument.name)

            # Copy notes that fall within the segment's time range
            for note in instrument.notes:
                if start_time <= note.start < end_time:
                    # Create a new Note object with adjusted timing
                    new_note = pretty_midi.Note(
                        velocity=note.velocity,
                        pitch=note.pitch,
                        start=note.start - start_time, # Make time relative to segment start
                        end=note.end - start_time
                    )
                    segment_instrument.notes.append(new_note)
            
            # Add the new instrument to the segment's PrettyMIDI object
            if segment_instrument.notes:
                segment_pm.instruments.append(segment_instrument)
        
        if segment_pm.instruments:
            segments.append(segment_pm)
            
    return segments

def midi_splitter(midi, fs = 8, splits_dim = 8):
        melody_roll = midi.get_piano_roll(fs = fs)    
        # Split into samples of size 128xCOLS_PER_BAR.
        splitted = []
        splits = int(melody_roll.shape[1] / (COLS_PER_BAR * splits_dim))
        for i in range(splits):
            tmp = i * COLS_PER_BAR * splits_dim
            sample = melody_roll[:, tmp:tmp + COLS_PER_BAR * splits_dim]
            splitted.append(piano_roll_to_pretty_midi(sample, fs))

        return splitted

In [8]:

### GLOBAL VARIABLES ###
# Define the number of cols per bar.
COLS_PER_BAR = 16
# Define the directory of the dataset.
INPUT_DIR_PATH =[
    "./raw/maestro-v3.0.0/2018/",
#    "./raw/maestro-v3.0.0/2017/",
#    "./raw/maestro-v3.0.0/2015/",
#    "./raw/maestro-v3.0.0/2014/"
]

# Define the output directory path.
OUT_DIR_PATH = "./raw/maestro-v3.0.0/splitted_2018/"

midi_data = []
midi_splits = []
counter = 0
for dir in INPUT_DIR_PATH:
    for filename in os.listdir(dir):
        counter += 1
        # Open the midi file.
        midi = pretty_midi.PrettyMIDI(os.path.join(dir, filename))
        dim = midi.get_piano_roll().shape[1]
        midi_splits = midi_splitter(midi, 8, 8)
        print(dim, " : ", len(midi_splits))
        for i in range(len(midi_splits)):
            midi_data.append(midi_splits[i])

58224  :  36
70398  :  43
83463  :  52
109166  :  68
160876  :  100
52864  :  33
111878  :  69
33037  :  20
63386  :  39
161182  :  100
70375  :  43
245793  :  153
79511  :  49
44310  :  27
34968  :  21
125406  :  78
203325  :  127
25650  :  16
16566  :  10
80346  :  50
70149  :  43
64512  :  40
107146  :  66
152819  :  95
164111  :  102
73361  :  45
117921  :  73
36042  :  22
84017  :  52
52089  :  32
24932  :  15
159331  :  99
152201  :  95
72780  :  45
25165  :  15
31322  :  19
29571  :  18
106902  :  66
189719  :  118
130789  :  81
68555  :  42
14140  :  8
63715  :  39
102381  :  63
23549  :  14
99133  :  61
24412  :  15
70056  :  43
84826  :  53
209805  :  131
45618  :  28
76026  :  47
97689  :  61
170061  :  106
111984  :  69
155722  :  97
135856  :  84
132491  :  82
145139  :  90
146129  :  91
94538  :  59
112083  :  70
29407  :  18
182334  :  113
80998  :  50
105461  :  65
62305  :  38
190631  :  119
137190  :  85
20640  :  12
30586  :  19
27070  :  16
21659  :  13
19476  :  12

In [9]:
print(len(midi_data))
print(type(midi_data))

6125
<class 'list'>


In [16]:
random.shuffle(midi_data)

<class 'NoneType'>


In [17]:
for i in range(480):
    if midi_data[i].get_piano_roll(fs = 8).shape[1] / 16 != 8:
        continue
    
    print(midi_data[i].get_piano_roll(fs = 8).shape[1] / 16  )
    midi_data[i].write(OUT_DIR_PATH + "file_" + str(i) + ".midi")

8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
