# Music generation notebook

In this notebook we provide minimal replication guidelines of the music generation part of our work. 
The main code to replicate all of our plots and results sits in the highly unstructured music_generation.ipynb.

First, we need to import all relevant packages (virtual environment needed for all of this to work can be provided upon request):

In [None]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"]="-1"  # uncomment this if you dont want to use GPU

import pretty_midi
import midi
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Model
from keras.layers import Dense, Input, Lambda, Concatenate, LSTM
from keras.optimizers import Adam
from keras import backend as K

import copy

import tensorflow as tf
import numpy as np

import matplotlib.pyplot as plt
import csv

import sys
from sys import stdout
import random

import librosa.display
import pypianoroll

import scipy.stats as st
from os import path
import pickle


################################### Our code

from loading import *
from models import *
from data import *
from midi_to_statematrix import *

%matplotlib inline

In [None]:
print("TensorFlow version: {}".format(tf.__version__))
print("GPU is available: {}".format(tf.test.is_gpu_available()))

The above should print:

TensorFlow version: 2.0.0 \
GPU is available: True

## Training model

To train our best model (bi-axial LSTM for both encoder and decoder) type this in bash console (takes around 50h to train):

```{bash}
python3 train_biaxial_long.py -lr 0.001 -bs 64 
```

For now the encoder output size is fixed to 32, which can be easily changed in the train_biaxial_long.py script. 

## Generating music

Here we give our generation process in one jupyter notebook cell. It involves the model going one timestep at a time and predicting the entire sequence in the target patch. We provide flexibility in the high level parameters of the generation process, which are mostly connected in translating probabilities from the model into actually played notes:

In [2]:
##################### GENERATION PARAMETERS #####################

my_model_name = "biaxial_pn_encoder_concat_deeplstm_cont.h5" # name of model in .h5 format
foldername    = 'experiment_switch_order3'                   # folder where to save the output of generation

# data 
what_type = 'test'    # can be train or test
train_tms = 40        # length of the context in timesteps
test_tms = 20         # length of the target in timesteps
batch_size = 64       # size of the batch (we will generate batch_size patches)
songs_per_batch = 16  # how many different piano scores we want per batch (must divide batch_size)
seed = 1212           # random seed for replication (applies only to choosing scores and patch start times)


# turn_probabilities_to_notes params
how = 'random'              # look into function for more details
normalize = False           # whether to normalize the probabilities outputted from the model
remap_to_max = True         # whether to divide probabilities by the max probability in that timestep
turn_on_notes = 8           # how many notes we want to turn on maximally at any timestep (humans have 10 fingers)
divide_prob = 2             # value by which we divide probabilities
articulation_prob = 0.0018  # if probability of stroking note is higher than that and it was played in last timestep then we articulate
remap_prob = 0.35           # if remap_to_max is True, this is the value we multiply the resulting probabilities by


Now running the below will generate the patches and save them to foldername:

In [None]:
def load_model(file, curr_batch, modelname, *modelparams):
    new_model = modelname(curr_batch, *modelparams)
    
    new_model.load_weights(file)
    
    return new_model

def turn_probabilities_to_notes(prediction, 
                                turn_on, 
                                how = 'random', 
                                normalize = True, 
                                threshold = 0.1, 
                                divide_prob = 2,
                                remap_to_max = True):
    
    for batch in range(prediction.shape[0]):
        if turn_on[batch] <= 1:
            prediction[batch, :] = 0
            continue
        turn_off = prediction[batch, :].argsort()[:-int(turn_on[batch])]
        prediction[batch, :][turn_off] = 0
        
        if normalize: 
            prediction[batch, timestep, :] = st.norm.cdf((prediction[batch, timestep, :] - 
                                                np.mean(prediction[batch, timestep, :][prediction[batch, timestep, :] > 0]))/
                                               np.sqrt(np.var(prediction[batch, timestep, :][prediction[batch, timestep, :]>0])))/divide_prob
            prediction[batch, timestep, :][turn_off] = 0
        
        if remap_to_max:
            prediction[batch, :] /= prediction[batch, :].max()
            prediction[batch, :] *= remap_prob
        
    if how == 'random':
        
        notes =  np.random.binomial(1, p=prediction)
        
    elif how == 'random_thresholded':
        
        prediction[prediction >= threshold] += 0.5
        prediction[prediction > 1] = 1
        prediction[prediction < threshold] = 0
        
        notes =  np.random.binomial(1, p=prediction)
        
    elif how == 'thresholded':
        
        prediction[prediction >= threshold] = 1
        prediction[prediction < threshold] = 0
        
        notes = prediction
    
    return notes     

############################################# LOAD DATA ####################################################

file = 'maestro-v2.0.0/maestro-v2.0.0.csv'
# Get a batch we want to predict
data_test = DataObject(file, what_type = what_type, 
                       train_tms = train_tms, test_tms = test_tms, 
                       fs = 20, window_size = 15,
                       seed = seed)

# Create a batch class which we will iterate over
test_batch = Batch(data_test, batch_size = batch_size, songs_per_batch = songs_per_batch)


############################################# START GENERATING #############################################

curr_test_batch = copy.deepcopy(test_batch.data)

# Uncomment below line if you want to switch the ordering of the contexts
#curr_test_batch.context[[0,1],:,:,:] = curr_test_batch.context[[1,0],:,:,:]

final_output = np.zeros((test_batch.batch_size, 
                         19+data_test.test_tms+19, 
                         78))

# The next is not necessary but makes the samples a bit better
# Populate from the front
final_output[:,0:19,:] = curr_test_batch.context[0,:,-19:,:]
final_output[:,20,:] = DataObject.drop_articulation3d(curr_test_batch.target[:,0,:,:])

# Populate from the back 
final_output[:,-19:,:] = curr_test_batch.context[1,:,0:19,:]

curr_test_batch.target[:,0:20,:,0] = final_output[:,0:20,:]
curr_test_batch.target[:,0:20,:,1] = np.zeros(final_output[:,0:20,:].shape)

curr_test_batch.target_split = 0
curr_test_batch.window_size  = 20
curr_test_batch.featurize(use_biaxial = True)

# If you have trained a different model from the models.py file, change the last argument of the next function to that models name
model = load_model(my_model_name, curr_test_batch, biaxial_pn_encoder_concat_deeplstm)

def take_prediction(t):
    if t<20:
        return -t
    else:
        return -20

def take_actual(t):
    if t <= test_tms:
        return np.arange(19, 19+t, 1)
    else:
        return np.arange(t-test_tms+19, t-19, 1)

# Start looping over the target patch
for timestep in range(1,test_tms):
    
    stdout.write('\rtimestep {}/{}'.format(timestep, test_tms))
    stdout.flush()
    
    prediction = model.predict([tf.convert_to_tensor(curr_test_batch.context, dtype = tf.float32), 
                                tf.convert_to_tensor(curr_test_batch.target_train, dtype = tf.float32)],
                               steps = 1)[:,take_prediction(timestep):,:]
    
    notes = np.zeros(prediction.shape)
    
    turn_on = [turn_on_notes]*batch_size
    
    # Loop over notes to determine which one to play
    for t in range(notes.shape[1]):
        articulation = np.multiply(prediction[:,t,:], final_output[:,20+t,:])
        articulation[articulation >= articulation_prob] = 1
        articulation[articulation < articulation_prob] = 0
        articulated_notes = np.sum(articulation, axis = -1)
        
        play_notes = turn_probabilities_to_notes(prediction[:,t,:], 
                                        turn_on = turn_on - articulated_notes,
                                        how = 'random', 
                                        normalize = normalize,
                                        divide_prob = divide_prob, 
                                        remap_to_max = remap_to_max)
        
        play_notes = play_notes + articulation
        play_notes[play_notes >= 1] = 1
        play_notes[play_notes < 1] = 0
        
        final_output[:,21+t,:] = play_notes
    
    
    # Now reinitialize the model and everything (quite an inefficient implementation)
    curr_test_batch = copy.deepcopy(test_batch.data)
    
    curr_test_batch.target[:,0:20,:,0] = final_output[:,timestep:(20+timestep)]

    curr_test_batch.target_split = 0
    curr_test_batch.window_size  = 20
    curr_test_batch.featurize(use_biaxial = True)

    # End of timestep loop
    
true_batch = copy.deepcopy(test_batch.data)

# This enables us to save the patches with the actual score names they come from
song_names = np.zeros(len(true_batch.link))
song_names = song_names.tolist()
i = 0
for i, link in enumerate(true_batch.link):
    with open(data_test.file) as csv_file:
            csv_reader = csv.reader(csv_file, delimiter=',')
            line_count = 0

            for row in csv_reader:
                if line_count == 0:
                    line_count += 1
                else:
                    if row[4] == link:
                        name = str(row[0]) + '_' + str(row[1]) + '___' + str(i)
                        name = name.replace(" ", "-")
                        name = name.replace("/", "")
                        song_names[i] = name
                        break  

##########################################################                      

if path.isdir(foldername):
    os.system('rm -r {}'.format(foldername))

if not path.isdir(foldername):
    os.mkdir(foldername)

with open('{}/setup.txt'.format(foldername), 'w+') as f:
    f.write('what_type = {} \n \
             train_tms = {} \n \
             test_tms  = {} \n \
             batch_size = {} \n \
             songs_per_batch ={} \n \
             how = {} \n \
             normalize = {} \n \
             turn_on = {} \n \
             divide_prob = {} \n \
             articulation_prob = {}'.format(what_type,
                                            str(train_tms),
                                            str(test_tms),
                                            str(batch_size),
                                            str(songs_per_batch),
                                            how,
                                            str(normalize),
                                            str(turn_on[0]),
                                            str(divide_prob),
                                            str(articulation_prob)))

##########################################################

true_batch = copy.deepcopy(test_batch.data)

true_batch.target = DataObject.drop_articulation(true_batch.target)

# Combine context
true_sample = np.append(np.squeeze(curr_test_batch.context[:,0,:,:]), true_batch.target, axis = 1)
true_sample = np.append(true_sample, np.squeeze(curr_test_batch.context[:,1,:,:]), axis = 1)

true_sample = np.append(np.expand_dims(true_sample, axis = 3), 
                                  np.expand_dims(true_sample, axis = 3), axis = 3)

predicted_sample = np.append(np.squeeze(curr_test_batch.context[:,0,:,:]), final_output[:,20:(20+test_tms),:], axis = 1)
predicted_sample = np.append(predicted_sample, np.squeeze(curr_test_batch.context[:,1,:,:]), axis = 1)

predicted_sample = np.append(np.expand_dims(predicted_sample, axis = 3), 
                                       np.expand_dims(predicted_sample, axis = 3), axis = 3)

# Save final midi

save_indices = np.arange(0,test_batch.batch_size)
for idx, i in enumerate(save_indices):
    print("saving {}".format(idx))

    noteStateMatrixToMidi(true_sample[i,:,:], name = '{}/NO_{}_TRUE_{}'.format(foldername,i,song_names[i]))
    noteStateMatrixToMidi(predicted_sample[i,:,:], name = '{}/NO_{}_PRED_{}'.format(foldername,i,song_names[i]))                

## Plotting functions

In [None]:
def plot_batch_element2(batch, fig, which_element = 0, cmap_ctx = 'viridis', cmap_tar = 'Reds', num_subplot = 2):
    ax = fig.add_subplot(300 + 10 + num_subplot)
    
    full_segment = combine_pianoroll(batch.context[which_element,0,:,:],
                                     np.zeros(batch.target[which_element,:,:].shape),
                                     batch.context[which_element,1,:,:])
    
    just_target = np.zeros(full_segment.shape)
    just_target[40:60, :] = batch.target[which_element,:,:]
    
    plot_pianoroll(ax, full_segment, cmap = cmap_ctx)
    plot_pianoroll(ax, just_target,  cmap = cmap_tar, alpha = 1)
    ax.axvline(data_test.train_tms)
    ax.axvline(data_test.train_tms+data_test.test_tms)
    
    return fig, ax

def pad_with_zeros(pianoroll):
    return np.pad(pianoroll, ((0,0),(25, 25)), 'constant', constant_values=(0, 0))

def combine_pianoroll(*pianorolls):
    
    for idx, pianoroll in enumerate(pianorolls):
        if idx == 0:
            new_pianoroll = pianoroll
        else:
            new_pianoroll = np.append(new_pianoroll, pianoroll, axis = 0)
    
    return new_pianoroll

def plot_batch_element(batch, which_element = 0, cmap_ctx = 'viridis', cmap_tar = 'Reds', num_subplots = 3, figsize = (12,8)):
    fig = plt.figure(figsize = figsize)
    ax = fig.add_subplot(num_subplots*100 + 11)
    
    full_segment = combine_pianoroll(batch.context[which_element,0,:,:],
                                     np.zeros(DataObject.drop_articulation3d(batch.target[which_element,:,:]).shape),
                                     batch.context[which_element,1,:,:])
    
    just_target = np.zeros(full_segment.shape)
    just_target[40:60, :] = DataObject.drop_articulation3d(batch.target[which_element,:,:])
    
    plot_pianoroll(ax, full_segment, cmap = cmap_ctx)
    plot_pianoroll(ax, just_target,  cmap = cmap_tar, alpha = 1)
    ax.axvline(data_test.train_tms)
    ax.axvline(data_test.train_tms+data_test.test_tms)
    
    return fig, ax

# The next function is a modified function from the packacge pypianoroll
def plot_pianoroll(
    ax,
    pianoroll,
    is_drum=False,
    beat_resolution=None,
    downbeats=None,
    preset="default",
    cmap="Blues",
    xtick="auto",
    ytick="octave",
    xticklabel=True,
    yticklabel="auto",
    tick_loc=None,
    tick_direction="in",
    label="both",
    grid="both",
    grid_linestyle=":",
    grid_linewidth=0.5,
    num_notes = 78,
    x_start = 1,
    alpha = 1,
):
    """
    Plot a pianoroll given as a numpy array.

    Parameters
    ----------
    ax : matplotlib.axes.Axes object
        A :class:`matplotlib.axes.Axes` object where the pianoroll will be
        plotted on.
    pianoroll : np.ndarray
        A pianoroll to be plotted. The values should be in [0, 1] when data type
        is float, and in [0, 127] when data type is integer.

        - For a 2D array, shape=(num_time_step, num_pitch).
        - For a 3D array, shape=(num_time_step, num_pitch, num_channel), where
          channels can be either RGB or RGBA.

    is_drum : bool
        A boolean number that indicates whether it is a percussion track.
        Defaults to False.
    beat_resolution : int
        The number of time steps used to represent a beat. Required and only
        effective when `xtick` is 'beat'.
    downbeats : list
        An array that indicates whether the time step contains a downbeat (i.e.,
        the first time step of a bar).

    preset : {'default', 'plain', 'frame'}
        A string that indicates the preset theme to use.

        - In 'default' preset, the ticks, grid and labels are on.
        - In 'frame' preset, the ticks and grid are both off.
        - In 'plain' preset, the x- and y-axis are both off.

    cmap :  `matplotlib.colors.Colormap`
        The colormap to use in :func:`matplotlib.pyplot.imshow`. Defaults to
        'Blues'. Only effective when `pianoroll` is 2D.
    xtick : {'auto', 'beat', 'step', 'off'}
        A string that indicates what to use as ticks along the x-axis. If 'auto'
        is given, automatically set to 'beat' if `beat_resolution` is also given
        and set to 'step', otherwise. Defaults to 'auto'.
    ytick : {'octave', 'pitch', 'off'}
        A string that indicates what to use as ticks along the y-axis.
        Defaults to 'octave'.
    xticklabel : bool
        Whether to add tick labels along the x-axis. Only effective when `xtick`
        is not 'off'.
    yticklabel : {'auto', 'name', 'number', 'off'}
        If 'name', use octave name and pitch name (key name when `is_drum` is
        True) as tick labels along the y-axis. If 'number', use pitch number. If
        'auto', set to 'name' when `ytick` is 'octave' and 'number' when `ytick`
        is 'pitch'. Defaults to 'auto'. Only effective when `ytick` is not
        'off'.
    tick_loc : tuple or list
        The locations to put the ticks. Availables elements are 'bottom', 'top',
        'left' and 'right'. Defaults to ('bottom', 'left').
    tick_direction : {'in', 'out', 'inout'}
        A string that indicates where to put the ticks. Defaults to 'in'. Only
        effective when one of `xtick` and `ytick` is on.
    label : {'x', 'y', 'both', 'off'}
        A string that indicates whether to add labels to the x-axis and y-axis.
        Defaults to 'both'.
    grid : {'x', 'y', 'both', 'off'}
        A string that indicates whether to add grids to the x-axis, y-axis, both
        or neither. Defaults to 'both'.
    grid_linestyle : str
        Will be passed to :meth:`matplotlib.axes.Axes.grid` as 'linestyle'
        argument.
    grid_linewidth : float
        Will be passed to :meth:`matplotlib.axes.Axes.grid` as 'linewidth'
        argument.

    """
    
    if pianoroll.ndim not in (2, 3):
        raise ValueError("`pianoroll` must be a 2D or 3D numpy array")
    if pianoroll.shape[1] != num_notes:
        raise ValueError("The length of the second axis of `pianoroll` must be 128.")
    if xtick not in ("auto", "beat", "step", "off"):
        raise ValueError("`xtick` must be one of {'auto', 'beat', 'step', 'none'}.")
    if xtick == "beat" and beat_resolution is None:
        raise ValueError("`beat_resolution` must be specified when `xtick` is 'beat'.")
    if ytick not in ("octave", "pitch", "off"):
        raise ValueError("`ytick` must be one of {octave', 'pitch', 'off'}.")
    if not isinstance(xticklabel, bool):
        raise TypeError("`xticklabel` must be bool.")
    if yticklabel not in ("auto", "name", "number", "off"):
        raise ValueError(
            "`yticklabel` must be one of {'auto', 'name', 'number', 'off'}."
        )
    if tick_direction not in ("in", "out", "inout"):
        raise ValueError("`tick_direction` must be one of {'in', 'out', 'inout'}.")
    if label not in ("x", "y", "both", "off"):
        raise ValueError("`label` must be one of {'x', 'y', 'both', 'off'}.")
    if grid not in ("x", "y", "both", "off"):
        raise ValueError("`grid` must be one of {'x', 'y', 'both', 'off'}.")

    # plotting
    if pianoroll.ndim > 2:
        to_plot = pianoroll.transpose(1, 0, 2)
    else:
        to_plot = pianoroll.T
    if np.issubdtype(pianoroll.dtype, np.bool_) or np.issubdtype(
        pianoroll.dtype, np.floating
    ):
        ax.imshow(
            to_plot,
            cmap=cmap,
            aspect="auto",
            vmin=0,
            vmax=1,
            origin="lower",
            interpolation="none",
            alpha = alpha,
        )
    elif np.issubdtype(pianoroll.dtype, np.integer):
        ax.imshow(
            to_plot,
            cmap=cmap,
            aspect="auto",
            vmin=0,
            vmax=127,
            origin="lower",
            interpolation="none",
            alpha = alpha,
        )
    else:
        raise TypeError("Unsupported data type for `pianoroll`.")

    # tick setting
    if tick_loc is None:
        tick_loc = ("bottom", "left")
    if xtick == "auto":
        xtick = "beat" if beat_resolution is not None else "step"
    if yticklabel == "auto":
        yticklabel = "name" if ytick == "octave" else "number"

    if preset == "plain":
        ax.axis("off")
    elif preset == "frame":
        ax.tick_params(
            direction=tick_direction,
            bottom=False,
            top=False,
            left=False,
            right=False,
            labelbottom=False,
            labeltop=False,
            labelleft=False,
            labelright=False,
        )
    else:
        ax.tick_params(
            direction=tick_direction,
            bottom=("bottom" in tick_loc),
            top=("top" in tick_loc),
            left=("left" in tick_loc),
            right=("right" in tick_loc),
            labelbottom=(xticklabel != "off"),
            labelleft=(yticklabel != "off"),
            labeltop=False,
            labelright=False,
        )

    # x-axis
    if xtick == "beat" and preset != "frame":
        num_beat = pianoroll.shape[0] // beat_resolution
        ax.set_xticks(beat_resolution * np.arange(num_beat) - 0.5)
        ax.set_xticklabels("")
        ax.set_xticks(beat_resolution * (np.arange(num_beat) + 0.5) - 0.5, minor=True)
        ax.set_xticklabels(np.arange(x_start, num_beat + 1), minor=True)
        ax.tick_params(axis="x", which="minor", width=0)

    # y-axis
    if ytick == "octave":
        ax.set_yticks(np.arange(0, num_notes, 12))
        if yticklabel == "name":
            ax.set_yticklabels(["C{}".format(i - 2) for i in range(11)])
    elif ytick == "step":
        ax.set_yticks(np.arange(0, num_notes))
        if yticklabel == "name":
            if is_drum:
                ax.set_yticklabels(
                    [pretty_midi.note_number_to_drum_name(i) for i in range(num_notes)]
                )
            else:
                ax.set_yticklabels(
                    [pretty_midi.note_number_to_name(i) for i in range(num_notes)]
                )

    # axis labels
    if label in ("x", "both"):
        if xtick == "step" or not xticklabel:
            ax.set_xlabel("time (step)")
        else:
            ax.set_xlabel("time (beat)")

    if label in ("y", "both"):
        if is_drum:
            ax.set_ylabel("key name")
        else:
            ax.set_ylabel("pitch")

    # grid
    if grid != "off":
        ax.grid(
            axis=grid, color="k", linestyle=grid_linestyle, linewidth=grid_linewidth
        )

    # downbeat boarder
    if downbeats is not None and preset != "plain":
        for step in downbeats:
            ax.axvline(x=step, color="k", linewidth=1)
            