In [1]:
import IPython
from IPython.display import Image, Audio
from midi2audio import FluidSynth
from music21 import corpus, converter, instrument, note, stream, chord, duration
import matplotlib.pyplot as plt
import time

import os
import pickle

import midi2audio
import music21

import os
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from music21 import converter

In [2]:
import keras
from keras.callbacks import ModelCheckpoint, EarlyStopping

import os
import numpy as np
import glob

from keras.layers import LSTM, Input, Dropout, Dense, Activation, Embedding, Concatenate, Reshape, GlobalAveragePooling1D
from keras.layers import Flatten, RepeatVector, Permute, TimeDistributed
from keras.layers import Multiply, Lambda, Softmax
import keras.backend as K 
from keras.models import Model
from tensorflow.keras.optimizers import RMSprop

from keras.utils import to_categorical
import tensorflow as tf

In [None]:
# fs = FluidSynth(sound_font='/Users/nicholasbarsi-rhyne/.fluidsynt/FluidR3_GM.sf2')
fs = FluidSynth()
file = 'db2/bach/bach_846.mid'
fs.midi_to_audio(file, 'bach_846.wav')
IPython.display.Audio("bach_846.wav")

---------------------
# Extracting the data

It loops through the score and extracts the pitch and time of each note (and rest) into two lists. The entire chord is stored as a string, and individual notes in the chord are separated by dots. The male after the name of each note refers to the octave to which the note belongs.

## Defiing Helper Functions

In [20]:
def helper_function(x):
    return tf.reduce_sum(x, axis=1)

def get_music_list(data_folder):
    file_list = []
    
    # Walk through all subdirectories
    for root, _, _ in os.walk(data_folder):
        # Find all MIDI files in current directory
        midi_files = glob.glob(os.path.join(root, "*.mid"))
        # Add them to our list
        file_list.extend(midi_files)
    
    parser = converter
    return file_list, parser

def create_network(n_notes, n_durations, embed_size = 100, rnn_units = 256, use_attention = False):
    notes_in = Input(shape = (None,))
    durations_in = Input(shape = (None,))

    x1 = Embedding(n_notes, embed_size)(notes_in)
    x2 = Embedding(n_durations, embed_size)(durations_in) 
    x = Concatenate()([x1,x2])
    x = LSTM(rnn_units, return_sequences=True)(x)

    if use_attention:
        x = LSTM(rnn_units, return_sequences=True)(x)
        e = Dense(1, activation='tanh')(x)
        e = Reshape([-1])(e)
        alpha = Activation('softmax')(e)
        alpha_repeated = Permute([2, 1])(RepeatVector(rnn_units)(alpha))
        c = Multiply()([x, alpha_repeated])
        c = Lambda(helper_function)(c)    
    else:
        c = LSTM(rnn_units)(x)
                                    
    notes_out = Dense(n_notes, activation = 'softmax', name = 'pitch')(c)
    durations_out = Dense(n_durations, activation = 'softmax', name = 'duration')(c)
   
    model = Model([notes_in, durations_in], [notes_out, durations_out])    

    if use_attention:
        att_model = Model([notes_in, durations_in], alpha)
    else:
        att_model = None
        
    opti = RMSprop(learning_rate = 0.001)
    model.compile(loss=['categorical_crossentropy', 'categorical_crossentropy'], optimizer=opti)

    return model, att_model


def get_distinct(elements):
    # Get all pitch names
    element_names = sorted(set(elements))
    n_elements = len(element_names)
    return (element_names, n_elements)

def create_lookups(element_names):
    # create dictionary to map notes and durations to integers
    element_to_int = dict((element, number) for number, element in enumerate(element_names))
    int_to_element = dict((number, element) for number, element in enumerate(element_names))
    return (element_to_int, int_to_element)    

def prepare_sequences(notes, durations, lookups, distincts, seq_len =32):
    note_to_int, int_to_note, duration_to_int, int_to_duration = lookups
    note_names, n_notes, duration_names, n_durations = distincts

    notes_network_input = []
    notes_network_output = []
    durations_network_input = []
    durations_network_output = []

    # create input sequences and the corresponding outputs
    for i in range(len(notes) - seq_len):
        notes_sequence_in = notes[i:i + seq_len]
        notes_sequence_out = notes[i + seq_len]
        notes_network_input.append([note_to_int[char] for char in notes_sequence_in])
        notes_network_output.append(note_to_int[notes_sequence_out])

        durations_sequence_in = durations[i:i + seq_len]
        durations_sequence_out = durations[i + seq_len]
        durations_network_input.append([duration_to_int[char] for char in durations_sequence_in])
        durations_network_output.append(duration_to_int[durations_sequence_out])

    n_patterns = len(notes_network_input)

    # reshape the input into a format compatible with LSTM layers
    notes_network_input = np.reshape(notes_network_input, (n_patterns, seq_len))
    durations_network_input = np.reshape(durations_network_input, (n_patterns, seq_len))
    network_input = [notes_network_input, durations_network_input]

    notes_network_output = to_categorical(notes_network_output, num_classes=n_notes)
    durations_network_output = to_categorical(durations_network_output, num_classes=n_durations)
    network_output = [notes_network_output, durations_network_output]
    return (network_input, network_output)

def sample_with_temp(preds, temperature):
    if temperature == 0:
        return np.argmax(preds)
    else:
        preds = np.log(preds) / temperature
        exp_preds = np.exp(preds)
        preds = exp_preds / np.sum(exp_preds)
        return np.random.choice(len(preds), p=preds)


data_folder ='db2'

mode = 'build'

# data params
intervals = range(1)
seq_len = 32

# model params
embed_size = 100
rnn_units = 256
use_attention = True

In [None]:
from tqdm.notebook import tqdm

if mode == 'build':    
    music_list, parser = get_music_list(data_folder)
    print(len(music_list), 'files in total')

    notes = []
    durations = []
    pbar = tqdm(music_list, desc="Processing files")
    for i, file in enumerate(pbar):
        original_score = parser.parse(file).chordify()        
        for interval in intervals:
            score = original_score.transpose(interval)

            notes.extend(['START'] * seq_len)
            durations.extend([0]* seq_len)

            for element in score.flatten(): 
                if isinstance(element, note.Note):
                    if element.isRest:
                        notes.append(str(element.name))
                        if element.duration.quarterLength < .5 and element.duration.quarterLength != 0:
                            durations.append(1.0)    
                        else:
                            durations.append(element.duration.quarterLength)
                    else:
                        notes.append(str(element.name))
                        if element.duration.quarterLength < .5 and element.duration.quarterLength != 0:
                            durations.append(1.0) 
                        else:
                            durations.append(element.duration.quarterLength)
                        
                if isinstance(element, chord.Chord):
                    
                    if len(element.pitches) > 2:
                        # For complex chords with more than 3 notes, just use the first 3 notes
                        notes.append('.'.join(n.nameWithOctave for n in element.pitches[:2]))
                    else:
                        # For simpler chords (3 or fewer notes), keep the full representation
                        notes.append('.'.join(n.nameWithOctave for n in element.pitches))
                    if element.duration.quarterLength < .5 and element.duration.quarterLength != 0:
                            durations.append(1.0) 
                    else:
                        durations.append(element.duration.quarterLength)

------------------------------------------
# Embedding Note and Duration

To create a dataset for training the model, we first convert the pitch and tempo into integer values. It doesn't matter what these values are because we use an embedding layer to convert the integer to a vector.

In [15]:
# get the distinct sets of notes and durations
note_names, n_notes = get_distinct(notes)
duration_names, n_durations = get_distinct(durations)
distincts = [note_names, n_notes, duration_names, n_durations]

# make the lookup dictionaries for notes and dictionaries and save
note_to_int, int_to_note = create_lookups(note_names)
duration_to_int, int_to_duration = create_lookups(duration_names)
lookups = [note_to_int, int_to_note, duration_to_int, int_to_duration]


In [None]:
len(note_to_int), note_to_int

In [None]:
print('\nduration_to_int')
len(duration_to_int),duration_to_int

In [10]:
dataset = prepare_sequences(notes, durations, lookups, distincts, seq_len)

Divide the dataset by 32 notes to create the training set. Target is the next pitch and time signature in the sequence.

----------------------------------
# Modeling

In [18]:
network_input, network_output = prepare_sequences(notes, durations, lookups, distincts, seq_len)

In [None]:
print('pitch input')
print(network_input[0][0])
print('duration input')
print(network_input[1][0])
print('pitch target')
print(network_output[0][0])
print('duration target')
print(network_output[1][0])

In [None]:
model, att_model = create_network(n_notes, n_durations, embed_size, rnn_units, use_attention)
model.summary()

---------------------------------------------
# Training

In [None]:
model.fit(network_input, network_output
          , epochs=10, batch_size=64
          , validation_split = 0.2
          , shuffle=True
         )

----------------------------------------
# Predicting

In [30]:
# prediction params
notes_temp=0.9
duration_temp = 0.9
max_extra_notes = 210
max_seq_len = 32
seq_len = 32

notes = ['START']
durations = [0]

if seq_len is not None:
    notes = ['START'] * (seq_len - len(notes)) + notes
    durations = [0] * (seq_len - len(durations)) + durations

sequence_length = len(notes)

In [None]:
prediction_output = []
notes_input_sequence = []
durations_input_sequence = []
overall_preds = []

for n, d in zip(notes,durations):
    note_int = note_to_int[n]
    duration_int = duration_to_int[d]

    notes_input_sequence.append(note_int)
    durations_input_sequence.append(duration_int)

    prediction_output.append([n, d])

    if n != 'START':
        midi_note = note.Note(n)
        new_note = np.zeros(128)
        new_note[midi_note.pitch.midi] = 1
        overall_preds.append(new_note)

att_matrix = np.zeros(shape = (max_extra_notes+sequence_length, max_extra_notes))

for note_index in range(max_extra_notes):

    prediction_input = [
        np.array([notes_input_sequence])
        , np.array([durations_input_sequence])
       ]

    notes_prediction, durations_prediction = model.predict(prediction_input, verbose=0)
    if use_attention:
        att_prediction = att_model.predict(prediction_input, verbose=0)[0]
        att_matrix[(note_index-len(att_prediction)+sequence_length):(note_index+sequence_length), note_index] = att_prediction

    new_note = np.zeros(128)

    for idx, n_i in enumerate(notes_prediction[0]):
        try:
            note_name = int_to_note[idx]
            midi_note = note.Note(note_name)
            new_note[midi_note.pitch.midi] = n_i
        except:
            pass

    overall_preds.append(new_note)

    i1 = sample_with_temp(notes_prediction[0], notes_temp)
    i2 = sample_with_temp(durations_prediction[0], duration_temp)

    note_result = int_to_note[i1]
    duration_result = int_to_duration[i2]

    prediction_output.append([note_result, duration_result])

    notes_input_sequence.append(i1)
    durations_input_sequence.append(i2)

    if len(notes_input_sequence) > max_seq_len:
        notes_input_sequence = notes_input_sequence[1:]
        durations_input_sequence = durations_input_sequence[1:]

    if note_result == 'START':
        break

overall_preds = np.transpose(np.array(overall_preds))
print('Generated sequence of {} notes'.format(len(prediction_output)))

In [None]:
midi_stream = stream.Stream()

# create note and chord objects based on the values generated by the model
for pattern in prediction_output:
    note_pattern, duration_pattern = pattern
    # pattern is a chord
    if ('.' in note_pattern):
        notes_in_chord = note_pattern.split('.')
        chord_notes = []
        for current_note in notes_in_chord:
            new_note = note.Note(current_note)
            new_note.duration = duration.Duration(duration_pattern)
            new_note.storedInstrument = instrument.Violoncello()
            chord_notes.append(new_note)
        new_chord = chord.Chord(chord_notes)
        midi_stream.append(new_chord)
    elif note_pattern == 'rest':
    # pattern is a rest
        new_note = note.Rest()
        new_note.duration = duration.Duration(duration_pattern)
        new_note.storedInstrument = instrument.Violoncello()
        midi_stream.append(new_note)
    elif note_pattern != 'START':
    # pattern is a note
        new_note = note.Note(note_pattern)
        new_note.duration = duration.Duration(duration_pattern)
        new_note.storedInstrument = instrument.Violoncello()
        midi_stream.append(new_note)

midi_stream = midi_stream.chordify()
timestr = time.strftime("%Y%m%d-%H%M%S")
new_file = 'output-' + timestr + '.mid'
midi_stream.write('midi', new_file)

In [None]:
fs = FluidSynth()
fs.midi_to_audio(new_file, 'new_output.wav')

In [None]:
IPython.display.Audio("new_output.wav")

# Saving and Loading

<hr style="border: solid 3px blue;">

In [27]:
# Save the main model using the new .keras format
model.save('music_generation_model.keras')

# If you're using attention, save the attention model as well
if use_attention:
    att_model.save('music_generation_attention_model.keras')

In [28]:
from tensorflow.keras.models import load_model

# Load the main model
model = load_model('music_generation_model.keras', safe_mode=False, custom_objects={'helper_function': helper_function,'tf':tf})

In [None]:
model.summary()