### Set Up

#### Standard library imports

In [None]:
import os
import datetime
import pdb

#### Third party imports 

In [None]:
import tensorflow as tf

#### Local imports

In [None]:
import modules.batch as batch
import modules.midi_related as midi
import modules.preprocessing as prep
import modules.subclasses as sub

#### Setting relative directories

In [None]:
Epoch_number = 64
alpha_beta = "alpha_0.01_beta_1.0_"
Epoch = "epoch_" + str(Epoch_number)
Date  = "20211010"

In [None]:
Working_Directory = os.getcwd()
Project_Directory = os.path.abspath(os.path.join(Working_Directory,'..'))
Music_In_Directory = Project_Directory + "/data/chopin_midi/" 
Output_Directory = Project_Directory + "/outputs/"
Model_Directory = Output_Directory + "models/"
Checkpoint_Directory = Model_Directory + "ckpt/"
Checkpoint_Date_Directory = Checkpoint_Directory + Date + "/"
Checkpoint_Date_Epoch_Directory = Checkpoint_Date_Directory + alpha_beta + Epoch + "_model"
Music_Out_Directory = Output_Directory + "midi/"
Music_Out_Genereating_Directory = Music_Out_Directory + "generated/"

### Load model

In [None]:
load_path = Checkpoint_Date_Epoch_Directory
model = tf.keras.models.load_model(load_path)

### MIDI generation

#### Genereate new MIDI files from scratch or using primer using the trained model

In [None]:
# Music Generation from scratch or using a primer
primer = False #'chop2804'
num_notes = 88
n_bars = 4
batch_size_gen = 4
n_time_steps_per_sixtheenth = 3
max_sixteenth_index = 16 # overwritten by primer

if primer:
    primer = midi.midiToNoteStateMatrix(Music_In_Directory + primer + ".mid", 
                                        verbose = False, 
                                        verbose_ts = False) 
    sixteenth_index = [b[0][3] for b in  primer]
    max_sixteenth_index = max(sixteenth_index)
    n_time_steps_ber_bar = max_sixteenth_index * n_time_steps_per_sixtheenth
    num_timesteps =  n_time_steps_ber_bar * (n_bars // 2) 
    tmp_data, start_out = prep.createDataSet2(primer, 
                                              num_time_steps = num_timesteps + 1, 
                                              batch_size = batch_size_gen, 
                                              start_old = 0)
    for _, (_, y_train) in enumerate(tmp_data):
        notes_gen_initial = y_train
    name = 'primer'
else:
    n_time_steps_ber_bar = max_sixteenth_index * n_time_steps_per_sixtheenth
    num_timesteps =  n_time_steps_ber_bar * (n_bars // 2)  
    notes_gen_initial = tf.zeros((batch_size_gen, num_notes, num_timesteps, 3))
    beats_initial = [int(t / n_time_steps_per_sixtheenth) % int(n_time_steps_ber_bar / n_time_steps_per_sixtheenth) + 1 for t in range(num_timesteps)]
    beats_initial = tf.constant(value = beats_initial,shape = (1,1,num_timesteps,1),dtype = tf.float32)
    beats_initial = tf.tile(beats_initial, multiples=[batch_size_gen,num_notes,1,1])
    notes_gen_initial = tf.concat([notes_gen_initial, beats_initial], axis=3)
    name = 'from_scratch'
    
t_gen = n_bars * n_time_steps_ber_bar


In [None]:
# Initial States
note_state_matrix_gen = notes_gen_initial


# Generate note_state_matrix
for t in tf.range(t_gen):

    beat = int(t / n_time_steps_per_sixtheenth) % int(n_time_steps_ber_bar / n_time_steps_per_sixtheenth) + 1

    X  = prep.inputKernel(note_state_matrix_gen[:,:,-num_timesteps:,:])
    _ , y_pred_velocity_train, y_pred_note_train = model.predict_on_batch(X)
    new_note = tf.concat([y_pred_note_train[:,:,-1:,:], y_pred_velocity_train[:,:,-1:,:]], axis=-1)
    new_note_p   = new_note[:,:,:,0]
    new_note_a   = new_note[:,:,:,1] * new_note[:,:,:,0]
    new_note_vel = new_note[:,:,:,2] * new_note[:,:,:,0]
    new_note_beat = tf.cast(tf.fill((batch_size_gen, num_notes, 1),beat), dtype=tf.float32)
    new_note = tf.stack([new_note_p, new_note_a, new_note_vel, new_note_beat], axis=-1)
    note_state_matrix_gen = tf.concat([note_state_matrix_gen, new_note], axis=2)

In [None]:
note_state_matrix_gen.shape

In [None]:
current_time_str = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
for i in range(batch_size_gen):
    midi.generate_audio(note_state_matrix_gen[i:(i+1),:,:,:], 
                        Music_Out_Genereating_Directory + current_time_str[:-7] + '/',
                        'generated_batch_' + str(i) + '_' + alpha_beta + '_' + Epoch + '_' + name, 
                        sample=False,
                        verbose = False)

#### Take a look at the features input to the model

In [None]:
prep.noteRNNInputSummary(prep.inputKernel(note_state_matrix_gen[:,:,:,:])[1,30,2,:])