This notebook is designed to purely for model training, particularly for ease of use on the HPC.

In [None]:
import glob
import numpy as np
import pathlib
import pandas as pd
import pretty_midi
import tensorflow as tf
import keras

# config
n_files = 10 #len(filenames)
seed = 42
tf.random.set_seed(seed)
np.random.seed(seed)
learning_rate = 0.005


# load data 
data_dir = pathlib.Path('data/maestro-v2.0.0')
if not data_dir.exists():
  tf.keras.utils.get_file(
      'maestro-v2.0.0-midi.zip',
      origin='https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip',
      extract=True,
      cache_dir='.', cache_subdir='data',
  )
  
  
filenames = glob.glob(str(data_dir/'**/*.mid*'))


key_order = ["pitch", "step", "duration", "velocity"]
def get_midi_note_data(file_path:str):
    pm = pretty_midi.PrettyMIDI(file_path)
    
    sorted_notes: list = sorted(pm.instruments[0].notes, key=lambda note: note.start)
    
    out = {}
    out["start"] = []
    out["end"] = []
    out["pitch"] = []
    out["step"] = []
    out["duration"] = []
    out["velocity"] = []
    
    previous_start = 0
    for note in sorted_notes:
        out["start"].append(note.start)
        out["end"].append(note.end)
        out["duration"].append(note.end - note.start)
        out["pitch"].append(note.pitch)
        out["step"].append(note.start - previous_start)
        out["velocity"].append(note.velocity)
        previous_start = note.start
        
    return out


def notes_to_midi(
  notes: pd.DataFrame,
  out_file: str, 
  instrument_name: str,
  velocity: int = 100,  # note loudness
) -> pretty_midi.PrettyMIDI:

  pm = pretty_midi.PrettyMIDI()
  instrument = pretty_midi.Instrument(
      program=pretty_midi.instrument_name_to_program(
          instrument_name))

  prev_start = 0
  for i, note in notes.iterrows():
    start = float(prev_start + note['step'])
    end = float(start + note['duration'])
    note = pretty_midi.Note(
        velocity=velocity,
        pitch=int(note['pitch']),
        start=start,
        end=end,
    )
    instrument.notes.append(note)
    prev_start = start
    
  pm.instruments.append(instrument)
  pm.write(out_file)
  return pm

def getAllNotes():
    all_notes = []
    for f in range(n_files):
        pm = pretty_midi.PrettyMIDI(filenames[f])
    
        sorted_notes: list = sorted(pm.instruments[0].notes, key=lambda note: note.start)
        
        out = {}
        out["start"] = []
        out["end"] = []
        out["pitch"] = []
        out["step"] = []
        out["duration"] = []
        out["velocity"] = []
        
        previous_start = 0
        for note in sorted_notes:
            out["start"].append(note.start)
            out["end"].append(note.end)
            out["duration"].append(note.end - note.start)
            out["pitch"].append(note.pitch)
            out["step"].append(note.start - previous_start)
            out["velocity"].append(note.velocity)
            previous_start = note.start
            
        notes = pd.DataFrame({name: np.array(value) for name, value in out.items()})
        all_notes.append(notes)
        
    all_notes = pd.concat(all_notes)
    return all_notes

def create_note_sequences(dataset: tf.data.Dataset, sequence_length: int, vocab_size: int):

    windows = dataset.window(sequence_length + 1, shift=1, stride=1, drop_remainder=True)
    sequences = windows.flat_map(lambda x: x.batch(sequence_length + 1, drop_remainder=True))
    

    def scale_pitch(x):
        
        pitch_scale = float(vocab_size)
        step_scale = 1.0
        duration_scale = 1.0
        velocity_scale = 1.0
        
        return tf.divide(x, tf.constant([pitch_scale, step_scale, duration_scale, velocity_scale], dtype= tf.float64))

    
    def split_labels(sequences):
        inputs = sequences[:-1]
        labels_dense = sequences[-1]
        labels = {key:labels_dense[i] for i,key in enumerate(key_order)}
            
        return scale_pitch(inputs), labels


    return sequences.map(split_labels, tf.data.AUTOTUNE)

      
      
all_notes = getAllNotes()

seq_length = 100
vocab_size = np.max(all_notes["pitch"].unique()) + 1
max_velocity = np.max(all_notes["velocity"]) + 1

notes_by_column = [all_notes[key].to_numpy() for key in key_order]
stacked_notes = np.stack(notes_by_column, axis=1)

notes_ds = tf.data.Dataset.from_tensor_slices(stacked_notes)


seq_ds = create_note_sequences(notes_ds, seq_length, vocab_size)

batch_size = 64
buffer_size = len(all_notes)
training_ds = seq_ds.shuffle(buffer_size).batch(batch_size, drop_remainder=True).cache().prefetch(tf.data.AUTOTUNE)


def mse_with_positive_presssure(y_true: tf.Tensor, y_pred: tf.Tensor):
    mse = tf.square(y_true - y_pred)
    positive_pressure = 10 * tf.maximum(-y_pred, 0.0)
    
    return tf.reduce_mean(mse + positive_pressure)

loss = {
    "pitch": keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    "step": mse_with_positive_presssure,
    "duration": mse_with_positive_presssure,
    "velocity": keras.losses.SparseCategoricalCrossentropy(from_logits=True)
}

optimizer = keras.optimizers.Adam(learning_rate)

def create_model():

    input_shape = (seq_length, len(key_order))

    input_layer = keras.layers.Input(input_shape)
    lstm_layer = keras.layers.LSTM(vocab_size)(input_layer)
    lstm_layer = keras.layers.Reshape([vocab_size, 1])(lstm_layer)
    lstm_layer = keras.layers.LSTM(vocab_size)(lstm_layer)

    outputs = {
        "pitch": keras.layers.Dense(vocab_size, name = 'pitch')(lstm_layer),
        "step": keras.layers.Dense(1, name="step")(lstm_layer),
        "duration": keras.layers.Dense(1, name="duration")(lstm_layer),
        "velocity": keras.layers.Dense(max_velocity, name="velocity")(lstm_layer)
    }

    model = keras.Model(input_layer, outputs)

    model.compile(optimizer, loss)
    
    return model

model = create_model()

losses = model.evaluate(training_ds, return_dict=True)
loss_weights = {key: 1 / losses[f"{key}_loss"] * (len(losses) - 1) for key in key_order}
weighted_losses = {key: loss_weights[key]*losses[f"{key}_loss"] for key in loss_weights}

model.compile(optimizer, loss, loss_weights = loss_weights)

new_losses = model.evaluate(training_ds, return_dict=True)

callbacks = [
    
    keras.callbacks.ReduceLROnPlateau(monitor="loss"),
    keras.callbacks.ModelCheckpoint("checkpoints/MIDI_with_LSTM/checkpoint_{epoch}", save_weights_only=True),
    keras.callbacks.EarlyStopping(monitor="loss", patience=5, verbose=1, restore_best_weights=True)
]


epochs = 1000
history = model.fit(training_ds, epochs=epochs, callbacks=callbacks)

In [None]:
# predictions

import os
import re
from scipy.io import wavfile

_SAMPLING_RATE = 16_000

def predict_next_note(
    notes: np.ndarray, 
    model: tf.keras.Model, 
    temperature: float = 1.0) -> tuple[int, float, float]:

  assert temperature > 0

  # Add batch dimension
  inputs = tf.expand_dims(notes, 0)

  predictions = model.predict(inputs)
  pitch_logits = predictions['pitch']
  step = predictions['step']
  duration = predictions['duration']
  velocity_logits = predictions['velocity']

  pitch_logits /= temperature
  velocity_logits /= temperature
  pitch = tf.random.categorical(pitch_logits, num_samples=1)
  pitch = tf.squeeze(pitch, axis=-1)
  duration = tf.squeeze(duration, axis=-1)
  step = tf.squeeze(step, axis=-1)
  velocity = tf.random.categorical(velocity_logits, num_samples=1)
  velocity = tf.squeeze(velocity, axis = -1)

  step = tf.maximum(0, step)
  duration = tf.maximum(0, duration)

  return int(pitch), float(step), float(duration), int(velocity)


def save_wav(audio, rate, filename):
    audio = (audio / np.max(audio))
    saveable = (audio * (2 ** 16 - 1)) - 2 ** 15
    saveable = saveable.astype(np.int16)
    wavfile.write(filename, rate, saveable)


temperature = 1.1
num_predictions = 5000

checkpoint_dir = "checkpoints/MIDI_with_LSTM"
def get_last_checkpoint():
    pattern = r'checkpoint_(\d+)\.'
    files = os.listdir(checkpoint_dir)
    checkpoints = [int(re.match(pattern, file).group(1)) if re.match(pattern, file) else -1 for file in files if file.startswith("checkpoint")]
    return max(checkpoints)

checkpoint_path = f"{checkpoint_dir}/checkpoint_{get_last_checkpoint()}"

model = create_model()
model.load_weights(checkpoint_path)

example = filenames[0]
raw_notes = get_midi_note_data(example)
sample_notes = np.stack([raw_notes[key] for key in key_order], axis=1)

input_notes = (
    sample_notes[:seq_length] / np.array([vocab_size, 1, 1, 1]))

generated_notes = []
prev_start = 0
for _ in range(num_predictions):
  pitch, step, duration, velocity = predict_next_note(input_notes, model, temperature)
  start = prev_start + step
  end = start + duration
  input_note = (pitch, step, duration, velocity)
  generated_notes.append((*input_note, start, end))
  input_notes = np.delete(input_notes, 0, axis=0)
  input_notes = np.append(input_notes, np.expand_dims(input_note, 0), axis=0)
  prev_start = start

generated_notes = pd.DataFrame(
    generated_notes, columns=(*key_order, 'start', 'end'))


instrument_name = pretty_midi.program_to_instrument_name(pretty_midi.PrettyMIDI(example).instruments[0].program)
midi = notes_to_midi(generated_notes, "test_outputs/MIDI_with_LSTM/test.midi", instrument_name, 100)

synth = midi.synthesize(_SAMPLING_RATE)
midi.write("predictions/MIDI_with_LSTM/example.midi")
    
save_wav(synth, _SAMPLING_RATE, "predictions/MIDI_with_LSTM/example.wav")
