# Import libraries

In [1]:
import os
import copy
from datetime import date, datetime

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

seed = 42
torch.manual_seed(seed)

from dataloader import PianoRollDataset
from models import LSTM
from utils.midi_processing import convert_midi_to_piano_roll
from utils.visuals import visualize_sequence, make_video
from utils.export import export_features

Using device: cuda


# Convert midi dataset to piano roll

This cell converts midi files to piano rolls, to save computation time during training. Make sure to prepare ```train```, ```validation``` and ```test``` subsets before running this.

The time sampling of the piano rolls must be specified when converting the files, but they can be downsampled when they are loaded later on. The ```dataset_fs``` value should be equal or superior to the desired time sampling (in Hz).

Run this cell even if the data are already preprocessed, to define the ```path``` and  ```dataset_fs``` variables.

In [3]:
midi_path = './data/midi_dataset_example/'  # Path to the MIDI dataset, containing train, validation, and test subfolders
pr_path = './data/pr_dataset/'              # Path where the dataset will be stored with a piano roll format

dataset_fs = 30                             # Sampling frequency used to convert MIDI files to piano roll format. The piano rolls
                                            # can be automatically downsampled when they're preloaded before the training loop.

if not os.path.exists(pr_path):
    subsets = ['train/', 'validation/', 'test/']
    for subset in subsets:
        print(f'Converting MIDI files to piano roll format for the {subset} subset...')
        os.makedirs(pr_path + subset, exist_ok=True)
        convert_midi_to_piano_roll(data_path = midi_path + subset, out_dir = pr_path + subset, 
                                    fs = dataset_fs, pedal_threshold = 64)

# Load dataset

To save some computation time when using large training datasets, it is recommanded to load the data in RAM with ```preload = True```. This may take several minutes.


The ```model_fs``` defines the time sampling (in Hz) of the data that the model will receive as input, and thus determines the temporal resolution of the model itself. This value can be inferior or equal to ```dataset_fs```.

The ```ons_value``` and ```sus_value``` define the numerical value that indicate the onset or the sustain of a note in the model's inputs. Changing these values should have little impact on the model's performance. An absence of note is always indicated as ```0```.

The ```max_seq_length``` set the maximum length of the input of the model (in timesteps).

In [4]:
dataset_path = pr_path                  # Path to the dataset in piano roll format.
preload = True                          # If True, the entire dataset is loaded into memory. This will speed up the training process,
                                        # but takes some time at loading, and requires a large amount of RAM.

model_fs = 20                           # Sampling frequency of the model. This value can only be equal or lower 
                                        # than the sampling frequency of the dataset.
batch_size = 16                         # Number of sequences processed in parallel by the model during training.
ons_value = 1                           # Value indicating the onset of a note in the model's input.
sus_value = 0.5                         # Value indicating the sustain of a note in the model's input.
padding_value = -99                     # Value used to pad the input sequences to the same length.
max_seq_length = 60*model_fs            # Maximum length of the input sequence in timesteps.

train_dataset = PianoRollDataset(data_path = dataset_path + 'train/', 
               dataset_fs = dataset_fs, model_fs = model_fs, 
               ons_value = ons_value, sus_value = sus_value, 
               padding_value = padding_value,
               source_length = max_seq_length,
               use_transposition = True,
               preload = preload, device = device, dtype = torch.float32)

validation_dataset = PianoRollDataset(data_path = dataset_path + 'validation/', 
               dataset_fs = dataset_fs, model_fs = model_fs, 
               ons_value = ons_value, sus_value = sus_value, 
               padding_value = padding_value,
               source_length = max_seq_length,
               use_transposition = True,
               preload = preload, device = device, dtype = torch.float32)

train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn = train_dataset.collate_batch,
                        shuffle=True, num_workers=0)
validation_loader = DataLoader(validation_dataset, batch_size=batch_size, collate_fn = validation_dataset.collate_batch,
                        shuffle=True, num_workers=0)

# Instanciate model

All the usual LSTM and optimizer parameters are listed below (see pytorch documentation).

If ```run_name = None```, a new folder will be created to store the checkpoints during training.
To resume the training of a model, set ```run_name``` to the name of its folder.

In [27]:
run_name = None     #if None, create a new run folder (named with the current date and time) and start training from scratch
                    #else, set the name of the folder to resume training from (e.g. 'lstm_run_Oct20_20-32-25')

#model parameters
input_size = 88             #dimension of the input sequence. In this case, it is the number of piano keys
output_size = 88            #dimension of the output sequence

n_lstm_layers = 1           #number of LSTM layers
hidden_dim = 128            #dimension of the hidden state of the LSTM
trunc_tw = 5*model_fs       #truncated backpropagation through time (BPTT) length. This will affect the ability of the model to learn long-term dependencies.
                            #A large value will allow the model to learn dependencies over longer time scales, but the model may fail to learn anything.
                            #A small value will restrict the learning to short-term dependencies, but the model has a higher chance of learning effectively.
dropout = 0.1               #dropout rate

#model initialization
model = LSTM(input_size, n_lstm_layers, hidden_dim, output_size, 
                 dropout, padding_value, device)
model.to(device)

#optimizer
lr = 1e-4               #learning rate
betas = (0.9, 0.98)     #betas for the Adam optimizer
eps = 1e-9              #epsilon for the Adam optimizer
optimizer = optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps)

#load model and optimizer state from the last checkpoint or start from scratch
if run_name is None:
    i_batch = 0

    date_run = date.today().strftime('%b%d') + '_' + datetime.now().strftime('%H-%M-%S')
    run_path = './runs/lstm_run_' + date_run + '/'
    if os.path.exists(run_path) == False:
        os.makedirs(run_path, exist_ok=True)
    writer = SummaryWriter(log_dir = run_path)                # Create a tensorboard writer to store training/validation losses, 
                                                              # and plot dynamic graphs during training (see tensorboard's documentation).
    loss_history = {'train_loss': [], 'validation_loss': []}  # Custom history of losses. Tensorboard's writer is efficient, but can
                                                              # occasionally produce corrupted files. This custom history is a safety measure. 

else:
    run_path = './runs/' + run_name + '/'
    tmp = []
    for file in os.listdir(run_path):
        if 'model_' in file:
            i_batch = int(file.split('_')[1].split('.')[0])
            tmp.append(i_batch)
    i_batch = max(tmp)
    
    writer = SummaryWriter(run_path)
    loss_history = np.load(run_path + 'loss_history.npy', allow_pickle=True).item()

    model_path = run_path + 'model_' + str(i_batch) + '.pt'
    model.load_state_dict(torch.load(model_path))
    optimizer_path = run_path + 'optimizer_' + str(i_batch) + '.pt'
    optimizer.load_state_dict(torch.load(optimizer_path))

# Training

When running the training loop, the training and validation losses will be stored with a TensordBoard writer, and can be visualized by lauching a TensorBoard session. Since TensorBoard can sometimes be unreliable, the loss values are also stored in the ```loss_history``` dictionnary.

The model state will be automatically saved during training, from every 1K batches at the beginning to every 25K batches later on.

The training loop can be manually stopped and resumed at any time. The next cell allows you to save the model manually.

In [None]:
n_batch = 1500000

# The training loop can be manually stopped and resumed. Use the next cell to manually save the model state.
while i_batch <= n_batch:
    for (src, tgt) in train_loader:
        model.train()

        t = 0
        loss_batch = []
        (h,c) = model.init_hidden(batch_size = src.shape[0])
        while t < src.shape[1]:
            optimizer.zero_grad()

            step = min(trunc_tw, src.shape[1] - t)
            src_trunc = src[:,t:t+step,:]
            tgt_trunc = tgt[:,t:t+step,:]

            output, (h,c) = model(src_trunc, (h,c))
            loss = model.criterion(output, tgt_trunc)
            loss.backward()
            optimizer.step()
            h,c = h.detach_(), c.detach_()

            loss_batch.append(loss.item())
            t += step

        #tensorboard
        training_loss = np.mean(loss_batch)
        if i_batch % 10 == 0:
            writer.add_scalar('Loss/train', training_loss, i_batch)
            loss_history['train_loss'].append((training_loss, i_batch))

        print(f"Batch: {i_batch}, Loss: {np.round(loss.item(),4)}")

        #validation loop
        if i_batch % 100 == 0:
            model.eval()
            (h,c) = model.init_hidden(batch_size = src.shape[0])
            with torch.no_grad():
                src, tgt = next(iter(validation_loader))
                output, (h,c) = model(src, (h,c))
                loss = model.criterion(output, tgt)
                
            
            validation_loss = copy.copy(loss.item())
            writer.add_scalar('Loss/validation', validation_loss, i_batch)
            loss_history['validation_loss'].append((validation_loss, i_batch))
            print(f"Batch: {i_batch}")
            print(f" --- Training loss: {np.round(training_loss,4)}")
            print(f" --- Validation loss: {np.round(validation_loss,4)}")

            #visual representation of model output to tensorboard
            if i_batch % 1000 == 0:
                max_t = 25*model_fs
                fig = visualize_sequence(src[0,:max_t,:], tgt[0,:max_t,:], output[0,:max_t,:])
                writer.add_figure('Visual/Validation', fig, i_batch)
            
            model.train()

        #save model
        if (i_batch % 1000 == 0 and i_batch < 10000) or (i_batch % 5000 == 0 and i_batch < 100000) or (i_batch % 10000 == 0 and i_batch < 300000) or (i_batch % 25000 == 0 and i_batch > 300000):
            torch.save(model.state_dict(), run_path + f'/model_{i_batch}.pt')
            torch.save(optimizer.state_dict(), run_path + f'/optimizer_{i_batch}.pt')
            np.save(run_path + 'loss_history.npy', loss_history)

            print(f'Model saved at batch {i_batch}')

        i_batch += 1

In [None]:
#to save manually, run this cell
torch.save(model.state_dict(), run_path + f'/model_{i_batch}.pt')
torch.save(optimizer.state_dict(), run_path + f'/optimizer_{i_batch}.pt')
print(f'Model saved at batch {i_batch}')

# Select best model

This cell computes the validation loss over the entire validation set for each checkpoint of the model, and selects the model state with the lowest validation loss.

To evaluate the model on a subsets of the checkpoints only, set the ```versions_range``` to the desired i_batch range.

To speed up the process, the validation set can be arbitrarely divided by ```n```.

In [None]:
versions_range = [0, n_batch]  #range of model versions to evaluate
n = 1 #divide by n the validation set to speed up the process

#retrieve model versions to evaluate
checkpoints = []
for file in os.listdir(run_path):
    if 'model_' in file:
        i_batch = int(file.split('_')[1].split('.')[0])
        if i_batch >= versions_range[0] and i_batch <= versions_range[1]:
            checkpoints.append(i_batch)
checkpoints = np.sort(checkpoints).tolist()
print(f"Evaluating checkpoints: {checkpoints}")

#compute loss on entire validation set for each model version (this can take some time)
val_loss = []
for checkpoint in checkpoints:
    tmp = []
    model_path = run_path + 'model_' + str(checkpoint) + '.pt'
    model.load_state_dict(torch.load(model_path))       #make sure that the model has been instanciated before running this
    model.eval()

    i = 0
    for (src, tgt) in validation_loader:
        if i % n == 0: 
            (h,c) = model.init_hidden(batch_size = src.shape[0])
            with torch.no_grad():
                output, (h,c) = model(src, (h,c))
                loss = model.criterion(output, tgt)
                tmp.append(loss.item())
        i += 1
    print(f"Batch: {checkpoint}, Validation loss: {np.mean(tmp)}")
    val_loss.append(np.mean(tmp))

#retrieve best model
print('---')
best_model_checkpoint = checkpoints[np.argmin(val_loss)]
best_train_loss = loss_history['train_loss'][np.argmin([np.abs(s - best_model_checkpoint) for s in [s for (_, s) in loss_history['train_loss']]])][0]
best_val_loss = val_loss[np.argmin(val_loss)]
print(f"Best model: model_{best_model_checkpoint}")
print(f"Training loss: {best_train_loss}")
print(f"Validation loss: {best_val_loss}")

# Save final version

The final version of a model is saved in ```'./versions/model_name/```, in a ```model.pt``` file. It can be loaded in other scripts with ```torch.load()```.

All the parameters to instanciate, evaluate or fine-tune the model are saved in ```info.pt```.

In [15]:
name = 'lstm_maestro_test'
additional_info = ''' Examplar model, trained on a subset of the MAESTRO 3.0 dataset. '''

os.makedirs('./versions/' + name, exist_ok=True)

model_path = run_path + 'model_' + str(best_model_checkpoint) + '.pt'
model.load_state_dict(torch.load(model_path))
optimizer_path = run_path + 'optimizer_' + str(best_model_checkpoint) + '.pt'
optimizer.load_state_dict(torch.load(optimizer_path))

if os.path.exists('./versions/' + name + '/model.pt'):
    print('A model already exists at this location !')
else:
    torch.save(model, './versions/' + name + '/model.pt')
    torch.save({'model_type': 'lstm',
                'seed': seed,
                'i_batch': best_model_checkpoint,
                'fs': model_fs,
                'ons_value': ons_value,
                'sus_value': sus_value,
                'padding_value': padding_value,
                'batch_size': batch_size,
                'input_size' : input_size,
                'output_size': output_size,
                'hidden_dim': hidden_dim,
                'n_lstm_layers': n_lstm_layers,
                'trunc_tw': trunc_tw,
                'max_seq_length': max_seq_length,
                'dropout': dropout,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': best_train_loss,
                'val_loss': best_val_loss,
                'additional_info': additional_info,
                }, 
                './versions/' + name + '/info.pt')

# Visualize with video

The ```make_video()``` function requires fluidsynth and ffmpeg to work. Links are available on https://github.com/pl-robert/musecog.

The video will be saved in ./versions/model_name/video/

In [None]:
make_video(file_path = './data/midi_dataset_example/visualization/',
            file_name = 'MIDI-Unprocessed_059_PIANO059_MID--AUDIO-split_07-07-17_Piano-e_2-03_wav--1.mid',
            res = (1280,720), 
            graph = True,
            model_name = 'lstm_maestro_test',
            ffmpeg_path = r'C:/ffmpeg/')

# Export features

The ```export_features()``` function allows you to compute and save a series of features of interest from a given model and midi dataset.

The ```out_fs``` parameter defines the time sampling of time-resolved features. Even if a model runs at 20Hz, the features can be upsampled to align with continuous neural or behavioral data of interest.

If ```timing_correction = True```, the small temporal misalignements caused by a low ```model_fs``` are corrected to match the exact note onset timings (from the midi files), at the cost of minor distorsion in the features values (linear interpolation).

The following features are computed at each timestep:
- **surprise:** Binary Cross Entropy (BCE) between the model's predictions and the target values
- **surprise_max:** maximum BCE across all simulatenous notes
- **surprise_scaled:** BCE scaled by the number of simultaneous notes in the target
- **surprise_positive:** positive part of the BCE
- **surprise_positive_max:** maximum positive part of the BCE across all simulatenous notes
- **surprise_positive_scaled:** positive part of the BCE scaled by the number of simultaneous notes in the target
- **surprise_negative:** negative part of the BCE
- **surprise_negative_max:** maximum negative part of the BCE across all absent notes
- **surprise_negative_scaled:** negative part of the BCE scaled by the number of absent notes in the target
- **uncertainty:** entropy of the model's predictions. The probabilities are normalized (sum = 1) before computing the entropy. 
- **predicted_density:** predicted note density (sum of probabilities) 

They are stored in the files:
- **features.csv:** contains a summary of the features of all midi files in a single table. Each feature is summed over time, either
                across all timesteps ('continuous_xxx' features) or only at the timesteps with note onsets ('onsets_xxx' features). In
                addition, the table contains the number of notes, events (group of simulatenous notes) and the duration of each file.
- **features_over_time:** folder containing the time-resolved features for each midi file. Each file contain all features, with the 
                addition of a time axis and a binary mask for the note onsets.

In [None]:
export_features(data_path = './my_stimuli/midi/',
           output_path = './my_stimuli/',
           model_name = 'lstm_maestro_test',
           out_fs = 100,
           timing_correction = False)