Preprocessing functions that transform midi datafiles into one hot encoded tensors.



In [1]:
import music21 as m21
import os
import json
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch

import torch.backends.cudnn as cudnn
torch.cuda.empty_cache()
cudnn.benchmark = True  # Optimise for hardware

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(device)
"""
Preprocessing file used to transform piano midi files into one hot encoded tensors. Uses music21 library to translate the 
midi files into music21 objects where the note pitches and durations can be extracted. 
PyTorch is used to one hot encode and transform into output tensors. 
"""
# Constants for datapaths of each file and other key values.
PIANO_DATAPATH = "Dataset"
MUSIC_DURATION = 0.25
SAVE_MUSIC_PATH = "Processed_data"
SEQUENCE_LENGTH = 128
SINGLE_FILE_PATH = "single_file_dataset.txt"
SONG_MAPPING_PATH = "Song_Mapping.json"
LABEL_MAPPING_PATH = "Lable_Mapping.json"

acceptable_durations = [
    0.25, 
    0.5,
    0.75,
    1, 
    1.25,
    1.5, 
    1.75, 
    2, 
    2.25,
    2.5,
    2.75,
    3,
    3.25,
    3.5, 
    3.75,
    4
]


cpu


In [104]:
song = m21.converter.parse("Dataset/bach/bach_847.mid")
song_mel = song.getElementsByClass(m21.stream.Part)[0]

DURATION = 0.25


array = np.zeros((128,512))


index = 0

def piano_roll(cols, song):

    array = np.zeros((128, cols))
    index = 0
    # Iterates through every note in the song. 
    for symbol in song.flatten().notesAndRests:

        if isinstance(symbol, m21.note.Note):
            # Finds the midi symbol of the note
            note = int(symbol.pitch.midi)
            
            # Finds the duration of the note
            duration = symbol.duration.quarterLengthNoTuplets

            # Converts this to the step size. If duration = 0.5 this equates to two time steps. 
            steps = int(duration // DURATION)
            
            # Set the value of this midi note to 1 for the amount of time steps.
            array[note, index: index + steps + 1] = 1
            
            # Resets the index to the next time position.
            index += steps 
         
        elif index > cols:
            break
    
    return array
        
   
array = piano_roll(512, song)

print(type(song))


<class 'music21.stream.base.Score'>


In [113]:
def acceptable_notes(song):
    new_song = m21.stream.base.Score()
    lst = []
    for symbol in song.flatten().notesAndRests:
        lst.append(symbol.duration.quarterLength)
        if symbol.duration.quarterLength not in acceptable_durations:
            
            symbol.duration.quarterLength = 0.25

        elif isinstance(symbol, m21.note.Note) == False:
            
            continue

        new_song.append(symbol)
    print(set(lst))
    return new_song

new = acceptable_notes(song)
new.show("midi")

{0.25, 0.75, 1.5, 0.5, 1.75, 3.5, 3.25, 4.0, 1.0, 2.0, 2.25, 3.75, 1.25, 2.5, 3.0}


In [3]:
def load_data(data_path):
    """
    This function takes in a datapath and uses the os library to iterate through every file in the datapath directories
    and transforms them into music21 objects.

    :params data path containing the directory path to a folder of midi files
    :returns a list of music21 objects 
    """
    # Generate the labels for each composer.
    songs = []
    labels = []
    no_songs_in_dir = []
    
    # Iterates through each file in every directory from the input file path
    for path, subdir, filenames in os.walk(data_path):
        base = os.path.basename(path)
        labels.append(base)
        count = 0
        for file in filenames:
            
            # For every midi file, convert to a music21 object
            if file.endswith("mid"):
                song = m21.converter.parse(os.path.join(path, file))
                # Take just the melody of the song
                song = song.getElementsByClass(m21.stream.Part)
                songs.append(song[0])
                count += 1
        no_songs_in_dir.append(count)

        # Get rid of the 'Dataset' label as this doesn't represent anything meaningful
    labels = labels[1:]
    no_songs_in_dir = no_songs_in_dir[1:]

    return songs, labels, no_songs_in_dir

songs, labels, no_songs_in_dir = load_data(PIANO_DATAPATH)
print(no_songs_in_dir, "\n", labels)


[14, 3, 1, 29, 7, 9, 9, 48, 7, 3, 16, 21, 16, 15, 21, 8, 24, 12] 
 ['albeniz', 'bach', 'balakir', 'beeth', 'borodin', 'brahms', 'burgm', 'chopin', 'debussy', 'granados', 'grieg', 'haydn', 'liszt', 'mendelssohn', 'mozart', 'muss', 'schumann', 'tschai']


In [4]:
def encode_songs(song, time_step=0.125):
    """
    This function takes an input song and encodes it uses the following format:
    each note is represented using its relevant midi value.
    the duration of each note is extracted and added to the list. The time step is 0.25 or 1/4 of a beat.

    This is the representation we are encoding. The underscore represents the duration of the note. So in the example it would be
    midi note 60 for 1 beat and note 54 for half of a beat
    ['60', _, _, _, '54', _,]

    :params one song (music21 object) and predefined time step of 0.25 beats
    :returns the encoded melody for the input song
    """

    # For each song, need to grab the notes and rest, flatten the list and song object 
    encoded_melody = []
    
    for element in song.flat.notesAndRests:
        
         # handle notes
        if isinstance(element, m21.note.Note):
            symbol = element.pitch.midi # 60
        # handle rests
        elif isinstance(element, m21.note.Rest):
            symbol = "r"
        elif isinstance(element, m21.chord.Chord):
            continue
            

        # convert the note/rest into time series notation
        steps = int(element.duration.quarterLength / time_step)

        # Early stop if the end of the song is detected. Waits to see if there is a 4 beat rest and will break the loop
        if steps >= 16 and symbol == "r":
            break

        for step in range(steps):
            
            # if it's the first time we see a note/rest, let's encode it. Otherwise, it means we're carrying the same
            # symbol in a new time step
            if step == 0:
                encoded_melody.append(symbol)
            else:
                encoded_melody.append("_")

    # cast encoded song to str
    encoded_melody = " ".join(map(str, encoded_melody))
    

    return encoded_melody



In [5]:
def preprocess(data_path, songs, labels, no_songs_in_dir):
    """ 
    This function loads, encodes and saves the midi songs as the encoded version of the midi file.
    :params data path containing the input midi files
    :returns an output directory containing the encoded files.
    """
    # Load the songs, labels and the counter
    songs, labels, no_songs_in_dir = load_data(data_path)
    index = 0
    count = no_songs_in_dir[index]
    label = labels[index]
    
    for i, song in enumerate(songs):

        # Encode the song
        encoded_mel = encode_songs(song)
        #print(f"i: {i}, count: {count}")
        if i == count:
            if label != "tschai":
                
                index += 1
                count = sum(no_songs_in_dir[:index + 1])
                label = labels[index]
            
            print(label, count, i)
        

        # Create the filename from the directory and name for each song
        save_dir_mel = os.path.join(SAVE_MUSIC_PATH, label + str(i) + ".txt")

        # Make the directory if it doesn't exist
        os.makedirs(os.path.dirname(save_dir_mel), exist_ok=True)
        
        # Write to the new file
        with open(save_dir_mel, 'w') as fp:
            fp.write(encoded_mel)

    fp.close()

    return labels, no_songs_in_dir

labels, no_songs_in_dir = preprocess(PIANO_DATAPATH, songs, labels, no_songs_in_dir)


  return self.iter().getElementsByClass(classFilterList)


bach 17 14
balakir 18 17
beeth 47 18
borodin 54 47
brahms 63 54
burgm 72 63
chopin 120 72
debussy 127 120
granados 130 127
grieg 146 130
haydn 167 146
liszt 183 167
mendelssohn 198 183
mozart 219 198
muss 227 219
schumann 251 227
tschai 263 251


In [6]:
print(no_songs_in_dir)

[14, 3, 1, 29, 7, 9, 9, 48, 7, 3, 16, 21, 16, 15, 21, 8, 24, 12]


In [7]:
def load(file_path):
    """ 
    This function opens a file in read mode
    :params file path of a song
    :returns the contents of the file as a string
    """
    with open(file_path, "r") as fp:
        song = fp.read()
    return song

In [8]:
def create_single_file(dataset_path, single_file_path, sequence_length):
    """ 
    Combines each file into one big file. Each song separated by delimeters of length sequence length
    :params datapath directory, path to write the single file to and the length of each sequence
    :returns None
    """

    # Creates the blank song string and the number of delimiters for each sequence.
    new_song_delimiter = "/ " * sequence_length
    songs = ""
    
    # Load the data into a single string with the delimiter
    for path, subdir, files in os.walk(dataset_path):

        for file in files:
            # Loads the file from the saved music path. 
            file_path = os.path.join(path, file)

            song = load(file_path)
            
            songs = songs + "".join(song) + " " + new_song_delimiter

    # remove empty space from last character of string
    songs = songs[:-1]
    
    # save string that contains all the dataset
    with open(single_file_path, "w") as fp:
        fp.write(songs)

    return songs# remove empty space from last character of string

single_songs = create_single_file(SAVE_MUSIC_PATH, SINGLE_FILE_PATH, SEQUENCE_LENGTH)

In [9]:

def create_mapping(songs, mapping_path):
    """Creates a json file that maps the symbols in the song dataset onto integers

    :param songs (str): String with all songs
    :param mapping_path (str): Path where to save mapping
    :return:
    """
    mappings = {}

    # identify the vocabulary
    songs = songs.split()
    vocabulary = list(set(songs))
    
    # create mappings
    for i, symbol in enumerate(vocabulary):
        mappings[symbol] = i

    # save voabulary to a json file
    with open(mapping_path, "w") as fp:
        json.dump(mappings, fp, indent=4)

create_mapping(single_songs, SONG_MAPPING_PATH)


In [10]:
def labels_mapping(labels, path):
    """
    Generate a mapping of the labels
    """
    mappings = {}

    unique = list(set(labels))
    print(unique)
    for i, symbol in enumerate(unique):
        mappings[symbol] = i
    
    with open(path, "w") as fp:
        json.dump(mappings, fp, indent=4)
    
labels_mapping(labels, LABEL_MAPPING_PATH)

['haydn', 'chopin', 'brahms', 'grieg', 'liszt', 'mendelssohn', 'debussy', 'borodin', 'mozart', 'beeth', 'albeniz', 'schumann', 'bach', 'balakir', 'muss', 'tschai', 'burgm', 'granados']


In [11]:
def convert_labels_to_int(labels):
    with open(LABEL_MAPPING_PATH, "r") as fp:
        mappings = json.load(fp)

    label_ints = []

    for i, symbol in enumerate(mappings):
        label_ints.append(mappings[symbol])

    return label_ints

labels_as_ints = convert_labels_to_int(labels)



In [12]:

def convert_songs_to_int(songs, no_songs_in_dir):
    """ 
    Converts the symbol in the dictionary to the equivalent integer value. 
    Counts the number of symbols in each directory. We will use this to add labels to each sequence
    :params the list of music21 objects (songs)
    :returns a list of mapped integers
    """
    int_songs = []

    # load mappings
    with open(SONG_MAPPING_PATH, "r") as fp:
        mappings = json.load(fp)

    # transform songs string to list
    songs = songs.split()
    
    # map songs to int
    symbol_in_songs = []
    dash_count = 0
    prev = 0
    index = 0
    
    for i, symbol in enumerate(songs):
        int_songs.append(mappings[symbol])
        
        # If the symbol is a dash, count them
        if symbol == "/":
            dash_count += 1
           
        # If the dash count is the same as the number of songs - 1 * sequence length, we have reached the end of that 'class'
        # Add the symbol count into the list and reset the variables.
        if dash_count == (no_songs_in_dir[index]) * SEQUENCE_LENGTH:
            dash_count = 0
            count = i - prev
            prev = i
            symbol_in_songs.append(count)
            index += 1
            
    return int_songs, symbol_in_songs



int_songs, symbol_in_songs = convert_songs_to_int(single_songs, no_songs_in_dir)



In [13]:

def generate_training_sequences(sequence_length, symbol_in_songs, labels):
    """Create input and output data samples for training. Each sample is a sequence.

    :param sequence_length (int): Length of each sequence. With a quantisation at 16th notes, 64 notes equates to 4 bars

    :return inputs (ndarray): Training inputs
    :return targets (ndarray): Training targets
    """

    # load songs and map them to int
    songs = load(SINGLE_FILE_PATH)
    int_songs, symbol_in_songs = convert_songs_to_int(songs, no_songs_in_dir)
    print(symbol_in_songs)
    labels_as_ints = convert_labels_to_int(labels)

    inputs = []
    targets = []
    labels_seq = []
    idx = 0
    # generate the training sequences
    num_sequences = len(int_songs) - sequence_length
    new_label = labels_as_ints[idx]
    print(labels_as_ints)
    for i in range(num_sequences):
        # If the next value in the sequence is in the next data path, change the label

        if i == symbol_in_songs[idx] + 1:
            idx += 1
            new_label = labels_as_ints[idx]
            

        inputs.append(int_songs[i:i+sequence_length])
        targets.append(int_songs[i+sequence_length])
        labels_seq.append(new_label)

   
    # one-hot encode the sequences
    vocabulary_size = len(set(int_songs))
    inputs = torch.tensor(inputs)
    labels_seq = torch.tensor(labels_seq)

    # inputs size: (# of sequences, sequence length, vocabulary size)
    inputs = torch.nn.functional.one_hot(inputs, num_classes=vocabulary_size)
    labels_seq = torch.nn.functional.one_hot(labels_seq, num_classes=len(symbol_in_songs))
    targets = torch.tensor(targets)

    print(inputs.shape)
    print(targets.shape)
    print(labels_seq.shape)
    
    print(f"There are {len(inputs)} sequences.")

    return inputs, targets, labels

inputs, targets, gen_labels = generate_training_sequences(SEQUENCE_LENGTH, symbol_in_songs, labels)
                                  

[12616, 3388, 316, 26110, 3384, 4190, 7312, 36020, 4080, 2816, 12245, 17389, 11637, 9116, 25085, 1268, 17208, 5299]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
torch.Size([199352, 128, 71])
torch.Size([199352])
torch.Size([199352, 18])
There are 199352 sequences.


In [1]:
class MusicDataset(Dataset):
    def __init__(self, inputs, targets, labels):
        self.inputs = inputs
        self.targets = targets
        self.labels = labels

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):

        input = self.inputs[idx]
        target = self.targets[idx]
        label = self.labels[idx]

        return input, target, label

# Create dataset and dataloader
dataset = MusicDataset(inputs, targets, gen_labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


NameError: name 'Dataset' is not defined