Pre-Processing midi data using piano roll time series format.

X axis represents time in steps of 0.25 beats.
Y axis represents the midi value of the note.

In [1]:
import music21 as m21
import os
import json
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch

import torch.backends.cudnn as cudnn
torch.cuda.empty_cache()
cudnn.benchmark = True  # Optimise for hardware

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(device)
"""
Preprocessing file used to transform piano midi files into one hot encoded tensors. Uses music21 library to translate the 
midi files into music21 objects where the note pitches and durations can be extracted. 
PyTorch is used to one hot encode and transform into output tensors. 
"""
# Constants for datapaths of each file and other key values.
PIANO_DATAPATH = "Dataset/bach"
TIME_STEP = 0.25
SAVE_MUSIC_PATH = "Processed_data"
SEQUENCE_LENGTH = 128
SINGLE_FILE_PATH = "single_file_dataset.txt"
SONG_MAPPING_PATH = "Song_Mapping.json"
LABEL_MAPPING_PATH = "Label_Mapping.json"
OUTPUT_FILES = "Output_files"
SIZE = 512

acceptable_durations = [
    0.25, 
    0.5,
    0.75,
    1, 
    1.25,
    1.5, 
    1.75, 
    2, 
    2.25,
    2.5,
    2.75,
    3,
    3.25,
    3.5, 
    3.75,
    4
]

cpu


In [2]:
def load_data(data_path):
    """
    This function takes in a datapath and uses the os library to iterate through every file in the datapath directories
    and transforms them into music21 objects.

    :params data path containing the directory path to a folder of midi files
    :returns a list of music21 objects 
    """
    # Generate the labels for each composer.
    songs = []
    labels = []
    no_songs_in_dir = []
    
    # Iterates through each file in every directory from the input file path
    for path, subdir, filenames in os.walk(data_path):
        base = os.path.basename(path)
        labels.append(base)
        count = 0
        for file in filenames:
            
            # For every midi file, convert to a music21 object
            if file.endswith("mid"):
                song = m21.converter.parse(os.path.join(path, file))
                # Take just the melody of the song
                #song = song.getElementsByClass(m21.stream.Part)
                songs.append(song)
                count += 1
        no_songs_in_dir.append(count)

        # Get rid of the 'Dataset' label as this doesn't represent anything meaningful
    if labels[0] == 'Dataset':
        labels = labels[1:]
        no_songs_in_dir = no_songs_in_dir[1:]

    return songs, labels, no_songs_in_dir

songs, labels, no_songs_in_dir = load_data(PIANO_DATAPATH)




In [3]:
def piano_roll(cols, song):
    """ 
    This function takes in a single song and column dimension and outputs the pianoroll 2D matrix that represents the midi
    notes at every time step.
    :params cols, song: column size of the matrix and the input song to convert
    :returns a 2D pianoroll matrix.
    """
    array = np.zeros((128, cols))
    index = 0
    # Iterates through every note in the song. 
    for symbol in song.flatten().notesAndRests:

        if isinstance(symbol, m21.note.Note):
            # Finds the midi symbol of the note
            note = int(symbol.pitch.midi)
            
            # Finds the duration of the note
            duration = symbol.duration.quarterLengthNoTuplets

            # Converts this to the step size. If duration = 0.5 this equates to two time steps. 
            steps = int(duration // TIME_STEP)
            
            # Set the value of this midi note to 1 for the amount of time steps.
            array[note, index: index + steps + 1] = 1
            
            # Resets the index to the next time position.
            index += steps 
         
        elif index > cols:
            break
    
    return array
        
   



In [4]:
def acceptable_note_durations(song):
    """ 
    Removes the notes that have unaccetpable durations. 
    :params song: the song to remove the ntoes from
    :returns the same song object with the ntoes removed.
    """
    # Creates a new song object
    new_song = m21.stream.base.Score()
    
    lst = []

    # Iterates through every symbol in the song
    for symbol in song.flatten().notesAndRests:
        
        # If the symbol has a duration that isn't acceptable, remove it from the song.
        if symbol.duration.quarterLength not in acceptable_durations:
            lst.append(symbol.duration.quarterLength)
            song.remove(symbol)
            continue
    
    return song



In [14]:

def convert_stream(matrix, format="midi", file_name='output.mid',filepath=OUTPUT_FILES, step_duration=TIME_STEP):
    """
    Converts the piano roll matrix back into a music 21 stream. Writes this stream to a midi file.
    :params matrix: 2D piano roll matrix
    :params format: format file type to write the stream
    :params file_name: the file name of the output file
    :params filepath: the output path of the directory holding the output files
    :params step_duration: the size of the step on the x axis of the piano roll matrix
    :returns None.
    """
    
    # Find the shape of the input matrix
    rows, cols = matrix.shape

    nulls = np.zeros((rows, 1))
    matrix = np.hstack((matrix, nulls))
    # Create two dictionaries. The first holds the notes that are on. The second holds each 'finished' note and its offset
    active_notes = {}
    note_list = {}

    # Iterates through every member in the matrix
    for col in range(cols - 1):
        for row in range(rows - 1):
            # Finds the midi pitch and creates a new note to represent the pitch and duration
            midi_pitch = row
            note = m21.note.Note(midi_pitch)
            note.quarterLength = step_duration

            # If this note is 'on':
            if matrix[row, col] == 1:
                
                # Checks if the note has already been turned on, or is active.
                if midi_pitch in active_notes:
                    
                    # If already active, updates the step duration of the note in the dictionary
                    lst = active_notes[midi_pitch]
                    lst[0] = lst[0] + step_duration
                    active_notes[midi_pitch] = lst

                # If newly activated, then adds the note duration and offset items to the midi pitch key in the dictionary
                else:
                    note.offset = col * step_duration
                    active_notes[midi_pitch] = [note.quarterLength, note.offset]

            # If the member is off but still in acitve notes, creates a new note and removes it from the dictionary
            elif midi_pitch in active_notes:
                # Grabs the duration and offset of the note and creates a new note object with duraiton, offset, midi pitch attributes
                lst = active_notes[midi_pitch]
                note = m21.note.Note( midi_pitch)
                note.quarterLength = lst[0]
                note.offset = lst[1]
                # Adds this note to the note dictionary based off of the offset
                note_list[note.offset] = note
                del active_notes[midi_pitch]
                    
    # Creates a new stream and grabs the keys (offsets) and values (note onjects) from the note list dictionary
    new_stream = m21.stream.Stream()
    keys = list(note_list.keys())
    notes = list(note_list.values())

    # Iterates through every item in the dictionary
    for i in range(len(note_list)):
        # Inserts the note based off of its offset
        new_stream.insert(keys[i], notes[i])

    # Creates the filepath for the output file
    path = os.path.join(filepath, file_name)

    # Makes the directory if it doesn't exist
    os.makedirs(filepath, exist_ok=True)

    # Writes the stream as a midi file to the path
    new_stream.write(format, fp=path)
 

convert_stream(piano_roll(SIZE, songs[1]))

In [11]:
def labels_mapping(labels, path):
    """
    Generate a mapping of the labels
    """
    mappings = {}

    # Finds the unique elements in the list
    unique = list(set(labels))
    
    # Sets an integer value for every label
    for i, symbol in enumerate(unique):
        mappings[symbol] = i
    # Opens the dicitonary in a new .json file
    with open(path, "w") as fp:
        json.dump(mappings, fp, indent=4)
    


In [13]:
def convert_labels_to_int(labels):
    """ 
    Converts the labels to an integer list
    :returns label_ints: list of integers representing every label
    """
    with open(LABEL_MAPPING_PATH, "r") as fp:
        mappings = json.load(fp)

    label_ints = []

    for i, symbol in enumerate(mappings):
        label_ints.append(mappings[symbol])

    return label_ints



In [12]:
def preprocess(data_path, songs, labels, no_songs_in_dir, size=SIZE):
    """ 
    This function loads, encodes and saves the midi songs as the encoded version of the midi file.
    :params data path containing the input midi files
    :returns an output directory containing the encoded files.
    """
    # Load the songs, labels and the counter
    songs, labels, no_songs_in_dir = load_data(data_path)
    index = 0
    count = no_songs_in_dir[index]
    
    # Initialize the label mapping dictionaries
    labels_mapping(labels, LABEL_MAPPING_PATH)
    labels_as_ints = convert_labels_to_int(labels)
    label = labels_as_ints[index]
    
    # Create empty label/matrix lists
    label_list = []
    matrix_list = []

    for i, song in enumerate(songs):

        # Filter out unaccetable durations for each song and create a piano roll matrix
        song = acceptable_note_durations(song)
        matrix = piano_roll(size, song)
        
        # Change the label if appropriate
        if i == count:
            if label != labels_as_ints[-1]:
                index += 1
                count = sum(no_songs_in_dir[:index + 1])
                label = labels_as_ints[index]
                
        # Append to the lists   
        label_list.append(label)
        matrix_list.append(matrix)

    # Convert to numpy arrays before creating tensors
    input_array = np.array(matrix_list)
    label_array = np.array(label_list)

    inputs = torch.tensor(input_array)
    labels = torch.tensor(label_array)
    
    print("Input tensor dimensions: ", inputs.shape)
    print("Label tensor dimensions: ", labels.shape)
    return inputs, labels

inputs, labels = preprocess(PIANO_DATAPATH, songs, labels, no_songs_in_dir)


[0]
Input tensor dimensions:  torch.Size([3, 128, 512])
Label tensor dimensions:  torch.Size([3])
