# **Import libraries**

In [10]:
!pip install mido



In [11]:
import mido # easy to use python MIDI library
import matplotlib.pyplot as plt # plotting
import numpy as np # linear algebra
import os # accessing directory structure
import random
import pandas as pd
import pickle
import time
import math

from mido import MidiFile, MidiTrack, Message

from sklearn import model_selection

import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

# **Hyperparameters**

In [13]:
num_epochs = 100
batch_size = 512

sequence_length = 16
embedding_dim_note = 128
embedding_dim_duree = 128
embedding_dim_time = 128
embedding_dim_veloc = 128

hidden_size = 128
num_layers = 3
num_classes_note = 128
num_classes_duree = 32
num_classes_time = 18

learning_rate = 0.05

# **Load data**

In [14]:
def retrive_list_of_seq(file_name):
    infile = open(file_name,'rb')
    list_of_seq = pickle.load(infile)
    infile.close()
    return list_of_seq

In [15]:
sequences_r = retrive_list_of_seq('sequences_r')
sequences_l_p = retrive_list_of_seq('sequences_l_p')
sequences_l = retrive_list_of_seq('sequences_l')

In [16]:
def split(list_of_seq_r, list_of_seq_l_p, list_of_seq_l, valset_size = 0.8):
    indices = [i for i in range(len(list_of_seq_r))]
    random.shuffle(indices)
    train_indices, val_indices = indices[:int(valset_size * len(indices))], indices[int(valset_size * len(indices)):]
    train_set = [[list_of_seq_r[i] for i in train_indices], [list_of_seq_l_p[i] for i in train_indices], [list_of_seq_l[i] for i in train_indices]]
    val_set = [[list_of_seq_r[i] for i in val_indices], [list_of_seq_l_p[i] for i in val_indices], [list_of_seq_l[i] for i in val_indices]]
    return train_set, val_set

In [17]:
def pad_sequence(seq_batch, start_tag=False):
  #seq_batch = seq_all.copy()
  max_length = max([len(seq_batch[i]) for i in range(len(seq_batch))])
  for i in range(len(seq_batch)):
    seq  = seq_batch[i]
    n = len(seq)
    for k in range(max_length - n):
      seq.append([-2, -2, -2, -2])
    if start_tag:
      seq.insert(0, [-3, -3, -3, -3])
  return seq_batch

In [18]:
def generate_batch(dataset, list_id):
  seq_r = [dataset[0][index].copy() for index in list_id]
  seq_l_p = [dataset[1][index].copy() for index in list_id]
  seq_l = [dataset[2][index].copy() for index in list_id]

  seq_r_padded = np.array(pad_sequence(seq_r))
  seq_l_p_padded = np.array(pad_sequence(seq_l_p))
  seq_l_padded = np.array(pad_sequence(seq_l))
  
  return seq_r_padded, seq_l_p_padded, seq_l_padded

In [19]:
train_set, val_set = split(sequences_r, sequences_l_p, sequences_l, valset_size = 0.7)
n_total_train = len(train_set[0])
n_total_val = len(val_set[0])

# **Models**

## Preprocess utils

In [20]:
def process_note(x):
  if x < 0:
    if x == -1:
      x = 128
    elif x == -2:
      x = 129
    elif x == -3:
      x = 130
    return x
  else:
    return x

In [21]:
def process_duree(x):
  if x >= 0:
    x =  (x * 2) // 60 - (x // 60) - 1
    x = min(max(x, 0), 31)
    return x
  else:
    if x == -1:
      x = 32
    elif x == -2:
      x = 33
    elif x == -3:
      x = 34
    return x

In [22]:
def process_time(x):
  if x >= 0:
    x =  (x * 2) // 60 - (x // 60) 
    x = min(max(x, 0), 17)
    return x
  else:
    if x == -1:
      x = 18
    elif x == -2:
      x = 19
    elif x == -3:
      x = 20
    return x

In [23]:
def process_sequence(seq):
  batch = torch.tensor(seq)
  batch = torch.reshape(batch, (batch.shape[0], -1 , 4))

  note = batch[:,:,0]
  velocity = batch[:,:,1]
  duree = batch[:,:,2]
  time = batch[:,:,3]

  note = note.apply_(lambda x :process_note(x))
  notes_sequence = note.to(device)

  duree = duree.apply_(lambda x :process_duree(x))
  duree_sequence = duree.to(device)

  time = time.apply_(lambda x :process_time(x))
  time_sequence = time.to(device)

  velocity_sequence = velocity.to(device) / 100
  velocity_sequence = torch.unsqueeze(velocity_sequence, 2)

  return (notes_sequence, duree_sequence, time_sequence, velocity_sequence)

## Encoder main gauche previous

In [24]:
# Recurrent neural network (many-to-one)
class Encoder_left_p(nn.Module):
  def __init__(self, num_classes_note, num_classes_duree, num_classes_time, embedding_dim_note, embedding_dim_duree, embedding_dim_time, embedding_dim_veloc, hidden_size, num_layers, drop_prob=0., drop_fc=0.):
    super(Encoder_left_p, self).__init__()

    self.embedding_note = nn.Embedding(num_classes_note + 3, embedding_dim_note)
    self.embedding_duree = nn.Embedding(num_classes_duree + 3, embedding_dim_duree)
    self.embedding_time = nn.Embedding(num_classes_time + 3, embedding_dim_time)
    self.fc_emb_veloc = nn.Linear(1, embedding_dim_veloc)

    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.lstm = nn.LSTM(embedding_dim_note + embedding_dim_duree + embedding_dim_time + embedding_dim_veloc, hidden_size, num_layers, dropout=drop_prob, batch_first=True)
    
   

  def forward(self, inputs, pad_mask):
    notes, duree, time, veloc = inputs

    # Embedding layer
    embeddings_note = self.embedding_note(notes)# Output shape (batch, sequence_length, embedding_dim)
    embeddings_duree = self.embedding_duree(duree)
    embeddings_time = self.embedding_time(time)
    emeddings_veloc = self.fc_emb_veloc(veloc)
    #emeddings_veloc =  torch.unsqueeze(emeddings_veloc, 2)
    
    x = torch.cat((embeddings_note, embeddings_duree, embeddings_time, emeddings_veloc), dim=2)
    pad_sequence = torch.nn.utils.rnn.pack_padded_sequence(x, pad_mask, batch_first=True, enforce_sorted=False)

    # Set initial hidden and cell states 
    h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
    c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
    
    # Forward propagate LSTM
    out, hidden = self.lstm(pad_sequence, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size)

    return hidden


## Econder main droite

In [25]:
# Recurrent neural network (many-to-one)
class Encoder_right(nn.Module):
  def __init__(self, num_classes_note, num_classes_duree, num_classes_time, embedding_dim_note, embedding_dim_duree, embedding_dim_time, embedding_dim_veloc, hidden_size, num_layers, drop_prob=0., drop_fc=0.):
    super(Encoder_right, self).__init__()

    self.embedding_note = nn.Embedding(num_classes_note + 3, embedding_dim_note)
    self.embedding_duree = nn.Embedding(num_classes_duree + 3, embedding_dim_duree)
    self.embedding_time = nn.Embedding(num_classes_time + 3, embedding_dim_time)
    self.fc_emb_veloc = nn.Linear(1, embedding_dim_veloc)

    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.lstm = nn.LSTM(embedding_dim_note + embedding_dim_duree + embedding_dim_time + embedding_dim_veloc, hidden_size, num_layers, dropout=drop_prob, batch_first=True)
    
   

  def forward(self, inputs, pad_mask):
    notes, duree, time, veloc = inputs

    # Embedding layer
    embeddings_note = self.embedding_note(notes)# Output shape (batch, sequence_length, embedding_dim)
    embeddings_duree = self.embedding_duree(duree)
    embeddings_time = self.embedding_time(time)
    emeddings_veloc = self.fc_emb_veloc(veloc)
    #emeddings_veloc =  torch.unsqueeze(emeddings_veloc, 2)

    x = torch.cat((embeddings_note, embeddings_duree, embeddings_time, emeddings_veloc), dim=2)
    pad_sequence = torch.nn.utils.rnn.pack_padded_sequence(x, pad_mask, batch_first=True, enforce_sorted=False)

    # Set initial hidden and cell states 
    h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
    c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
    
    # Forward propagate LSTM
    out, hidden = self.lstm(pad_sequence, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size)

    return hidden


## Decoder

In [29]:
# Recurrent neural network (many-to-one)
class Decoder(nn.Module):
  def __init__(self, num_classes_note, num_classes_duree, num_classes_time, embedding_dim_note, embedding_dim_duree, embedding_dim_time, embedding_dim_veloc, hidden_size, num_layers, drop_prob=0., drop_fc=0.):
    super(Decoder, self).__init__()

    self.embedding_note = nn.Embedding(num_classes_note + 3, embedding_dim_note)
    self.embedding_duree = nn.Embedding(num_classes_duree + 3, embedding_dim_duree)
    self.embedding_time = nn.Embedding(num_classes_time + 3, embedding_dim_time)
    self.fc_emb_veloc = nn.Linear(1, embedding_dim_veloc)

    self.embedding_note_cond = nn.Embedding(num_classes_note + 3, embedding_dim_note)
    self.embedding_duree_cond = nn.Embedding(num_classes_duree + 3, embedding_dim_duree)
    self.embedding_time_cond = nn.Embedding(num_classes_time + 3, embedding_dim_time)

    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.lstm = nn.LSTM(embedding_dim_note + embedding_dim_duree + embedding_dim_time + embedding_dim_veloc, hidden_size, num_layers, dropout=drop_prob, batch_first=True)
    
    self.fc_note_1 = nn.Linear(hidden_size, hidden_size)
    self.fc_note_2 = nn.Linear(hidden_size, num_classes_note + 3)

    self.fc_duree_1 = nn.Linear(hidden_size + embedding_dim_note, hidden_size)
    self.fc_duree_2 = nn.Linear(hidden_size, num_classes_duree + 3)

    self.fc_time_1 = nn.Linear(hidden_size + embedding_dim_note + embedding_dim_duree, hidden_size)
    self.fc_time_2 = nn.Linear(hidden_size, num_classes_time + 3)

    self.fc_veloc_1 = nn.Linear(hidden_size + embedding_dim_note + embedding_dim_duree + embedding_dim_time, hidden_size)
    self.fc_veloc_2 = nn.Linear(hidden_size, 1)


    self.relu = nn.ReLU()
    self.dropout = nn.Dropout(p=drop_fc)

  def forward_RNN(self, inputs, hidden):
    notes, duree, time, veloc = inputs

    # Embedding layer
    embeddings_note = self.embedding_note(notes) # Output shape (batch, sequence_length, embedding_dim)
    embeddings_duree = self.embedding_duree(duree)
    embeddings_time = self.embedding_time(time)
    emeddings_veloc = self.fc_emb_veloc(veloc)
    #emeddings_veloc =  torch.unsqueeze(emeddings_veloc, 2)

    x = torch.cat((embeddings_note, embeddings_duree, embeddings_time, emeddings_veloc), dim=2)
    
    # Forward propagate LSTM
    out, hidden = self.lstm(x, hidden)  # out: tensor of shape (batch_size, seq_length, hidden_size)
    out = out[:, -1, :] # Hidden state of the last element of the sequence (equivalent to hidden[0])

    return out, hidden

  def classif_note(self, out):
    note = self.fc_note_1(out)
    note = self.dropout(note)
    note = self.relu(note)
    note = self.fc_note_2(note)
    return note

  def classif_duree(self, out, note_target):

    embeddings_note_target = self.embedding_note_cond(note_target)

    duree = torch.cat((out, embeddings_note_target), dim=1)
    duree = self.fc_duree_1(duree)
    duree = self.dropout(duree)
    duree = self.relu(duree)
    duree = self.fc_duree_2(duree)
    return duree

  def classif_time(self, out, note_target, duree_target):

    embeddings_note_target = self.embedding_note_cond(note_target)
    embeddings_duree_target = self.embedding_duree_cond(duree_target)

    time = torch.cat((out, embeddings_note_target, embeddings_duree_target), dim=1)
    time = self.fc_time_1(time)
    time = self.dropout(time)
    time = self.relu(time)
    time = self.fc_time_2(time)
    return time

  def reg_veloc(self, out, note_target, duree_target, time_target):

    embeddings_note_target = self.embedding_note_cond(note_target)
    embeddings_duree_target = self.embedding_duree_cond(duree_target)
    embeddings_time_target = self.embedding_time_cond(time_target)

    veloc = torch.cat((out, embeddings_note_target, embeddings_duree_target, embeddings_time_target), dim=1)
    veloc = self.fc_veloc_1(veloc)
    veloc = self.dropout(veloc)
    veloc = self.relu(veloc)
    veloc = self.fc_veloc_2(veloc)
    return veloc


  def forward(self, inputs, targets, hidden):
    out, hidden = self.forward_RNN(inputs, hidden)

    note_target, duree_target, time_target = targets

    note = self.classif_note(out)
    duree = self.classif_duree(out, note_target)
    time = self.classif_time(out, note_target, duree_target)
    veloc = self.reg_veloc(out, note_target, duree_target, time_target)

    return note, duree, time, veloc, hidden

# **Training**

## Accuracy 

In [30]:
def validate_model(models, dataset, use_teacher_forcing = True, max_iter=100000):

    n_total_set = len(dataset[0])
    encoder_right, encoder_left_p, decoder_left = models
    encoder_right.eval()
    encoder_left_p.eval()
    decoder_left.eval()

    with torch.no_grad():
        correct_note = 0
        correct_duree = 0
        correct_time = 0
        sum_distance = 0

        total = 0
        count = 0
        # Batch generation
        set_id_epoch = set(range(n_total_set))
        n_iter_per_epoch = int(len(set_id_epoch) / batch_size) + 1

        for iter in range(n_iter_per_epoch):
          if iter >= max_iter:
            break

          list_id_batch = random.sample(set_id_epoch, min(batch_size, len(set_id_epoch)))
          for id in list_id_batch: # Remove id from set
            set_id_epoch.remove(id)

          current_batch_size = len(list_id_batch)
          

          ################# PROCESS SEQUENCE #################

          seq_r, seq_l_p, seq_l  = generate_batch(dataset, list_id_batch)

          (notes_sequence_r, duree_sequence_r, time_sequence_r, velocity_sequence_r) = process_sequence(seq_r)
          (notes_sequence_l_p, duree_sequence_l_p, time_sequence_l_p, velocity_sequence_l_p) = process_sequence(seq_l_p)
          (notes_sequence_l, duree_sequence_l, time_sequence_l, velocity_sequence_l) = process_sequence(seq_l)

          ################### ENCODING ###################

          # Encode right
          mask_r = 1 - (notes_sequence_r == 129).float()
          pad_mask_r = torch.sum(mask_r, 1).int().cpu().detach().numpy()

          inputs_r = (notes_sequence_r, duree_sequence_r, time_sequence_r, velocity_sequence_r)
          hidden_right = encoder_right(inputs_r, pad_mask_r)
          
          # Encode left previous
          mask_l_p = 1 - (notes_sequence_l_p == 129).float()
          pad_mask_l_p = torch.sum(mask_l_p, 1).int().cpu().detach().numpy()

          inputs_l_p = (notes_sequence_l_p, duree_sequence_l_p, time_sequence_l_p, velocity_sequence_l_p)
          hidden_left_p = encoder_left_p(inputs_l_p, pad_mask_l_p)

          # Concatenate hidden states
          hidden_1_r, hidden_2_r = hidden_right
          hidden_1_lp, hidden_2_lp = hidden_left_p

          hidden_1 = torch.cat((hidden_1_r, hidden_1_lp), dim=2)
          hidden_2 = torch.cat((hidden_2_r, hidden_2_lp), dim=2)

          #hidden_1 = torch.randn((hidden_1.shape[0], hidden_1.shape[1], hidden_1.shape[2])).to(device)
          #hidden_2 = torch.randn((hidden_2.shape[0], hidden_2.shape[1], hidden_2.shape[2])).to(device)

          hidden = (hidden_1, hidden_2)

          ################### DECODING ###################

          # Decoder initial input (start tag)
          stack_decoder_input = torch.unsqueeze(torch.tensor([[130, 34, 20, -1]], device=device).repeat(current_batch_size, 1), 1).to(device)

          note_input = stack_decoder_input[:,:,0]
          duree_input = stack_decoder_input[:,:,1]
          time_input = stack_decoder_input[:,:,2]
          velocity_input = stack_decoder_input[:,:,3].float()
          velocity_input = torch.unsqueeze(velocity_input, 2)

          decoder_inputs = (note_input, duree_input, time_input, velocity_input)

          # Teacher forcing
          teacher_forcing_ratio = 0.5
          use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
          use_teacher_forcing

          # Decode left sequence
          length_batch = notes_sequence_l.shape[1]

          loss_total_note = 0
          loss_total_duree = 0
          loss_total_time = 0
          loss_total_veloc = 0

          loss_total = torch.tensor(0., device=device)

          for i in range(length_batch):
            count += 1
            # Predictions
            note_target = notes_sequence_l[:,i]
            duree_target = duree_sequence_l[:,i]
            time_target = time_sequence_l[:,i]
            velocity_target = velocity_sequence_l[:,i]

            targets = (note_target, duree_target, time_target)
            note_pred_logits, duree_pred_logits, time_pred_logits, veloc_pred, hidden = decoder(decoder_inputs, targets,  hidden)

            mask = 1 - (note_target ==  129).float()
            count_non_null = torch.sum(mask)

            # Accuracy note
            _, predicted = torch.max(note_pred_logits.data, 1)
            #print('----')
            #print(predicted.shape)
            #print(note_target.shape)
            #print(predicted[:10])
            #print(note_target[:10])
            correct_note += (predicted * mask == note_target).sum().item()
            total += count_non_null.item()

            # Accuracy duree 
            _, predicted = torch.max(duree_pred_logits.data, 1)
            correct_duree += (predicted * mask == duree_target).sum().item()
            #print(predicted[:10])
            #print(duree_target[:10])

            # Accuracy time
            _, predicted = torch.max(time_pred_logits.data, 1)
            correct_time += (predicted * mask == time_target).sum().item()
            #print(predicted[:10])
            #print(time_target[:10])

            # Distance velocity
            #distance = torch.mean(torch.abs(veloc_pred - velocity_target)) * 100
            distance = 100 * torch.sum(mask * torch.abs(veloc_pred.squeeze() - velocity_target.squeeze())) / count_non_null.item()
            sum_distance += distance.item()
            #print(veloc_pred[:20])
            #print(velocity_target[:20])

          # New inputs
          if use_teacher_forcing:
            note_input = torch.unsqueeze(note_target, 1)
            duree_input = torch.unsqueeze(duree_target, 1)
            time_input = torch.unsqueeze(time_target, 1)
            velocity_input = torch.unsqueeze(velocity_target, 1)

          else:
            _, note_pred = torch.max(note_pred_logits.data, 1)
            _, duree_pred = torch.max(duree_pred_logits.data, 1)
            _, time_pred = torch.max(time_pred_logits.data, 1)

            note_input = torch.unsqueeze(note_pred, 1)
            duree_input = torch.unsqueeze(duree_pred, 1)
            time_input = torch.unsqueeze(time_pred, 1)
            velocity_input = torch.unsqueeze(veloc_pred, 1)

          decoder_inputs = (note_input.detach(), duree_input.detach(), time_input.detach(), velocity_input.detach())            


        accuracy_note = 100 * correct_note / total
        accuracy_duree = 100 * correct_duree / total
        accuracy_time = 100 * correct_time / total
        distance_velo = sum_distance / count

    return (accuracy_note, accuracy_duree, accuracy_time, distance_velo)

# **Training loop** 

In [31]:
def train_model(models, optimizers, train_set, val_set, num_epochs, lr_scheduler=None, display_loss=False, teacher_forcing_ratio = 1., max_train_iter = 100000000):
  encoder_right, encoder_left_p, decoder_left = models
  optimizer_right, optimizer_left_p, optimizer_left = optimizers

  criterion_note = nn.CrossEntropyLoss(reduce=False)
  criterion_duree = nn.CrossEntropyLoss(reduce=False)
  criterion_time = nn.CrossEntropyLoss(reduce=False)

  best_val_accuracy_note = 0
  best_epoch_note = 0

  best_val_accuracy_duree = 0
  best_epoch_duree = 0

  best_val_accuracy_time = 0
  best_epoch_time = 0

  best_val_distance_velo = 1000
  best_epoch_velo = 0

  n_total_train = len(train_set[0])
  n_total_val = len(val_set[0]) 

  for epoch in range(num_epochs):

    start = time.time()

    encoder_right.train()
    encoder_left_p.train()
    decoder_left.train()

    #### UPDATE LEARNING RATE #### 
    if lr_scheduler == 'multi_steps':
        if epoch in [int(num_epochs * 0.5)]:
            for param_group in optimizer_right.param_groups:
                param_group['lr'] *= 0.1
            for param_group in optimizer_left_p.param_groups:
                param_group['lr'] *= 0.1
            for param_group in optimizer_left.param_groups:
                param_group['lr'] *= 0.1


    # Batch generation
    set_id_epoch = set(range(n_total_train))
    n_iter_per_epoch = int(len(set_id_epoch) / batch_size) + 1

    #for iter in range(n_iter_per_epoch):
    for iter in range(min(max_train_iter, n_iter_per_epoch)):
      list_id_batch = random.sample(set_id_epoch, min(batch_size, len(set_id_epoch)))
      for id in list_id_batch: # Remove id from set
        set_id_epoch.remove(id)

      current_batch_size = len(list_id_batch)


      ################# PROCESS SEQUENCE #################

      seq_r, seq_l_p, seq_l  = generate_batch(train_set, list_id_batch)

      (notes_sequence_r, duree_sequence_r, time_sequence_r, velocity_sequence_r) = process_sequence(seq_r)
      (notes_sequence_l_p, duree_sequence_l_p, time_sequence_l_p, velocity_sequence_l_p) = process_sequence(seq_l_p)
      (notes_sequence_l, duree_sequence_l, time_sequence_l, velocity_sequence_l) = process_sequence(seq_l)

      ################### ENCODING ###################

      # Encode right
      mask_r = 1 - (notes_sequence_r == 129).float()
      pad_mask_r = torch.sum(mask_r, 1).int().cpu().detach().numpy()

      inputs_r = (notes_sequence_r, duree_sequence_r, time_sequence_r, velocity_sequence_r)
      hidden_right = encoder_right(inputs_r, pad_mask_r)
      
      # Encode left previous
      mask_l_p = 1 - (notes_sequence_l_p == 129).float()
      pad_mask_l_p = torch.sum(mask_l_p, 1).int().cpu().detach().numpy()

      inputs_l_p = (notes_sequence_l_p, duree_sequence_l_p, time_sequence_l_p, velocity_sequence_l_p)
      hidden_left_p = encoder_left_p(inputs_l_p, pad_mask_l_p)

      # Concatenate hidden states
      hidden_1_r, hidden_2_r = hidden_right
      hidden_1_lp, hidden_2_lp = hidden_left_p

      hidden_1 = torch.cat((hidden_1_r, hidden_1_lp), dim=2)
      hidden_2 = torch.cat((hidden_2_r, hidden_2_lp), dim=2)

      #hidden_1 = torch.randn((hidden_1.shape[0], hidden_1.shape[1], hidden_1.shape[2])).to(device)
      #hidden_2 = torch.randn((hidden_2.shape[0], hidden_2.shape[1], hidden_2.shape[2])).to(device)

      hidden = (hidden_1, hidden_2)

      ################### DECODING ###################

      # Decoder initial input (start tag)
      stack_decoder_input = torch.unsqueeze(torch.tensor([[130, 34, 20, -1]], device=device).repeat(current_batch_size, 1), 1).to(device)

      note_input = stack_decoder_input[:,:,0]
      duree_input = stack_decoder_input[:,:,1]
      time_input = stack_decoder_input[:,:,2]
      velocity_input = stack_decoder_input[:,:,3].float()
      velocity_input = torch.unsqueeze(velocity_input, 2)

      decoder_inputs = (note_input, duree_input, time_input, velocity_input)

      # Teacher forcing
      use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
      use_teacher_forcing

      # Decode left sequence
      
      loss_total = torch.tensor(0., device=device)
      length_batch = notes_sequence_l.shape[1]
      total_count_batch = 0

      for i in range(length_batch):
        # Predictions
        note_target = notes_sequence_l[:,i]
        duree_target = duree_sequence_l[:,i]
        time_target = time_sequence_l[:,i]
        velocity_target = velocity_sequence_l[:,i]

        targets = (note_target, duree_target, time_target)
        note_pred_logits, duree_pred_logits, time_pred_logits, veloc_pred, hidden = decoder(decoder_inputs, targets,  hidden)

        # Compute loss
        mask = 1 - (note_target ==  129).float()
        count_non_null = torch.sum(mask)
        total_count_batch += count_non_null.item()
        #print('----')
        #print(mask.shape)
        #print(veloc_pred.shape)
        #print(velocity_target.shape)
        #print(torch.abs(veloc_pred - velocity_target).shape)
        #print((mask * torch.abs(veloc_pred - velocity_target)).shape)
        

        loss_note = torch.sum(mask * criterion_note(note_pred_logits, note_target))
        loss_duree = torch.sum(mask * criterion_duree(duree_pred_logits, duree_target)) 
        loss_time = torch.sum(mask * criterion_time(time_pred_logits, time_target)) 
        loss_veloc = torch.sum(mask * torch.abs(veloc_pred.squeeze() - velocity_target.squeeze()))


        #print((loss_note + loss_duree + loss_time + loss_veloc))
        loss_total += ( loss_note + loss_duree + loss_time + loss_veloc)

        # New inputs
        if use_teacher_forcing:
          note_input = torch.unsqueeze(note_target, 1)
          duree_input = torch.unsqueeze(duree_target, 1)
          time_input = torch.unsqueeze(time_target, 1)
          velocity_input = torch.unsqueeze(velocity_target, 1)

        else:
          _, note_pred = torch.max(note_pred_logits.data, 1)
          _, duree_pred = torch.max(duree_pred_logits.data, 1)
          _, time_pred = torch.max(time_pred_logits.data, 1)

          note_input = torch.unsqueeze(note_pred, 1)
          duree_input = torch.unsqueeze(duree_pred, 1)
          time_input = torch.unsqueeze(time_pred, 1)
          velocity_input = torch.unsqueeze(veloc_pred, 1)

        #decoder_inputs = (note_input.detach(), duree_input.detach(), time_input.detach(), velocity_input.detach())
        decoder_inputs = (note_input, duree_input, time_input, velocity_input)

        if i == 0:
          loss_note_display = loss_note.item() / count_non_null.item()
          loss_duree_display = loss_duree.item() / count_non_null.item()
          loss_time_display = loss_time.item() / count_non_null.item()
          loss_veloc_display = loss_veloc.item() / count_non_null.item()

      #print(criterion_time(time_pred_logits, time_target))
      #print(criterion_note(note_pred_logits, note_target))
      #print(mask)
  

      # Backward and weight update
      optimizer_right.zero_grad()
      optimizer_left_p.zero_grad()
      optimizer_left.zero_grad()

      loss_total = loss_total / total_count_batch
      loss_total.backward()

      optimizer_right.step()
      optimizer_left_p.step()
      optimizer_left.step()
    
    ################### VALIDATION ###################
    # Train accuracy 
    train_accuracy_note, train_accuracy_duree, train_accuracy_time, train_distance_velo = validate_model(models, train_set, use_teacher_forcing = True, max_iter=4)
    train_accuracy_note, train_accuracy_duree, train_accuracy_time, train_distance_velo = round(train_accuracy_note, 2), round(train_accuracy_duree, 2), round(train_accuracy_time, 2), round(train_distance_velo, 2)

    # Val accuracy
    val_accuracy_note, val_accuracy_duree, val_accuracy_time, val_distance_velo = validate_model(models, val_set, use_teacher_forcing = True, max_iter=4)
    val_accuracy_note, val_accuracy_duree, val_accuracy_time, val_distance_velo = round(val_accuracy_note, 2), round(val_accuracy_duree, 2), round(val_accuracy_time, 2), round(val_distance_velo, 2)

    if val_accuracy_note > best_val_accuracy_note:
      best_val_accuracy_note = val_accuracy_note
      best_epoch_note = epoch

    if val_accuracy_duree > best_val_accuracy_duree:
      best_val_accuracy_duree = val_accuracy_duree
      best_epoch_duree = epoch

    if val_accuracy_time > best_val_accuracy_time:
      best_val_accuracy_time = val_accuracy_time
      best_epoch_time = epoch

    if val_distance_velo < best_val_distance_velo:
      best_val_distance_velo = val_distance_velo
      best_epoch_velo = epoch

    end = time.time()

    print('################')
    print(f'Epoch: {epoch}, Time: {round(end - start, 2)}, Loss note: {round(loss_note_display, 4)}, Loss  duree: {round(loss_duree_display, 4)}, Loss time: {round(loss_time_display, 4)}, Loss velocity: {round(100 * loss_veloc_display, 2)}, Length batch: {length_batch}')
    print('------')
    print(f'Epoch : {epoch}, Train accuracy note : {train_accuracy_note} %, Val accuracy note : {val_accuracy_note} %')
    print(f'Best val accuracy at epoch {best_epoch_note}: {best_val_accuracy_note} %')
    print('------')
    print(f'Epoch : {epoch}, Train accuracy duree : {train_accuracy_duree} %, Val accuracy duree : {val_accuracy_duree} %')
    print(f'Best val accuracy at epoch {best_epoch_duree}: {best_val_accuracy_duree} %')
    print('------')
    print(f'Epoch : {epoch}, Train accuracy time : {train_accuracy_time} %, Val accuracy time: {val_accuracy_time} %')
    print(f'Best val accuracy at epoch {best_epoch_time}: {best_val_accuracy_time} %')
    print('------')
    print(f'Epoch : {epoch}, Train distance velo : {train_distance_velo}, Val distance velo: {val_distance_velo}')
    print(f'Best val distance at epoch {best_epoch_velo}: {best_val_distance_velo}')
    #print('################')

# **Experiments**

In [32]:
encoder_right = Encoder_right(num_classes_note, num_classes_duree, num_classes_time, embedding_dim_note=128, embedding_dim_duree=128, embedding_dim_time=128, embedding_dim_veloc=128, hidden_size=256, num_layers=3).to(device)
encoder_left_p = Encoder_left_p(num_classes_note, num_classes_duree, num_classes_time, embedding_dim_note=128, embedding_dim_duree=128, embedding_dim_time=128, embedding_dim_veloc=128, hidden_size=256, num_layers=3).to(device)
decoder = Decoder(num_classes_note, num_classes_duree, num_classes_time, embedding_dim_note=128, embedding_dim_duree=128, embedding_dim_time=128, embedding_dim_veloc=128, hidden_size=512, num_layers=3).to(device)

start_lr = 0.05

optimizer_right = torch.optim.SGD(encoder_right.parameters(), lr=start_lr, nesterov=True, momentum=0.9)
optimizer_left_p = torch.optim.SGD(encoder_left_p.parameters(), lr=start_lr, nesterov=True, momentum=0.9)
optimizer_left = torch.optim.SGD(decoder.parameters(), lr=start_lr, nesterov=True, momentum=0.9)

models = (encoder_right, encoder_left_p, decoder)
optimizers = (optimizer_right, optimizer_left_p, optimizer_left)


In [None]:
train_model(models=models, optimizers=optimizers, train_set=train_set, val_set=val_set, num_epochs=200, lr_scheduler=None, display_loss=False, teacher_forcing_ratio=1., max_train_iter=1000000000)