# Compose: Training a model to generate music

In [None]:
import os
import pickle
import numpy
import glob

from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.utils import plot_model
import tensorflow.keras

from mido import MidiFile ,MetaMessage, Message, MidiTrack

from models.LSTM_SELFTRY import get_distinct, create_lookups, prepare_sequences, create_network

## Set parameters

In [None]:
# run params
section = 'compose'
run_id = '1117'
music_name = 'midis'

run_folder = 'run/{}/'.format(section)
run_folder += '_'.join([run_id, music_name])


store_folder = os.path.join(run_folder, 'store')
data_folder = os.path.join('data', music_name)

if not os.path.exists(run_folder):
    ### EDITED #################
    # Failes if subdirectorys (1117/midis) missing 
    # os.mkdir(run_folder)
    # Creates all needed subdirectorys (1117/midis) 
    os.makedirs(run_folder)
    ############################
    os.mkdir(os.path.join(run_folder, 'store'))
    os.mkdir(os.path.join(run_folder, 'output'))
    os.mkdir(os.path.join(run_folder, 'weights'))
    os.mkdir(os.path.join(run_folder, 'viz'))
    


mode = 'build' # 'load' # 

# data params
#intervals = range(1)
seq_len = 64

# model params
embed_size = 200
rnn_units = 256
use_attention = True

## Extract the notes

In [None]:
# Searches all .mid Files in the given directory and its subdirectorys
def readFiles(directory):
    global music_list
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.endswith(".mid"):
                music_list.append(os.path.join(root, file))
                #print(os.path.join(root, file))
                    
#(Quelle: https://stackoverflow.com/questions/3964681/find-all-files-in-a-directory-with-extension-txt-in-python)

In [None]:
# # Returns a list of all the midi Files in the given data_folder and its subfolders
def get_Music_List(directory):
    global music_list
    music_list = []
    readFiles(directory)
    return music_list

In [None]:
music_list = get_Music_List(data_folder)
len(music_list)

In [None]:
isOnlyMonoTrackMode = True
monoTrackMidis = []

if mode == 'build':
    
    music_list = get_Music_List(data_folder)
    print(len(music_list), 'files in total')

    for i, file in enumerate(music_list):
        print(i+1, "Parsing %s" % file)
        midi_score = MidiFile(os.path.join(file))
        #print('Ticks per beat: ' + str(midi_score.ticks_per_beat))
        if (midi_score.type != 0):
            print("Skipped: %s", midi_score)
            continue
        print(midi_score)
        monoTrackMidis.append(midi_score)
    print(monoTrackMidis)

# Preprocessing Data for Tokenization

In [None]:
commands = []
channels = []
values1 = []
values2 = []
durations = []

#tempos = []
#time_signatures = []

for midi in monoTrackMidis:
    
    print('\n' + str(midi))
    #print(midi.ticks_per_beat)
    
    commands.append('START')
    channels.append(-1)
    values1.append(-1)
    values2.append(-1)
    durations.append(0)
    
    #commands.append('FILE')
    #channels.append(-1)
    #values1.append(midi.filename)
    #values2.append(midi.ticks_per_beat)
    #durations.append(0)
    
    # All midis in the List only have one Track so we just need to look at that one
    for i, msg in enumerate(midi.tracks[0]):
        if type(msg) == MetaMessage:
            if msg.type == 'set_tempo':
                commands.append(msg.type)
                channels.append(-1)
                values1.append(msg.tempo)
                values2.append(-1)
                durations.append(msg.time)
                continue
            elif msg.type == 'time_signature':
                commands.append(msg.type)
                channels.append(-1)
                values1.append(str(msg.numerator)+'/'+str(msg.denominator))
                values2.append(str(msg.clocks_per_click)+'/'+ str(msg.notated_32nd_notes_per_beat))
                durations.append(msg.time)
                continue
            elif msg.type == 'key_signature':
                #commands.append(msg.type)
                #channels.append(-1)
                #values1.append(msg.key)
                #values2.append(-1)
                #durations.append(msg.time)
                # Adds the time to the next message to not get timing wrong
                #nextMessage = midi.tracks[0][i+1]
                #nextMessage.time += msg.time
                #print(nextMessage)
                continue
            elif msg.type == 'end_of_track':
                commands.append(msg.type)
                channels.append(-1)
                values1.append(-1)
                values2.append(-1)
                durations.append(msg.time)
                continue
            elif msg.type == 'lyrics' or msg.type == 'track_name' or msg.type == 'copyright' or msg.type == 'instrument_name' or msg.type == 'marker' or msg.type == 'text' or msg.type == 'smpte_offset' or msg.type == 'channel_prefix':
                # Adds the time to the next message to not get timing wrong
                #nextMessage = midi.tracks[0][i+1]
                #nextMessage.time += msg.time
                #print(nextMessage)
                continue
            else:
                print(msg)
        elif type(msg) == Message:
            if msg.type == 'sysex':
                #commands.append(msg.type)
                #channels.append(-1)
                #values1.append(str(msg.data))
                #values2.append(-1)
                #durations.append(msg.time)
                #nextMessage = midi.tracks[0][i+1]
                #nextMessage.time += msg.time
                #print(nextMessage)
                continue
                
            commands.append(msg.type)
            channels.append(msg.channel)
            durations.append(msg.time)
            
            if msg.type == 'note_on' or msg.type == 'note_off':
                values1.append(msg.note)
                values2.append(msg.velocity)
            elif msg.type == 'control_change':
                values1.append(msg.control)
                values2.append(msg.value)
            elif msg.type == 'program_change':
                values1.append(msg.program)
                values2.append(-1)
            elif msg.type == 'pitchwheel':
                values1.append(msg.pitch)
                values2.append(-1)
            else:
                print(msg)


## Create the lookup tables

In [None]:
# get the distinct sets of all input data
command_names, n_commands = get_distinct(commands)
channel_names, n_channel = get_distinct(channels)
value1_names, n_value1 = get_distinct(list(map(str, values1)))
value2_names, n_value2 = get_distinct(list(map(str, values2)))
duration_names, n_durations = get_distinct(durations)

distincts = [command_names, n_commands, channel_names, n_channel, value1_names, n_value1, value2_names, n_value2, duration_names, n_durations]

with open(os.path.join(store_folder, 'distincts'), 'wb') as f:
    pickle.dump(distincts, f)

# make the lookup dictionaries for notes and dictionaries and save
command_to_int, int_to_command = create_lookups(command_names)
channel_to_int, int_to_channel = create_lookups(channel_names)
value1_to_int, int_to_value1 = create_lookups(value1_names)
value2_to_int, int_to_value2 = create_lookups(value2_names)
duration_to_int, int_to_duration = create_lookups(duration_names)

lookups = [command_to_int, int_to_command, channel_to_int, int_to_channel, value1_to_int, int_to_value1, value2_to_int, int_to_value2, duration_to_int, int_to_duration]

with open(os.path.join(store_folder, 'lookups'), 'wb') as f:
    pickle.dump(lookups, f)

In [None]:
print('\nvalue2_to_int')
command_to_int

In [None]:
print('\nduration_to_int')
duration_to_int

## Prepare the sequences used by the Neural Network

In [None]:
values1_strings = list(map(str, values1))
values2_strings = list(map(str, values2))
network_input, network_output = prepare_sequences(commands, channels, values1_strings, values2_strings, durations, lookups, distincts, seq_len)

In [None]:
print('#####  INPUT  ####')
print('command input')
print(network_input[0][0])
print('channel input')
print(network_input[1][0])
print('values1 input')
print(network_input[2][0])
print('values2 input')
print(network_input[3][0])
print('durations input')
print(network_input[4][0])

print('\n#####  OUTPUT  ####')
print('command output')
print(network_output[0][0])
print('channel output')
print(network_output[0][1])
print('values1 output')
print(network_output[0][2])
print('values2 output')
print(network_output[0][3])
print('durations output')
print(network_output[0][4])

## Create the structure of the neural network

In [None]:
model, att_model = create_network(n_commands, n_channel, n_value1, n_value2, n_durations, embed_size, rnn_units, use_attention)
model.summary()

## Train the neural network

In [None]:
weights_folder = os.path.join(run_folder, 'weights')
# model.load_weights(os.path.join(weights_folder, "weights.h5"))

In [None]:
weights_folder = os.path.join(run_folder, "weights")

checkpoint1 = ModelCheckpoint(
    os.path.join(weights_folder, "weights-improvement-{epoch:02d}-{loss:.4f}-bigger_custom.h5"),
    monitor='loss',
    verbose=0,
    save_best_only=True,
    mode='min'
)

checkpoint2 = ModelCheckpoint(
    os.path.join(weights_folder, "weights_custom.h5"),
    monitor='loss',
    verbose=0,
    save_best_only=True,
    mode='min'
)

early_stopping = EarlyStopping(
    monitor='loss'
    , restore_best_weights=True
    , patience = 10
)


callbacks_list = [
    checkpoint1
    , checkpoint2
    #, early_stopping
 ]

model.save_weights(os.path.join(weights_folder, "weights_custom.h5"))
model.fit(network_input, network_output
          , epochs=2000000, batch_size=128
          , validation_split = 0.2
          , callbacks=callbacks_list
          , shuffle=True)



In [None]:
model.save_weights(os.path.join(weights_folder, "weights_finished.h5"))