<a href="https://colab.research.google.com/github/GiovanniSorice/Deep_Music_Generator/blob/main/notebooks/Music_Generation_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transformer Music Generator 



In this notebook, we use an Transformer to generate some music.


**This notebook was inspired (and part of the code comes from it) by [Music_Generation_LSTM](https://colab.research.google.com/drive/19TQqekOlnOSW36VCL8CPVEQKBBukmaEQ#scrollTo=DDOBVWULXfpz)**




**Load dependencies**

In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"


In [2]:
pip install compressive_transformer_pytorch

Collecting compressive_transformer_pytorch
  Downloading https://files.pythonhosted.org/packages/30/39/b8caf2671abcb8615977c08766aa9f450addd6949f57c7dda87224e844b5/compressive_transformer_pytorch-0.3.20-py3-none-any.whl
Collecting mogrifier
  Downloading https://files.pythonhosted.org/packages/77/01/62a55d0f8048e788fce435f2ade6478f443e4e53ed9b89b55ba0fc42c198/mogrifier-0.0.3-py3-none-any.whl
Installing collected packages: mogrifier, compressive-transformer-pytorch
Successfully installed compressive-transformer-pytorch-0.3.20 mogrifier-0.0.3


In [3]:
import torch
import tqdm
import numpy as np
import pandas as pd
import tensorflow as tf
from compressive_transformer_pytorch import CompressiveTransformer
from compressive_transformer_pytorch import AutoregressiveWrapper
from torchsummary import summary
from torch.utils.data import DataLoader, Dataset
from tensorflow.keras import utils
from sklearn.metrics import roc_auc_score 
import matplotlib.pyplot as plt
import glob
import pickle
from music21 import converter, instrument, stream, note, chord
import math
import shutil

In [4]:
# Set to false if you are not running
# this notebook in Google Colaboratory
run_on_colab = True

**Set hyperparameters**

In [5]:
# output directory name:
output_dir = '/content/drive/My Drive/ISPR_project/Transformer/'
current_path ='/content/drive/My Drive/ISPR_project/'
# training:
epochs = 2000
batch_size = 64
learning_rate=1e-2
# vector-space embedding: 
n_dim = 64 
sequence_length = 64


VALIDATE_EVERY  = 5

GENERATE_EVERY  = 500



**Save model function**

In [6]:
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, output_dir+filename)
    if is_best:
        shutil.copyfile(output_dir+filename, output_dir+'model_best.pth.tar')

**Google drive configuration (only Colab)**

In [7]:
!nvidia-smi

Wed Dec 16 08:26:30 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.45.01    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   49C    P8    10W /  70W |      0MiB / 15079MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
if(run_on_colab):
  from google.colab import drive
  # This will prompt for authorization.
  drive.mount('/content/drive')

**Load data** \\
Original MIDI files
 I have obtained  **MIDI files** from [The Lakh MIDI Dataset v0.1](https://colinraffel.com/projects/lmd/). 

## Processing data

Let's process the files, and load them into **music21**

In [None]:
file = current_path+"midi_songs/dataset/Metal/Metallica/Am I Evil?.mid"
midi = converter.parse(file)
notes_to_parse = midi.flat.notes
for element in notes_to_parse[:10]:
  print(element, element.offset)

<music21.chord.Chord E2 E3 B3 E4> 0.0
<music21.chord.Chord E2 E3 B3 E4> 0.0
<music21.note.Note E> 0.0
<music21.chord.Chord C2 C#3> 0.0
<music21.note.Note G#> 2.0
<music21.chord.Chord D3 A3 D4> 3.0
<music21.chord.Chord D3 A3 D4> 3.0
<music21.note.Note D> 3.0
<music21.chord.Chord C#3 C2> 3.0
<music21.chord.Chord B3 E3 E4> 3.5


I will process all MIDI files obtaining data from each note of chord.

- If I process a **note**, I will store in the list a string representing the pitch (the note name) and the octave.

- If I process a **chord** (Remember that chords are set of notes that are played at the same time) I will store a different type of string with numbers separated by dots. Each number represents the pitch of a chord note. 

As you can see, **I are not considering yet time offsets of each element**. In this first version, we won't consider them, so all the notes and chords will have the same duration. Maybe, in the future, I will consider them.

I are creating a big list with all the elements of all the compositions.

In [None]:
notes_for_instruments = []
for i,file in enumerate(glob.glob(current_path+"midi_songs/dataset/*/*/*.mid")):
      midi = converter.parse(file)
      print('Parsing file ', i, " ", file)
      notes_to_parse = None
      try:  # file has instrument parts
          s2 = instrument.partitionByInstrument(midi)
          notes_to_parse = s2.recurse()
      except:  # file has notes in a flat structure
          notes_to_parse = midi.flat.notes
      notes_instrument = []
      for element in notes_to_parse:
          if isinstance(element, note.Note):
              notes_instrument.append(str(element.pitch))
          elif isinstance(element, chord.Chord):
              notes_instrument.append('.'.join(str(n) for n in element.normalOrder))
      notes_for_instruments.append(notes_instrument)
with open(current_path + 'notes_for_instruments', 'wb') as filepath:
    pickle.dump(notes_for_instruments, filepath)


In [None]:
notes_for_instruments_validation = []
for i,file in enumerate(glob.glob(current_path+"midi_songs/validation/*.mid")):
      midi = converter.parse(file)
      print('Parsing file ', i, " ", file)
      notes_to_parse = None
      try:  # file has instrument parts
          s2 = instrument.partitionByInstrument(midi)
          notes_to_parse = s2.recurse()
      except:  # file has notes in a flat structure
          notes_to_parse = midi.flat.notes
      notes_instrument = []
      for element in notes_to_parse:
          if isinstance(element, note.Note):
              notes_instrument.append(str(element.pitch))
          elif isinstance(element, chord.Chord):
              notes_instrument.append('.'.join(str(n) for n in element.normalOrder))
      notes_for_instruments_validation.append(notes_instrument)
with open(current_path + 'VALIDATION_notes_for_instruments', 'wb') as filepath:
    pickle.dump(notes_for_instruments_validation, filepath)


Parsing file  0   /content/drive/My Drive/ISPR_project/midi_songs/validation/Crazy Little Thing Called Love.mid
Parsing file  1   /content/drive/My Drive/ISPR_project/midi_songs/validation/Nothing Else Matters.2.mid
Parsing file  2   /content/drive/My Drive/ISPR_project/midi_songs/validation/King Nothing.1.mid
Parsing file  3   /content/drive/My Drive/ISPR_project/midi_songs/validation/Fixxxer.mid
Parsing file  4   /content/drive/My Drive/ISPR_project/midi_songs/validation/Motorbreath.mid
Parsing file  5   /content/drive/My Drive/ISPR_project/midi_songs/validation/Porch.mid
Parsing file  6   /content/drive/My Drive/ISPR_project/midi_songs/validation/A Kind of Magic.mid
Parsing file  7   /content/drive/My Drive/ISPR_project/midi_songs/validation/Don't Chain My Heart.mid
Parsing file  8   /content/drive/My Drive/ISPR_project/midi_songs/validation/Se tornerai.1.mid
Parsing file  9   /content/drive/My Drive/ISPR_project/midi_songs/validation/Pamela.1.mid


In [9]:
with open(current_path + 'notes_for_instruments', 'rb') as f:
    notes_for_instruments = pickle.load(f)

In [10]:
with open(current_path + 'VALIDATION_notes_for_instruments', 'rb') as f:
    notes_for_instruments_validation = pickle.load(f)

I obtain the number of different notes in our dataset, because this will be the **number of possible output classes**  of our model.

In [11]:
# Count different possible outputs
print(len(set(item for notes_for_instrument in notes_for_instruments for item in notes_for_instrument)))

833


In [12]:
# Count different possible outputs valifation
print(len(set(item for notes_for_instrument in notes_for_instruments_validation for item in notes_for_instrument)))

229


**Preprocess data** \\
Now, there is some **data processing** that I have to do:

- I will map each pitch or chord to an integer
- I will create pairs of input sequences and its corresponding output note

I can try different **sequence_length** to obtain different results. In this first version, I will use a sequence_length of 100.

The network will made its prediction of the next note (or chord), based on the previous *sequence_length* notes (or chords). 


In [13]:
# get all pitch names
pitchnames_training = set(item for notes_for_instrument in notes_for_instruments for item in notes_for_instrument)
pitchnames_validation = set(item for notes_for_instrument in notes_for_instruments_validation for item in notes_for_instrument)
pitchnames = sorted(pitchnames_training.union(pitchnames_validation))

In [14]:
n_vocab = len(pitchnames)
n_vocab

839

In [15]:
# create a dictionary to map pitches to integers
note_to_int = dict((note, number) for number, note in enumerate(pitchnames))
network_input = []
for notes in notes_for_instruments:
    if len(notes) - sequence_length<=0:
        print("canzone troppo corta")
    # create input sequences and the corresponding outputs
    for i in range(0, len(notes) - sequence_length, 1):
      # Map pitches of sequence_in to integers
      network_input.append([note_to_int[char] for char in notes[i:i + sequence_length]])
n_patterns = len(network_input)
# reshape the input into a format compatible with Transormer layers
network_input = np.reshape(network_input, (n_patterns, sequence_length))

In [16]:
# create a dictionary to map pitches to integers
note_to_int_validation = dict((notes_validation, number) for number, notes_validation in enumerate(pitchnames))
network_input_validation = []
network_output_validation = []
for notes_validation in notes_for_instruments_validation:
    if len(notes_validation) - sequence_length<=0:
        print("canzone troppo corta")
    # create input sequences and the corresponding outputs
    for i in range(0, len(notes_validation) - sequence_length, 1):
      # Map pitches of sequence_in to integers
      network_input_validation.append([note_to_int_validation[char] for char in notes_validation[i:i + sequence_length]])
n_patterns_validation = len(network_input_validation)
# reshape the input into a format compatible with Transormer layers
network_input_validation = np.reshape(network_input_validation, (n_patterns_validation, sequence_length))

Let's see the new metwork_input size

In [17]:
network_input.shape

(366889, 64)

In [18]:
network_input_validation.shape

(36341, 64)

**Design neural network architecture** 

In [19]:
def create_network(sequence_length, n_vocab):
    """ create the structure of the neural network """
    model = CompressiveTransformer(
    num_tokens = n_vocab,
    dim = sequence_length,
    depth = 6,
    seq_len = sequence_length,
    mem_len = sequence_length,
    cmem_len = 256,
    cmem_ratio = 4,
    memory_layers = [5,6],
    gru_gated_residual = False
    )

    model = AutoregressiveWrapper(model)
    model.cuda()
    return model

In [20]:
model = create_network(sequence_length,n_vocab)

print(model)


AutoregressiveWrapper(
  (net): CompressiveTransformer(
    (token_emb): Embedding(839, 64)
    (to_model_dim): Identity()
    (to_logits): Sequential(
      (0): Identity()
      (1): Linear(in_features=64, out_features=839, bias=True)
    )
    (attn_layers): ModuleList(
      (0): Residual(
        (fn): PreNorm(
          (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (fn): SelfAttention(
            (compress_mem_fn): ConvCompress(
              (conv): Conv1d(64, 64, kernel_size=(4,), stride=(4,))
            )
            (to_q): Linear(in_features=64, out_features=64, bias=False)
            (to_kv): Linear(in_features=64, out_features=128, bias=False)
            (to_out): Linear(in_features=64, out_features=64, bias=True)
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (dropout): Dropout(p=0.0, inplace=False)
            (reconstruction_attn_dropout): Dropout(p=0.0, inplace=False)
          )
        )
      )
      (1): Residual(

In [21]:
def cycle(loader):
    while True:
        for data in loader:
          yield data

data_train = torch.from_numpy(network_input).cuda()
train_loader = torch.utils.data.DataLoader(data_train, batch_size=32) 
cycle_train_loader  = cycle(DataLoader(data_train, batch_size = data_train.shape[0]))
num_batches=math.ceil(data_train.shape[0]/batch_size) # Total number of batches

In [22]:
#Validation
data_validation = torch.from_numpy(network_input_validation).cuda()
validation_loader = torch.utils.data.DataLoader(data_validation, batch_size=32) 
cycle_validation_loader  = cycle(DataLoader(data_validation, batch_size = data_validation.shape[0]))
num_batches_val=math.ceil(data_validation.shape[0]/batch_size) # Total number of batches

In [23]:
# optimizer

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In case we want to use previously trained weights, to continue the training in the point we left it, we should load them into the model.

This is very useful in Google Colaboratory, that usually kills the virtual machine that is executing the Jupyter notework after a certime amount of time. If this happens to you, you should have to look for the last weights file in your configured Drive account and use it to train the network.


In [24]:
# In case we want to use previously trained weights
weights = "model_best.pth.tar"
checkpoint = torch.load("/content/drive/MyDrive/ISPR_project/Transformer/model_best.pth.tar")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']


In [None]:
# training
for i in tqdm.tqdm(range(211,epochs), mininterval=20., desc='training'):
    model.train()
    tot_loss = 0.0
    is_best=0
    best_loss_value=n_vocab
    avg_loss_val=0
    for mlm_loss, aux_loss, is_last in model(next(cycle_train_loader), max_batch_size = batch_size, return_loss = True):
        loss = mlm_loss + aux_loss

        loss.backward()

        tot_loss+=loss;

        if is_last:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            optimizer.zero_grad()
    
    if i % VALIDATE_EVERY == 0 or i==epochs-1:
      model.eval()
      with torch.no_grad():
          for loss_val, aux_loss_val, is_last_val in model(next(cycle_validation_loader), max_batch_size = batch_size, return_loss = True):
            avg_loss_val+=loss_val/num_batches_val;

            if is_last_val:
              print(f'\n validation loss: {avg_loss_val.item():.4f}')


    avg_loss=tot_loss/num_batches

    if i%5==0 or i==epochs-1:
      if best_loss_value>avg_loss:
        best_loss_value=avg_loss;
        is_best=1

      save_checkpoint({
      'epoch': i,
      'model_state_dict': model.state_dict(),
      'optimizer_state_dict' : optimizer.state_dict(),
      'loss':avg_loss.item(),
     }, is_best, 'Tran_64_Checkpoint'+str(i)+'_'+"{:.4f}".format(avg_loss.item())+'.pth.tar')
      is_best=0
    print(f'\n Epoch: {i} |Training loss: {avg_loss.item():.4f}')
print('\nTraining complete.')






training:   0%|          | 1/1789 [04:51<144:35:19, 291.12s/it]


 Epoch: 211 |Training loss: 3.4283


training:   0%|          | 2/1789 [09:37<143:50:00, 289.76s/it]


 Epoch: 212 |Training loss: 3.4230


training:   0%|          | 3/1789 [14:20<142:46:46, 287.80s/it]


 Epoch: 213 |Training loss: 3.4177


training:   0%|          | 4/1789 [19:02<141:49:34, 286.04s/it]


 Epoch: 214 |Training loss: 3.4125

 validation loss: 3.7472


training:   0%|          | 5/1789 [23:56<142:55:06, 288.40s/it]


 Epoch: 215 |Training loss: 3.4073


training:   0%|          | 6/1789 [28:43<142:37:16, 287.96s/it]


 Epoch: 216 |Training loss: 3.4021


training:   0%|          | 7/1789 [33:32<142:42:08, 288.29s/it]


 Epoch: 217 |Training loss: 3.3969


training:   0%|          | 8/1789 [38:19<142:25:58, 287.90s/it]


 Epoch: 218 |Training loss: 3.3918


training:   1%|          | 9/1789 [43:08<142:31:16, 288.25s/it]


 Epoch: 219 |Training loss: 3.3867

 validation loss: 3.7301


training:   1%|          | 10/1789 [48:10<144:28:16, 292.35s/it]


 Epoch: 220 |Training loss: 3.3816


training:   1%|          | 11/1789 [52:59<143:50:16, 291.24s/it]


 Epoch: 221 |Training loss: 3.3765


training:   1%|          | 12/1789 [57:47<143:15:41, 290.23s/it]


 Epoch: 222 |Training loss: 3.3715


training:   1%|          | 13/1789 [1:02:33<142:35:48, 289.05s/it]


 Epoch: 223 |Training loss: 3.3665


training:   1%|          | 14/1789 [1:07:19<142:07:45, 288.26s/it]


 Epoch: 224 |Training loss: 3.3615

 validation loss: 3.7141


training:   1%|          | 15/1789 [1:12:17<143:24:06, 291.01s/it]


 Epoch: 225 |Training loss: 3.3565


training:   1%|          | 16/1789 [1:17:03<142:40:08, 289.68s/it]


 Epoch: 226 |Training loss: 3.3516


training:   1%|          | 17/1789 [1:21:50<142:08:58, 288.79s/it]


 Epoch: 227 |Training loss: 3.3466


training:   1%|          | 18/1789 [1:26:37<141:42:20, 288.05s/it]


 Epoch: 228 |Training loss: 3.3417


training:   1%|          | 19/1789 [1:31:23<141:20:28, 287.47s/it]


 Epoch: 229 |Training loss: 3.3369

 validation loss: 3.6991


training:   1%|          | 20/1789 [1:36:19<142:35:08, 290.17s/it]


 Epoch: 230 |Training loss: 3.3320


training:   1%|          | 21/1789 [1:41:05<141:51:01, 288.84s/it]


 Epoch: 231 |Training loss: 3.3272


training:   1%|          | 22/1789 [1:45:50<141:10:03, 287.61s/it]


 Epoch: 232 |Training loss: 3.3224


training:   1%|▏         | 23/1789 [1:50:36<140:58:09, 287.37s/it]


 Epoch: 233 |Training loss: 3.3176


training:   1%|▏         | 24/1789 [1:55:22<140:35:38, 286.76s/it]


 Epoch: 234 |Training loss: 3.3128

 validation loss: 3.6851


training:   1%|▏         | 25/1789 [2:00:30<143:40:06, 293.20s/it]


 Epoch: 235 |Training loss: 3.3081


training:   1%|▏         | 26/1789 [2:05:21<143:16:43, 292.57s/it]


 Epoch: 236 |Training loss: 3.3034


training:   2%|▏         | 27/1789 [2:10:08<142:22:02, 290.88s/it]


 Epoch: 237 |Training loss: 3.2987


training:   2%|▏         | 28/1789 [2:14:53<141:25:51, 289.13s/it]


 Epoch: 238 |Training loss: 3.2940


training:   2%|▏         | 29/1789 [2:19:37<140:37:33, 287.64s/it]


 Epoch: 239 |Training loss: 3.2893

 validation loss: 3.6720


training:   2%|▏         | 30/1789 [2:24:32<141:31:20, 289.64s/it]


 Epoch: 240 |Training loss: 3.2847


training:   2%|▏         | 31/1789 [2:29:15<140:30:47, 287.74s/it]


 Epoch: 241 |Training loss: 3.2801


training:   2%|▏         | 32/1789 [2:33:58<139:46:19, 286.39s/it]


 Epoch: 242 |Training loss: 3.2755


training:   2%|▏         | 33/1789 [2:38:41<139:14:11, 285.45s/it]


 Epoch: 243 |Training loss: 3.2709


training:   2%|▏         | 34/1789 [2:43:24<138:46:11, 284.66s/it]


 Epoch: 244 |Training loss: 3.2664

 validation loss: 3.6597


training:   2%|▏         | 35/1789 [2:48:18<140:05:51, 287.54s/it]


 Epoch: 245 |Training loss: 3.2618


training:   2%|▏         | 36/1789 [2:53:02<139:22:55, 286.24s/it]


 Epoch: 246 |Training loss: 3.2573


training:   2%|▏         | 37/1789 [2:57:45<138:50:16, 285.28s/it]


 Epoch: 247 |Training loss: 3.2528


training:   2%|▏         | 38/1789 [3:02:28<138:27:33, 284.67s/it]


 Epoch: 248 |Training loss: 3.2483


training:   2%|▏         | 39/1789 [3:07:11<138:09:57, 284.23s/it]


 Epoch: 249 |Training loss: 3.2439

 validation loss: 3.6482


training:   2%|▏         | 40/1789 [3:12:05<139:29:43, 287.13s/it]


 Epoch: 250 |Training loss: 3.2394


training:   2%|▏         | 41/1789 [3:16:49<138:55:16, 286.11s/it]


 Epoch: 251 |Training loss: 3.2350


training:   2%|▏         | 42/1789 [3:21:35<138:53:14, 286.20s/it]


 Epoch: 252 |Training loss: 3.2306


training:   2%|▏         | 43/1789 [3:26:24<139:13:51, 287.07s/it]


 Epoch: 253 |Training loss: 3.2262


training:   2%|▏         | 44/1789 [3:31:12<139:19:25, 287.43s/it]


 Epoch: 254 |Training loss: 3.2218

 validation loss: 3.6375


training:   3%|▎         | 45/1789 [3:36:10<140:39:49, 290.36s/it]


 Epoch: 255 |Training loss: 3.2174


training:   3%|▎         | 46/1789 [3:40:57<140:10:31, 289.52s/it]


 Epoch: 256 |Training loss: 3.2131


training:   3%|▎         | 47/1789 [3:45:44<139:40:35, 288.65s/it]


 Epoch: 257 |Training loss: 3.2087


training:   3%|▎         | 48/1789 [3:50:30<139:11:54, 287.83s/it]


 Epoch: 258 |Training loss: 3.2044


training:   3%|▎         | 49/1789 [3:55:15<138:46:15, 287.11s/it]


 Epoch: 259 |Training loss: 3.2001


training:   3%|▎         | 50/1789 [4:00:12<140:06:02, 290.03s/it]


 validation loss: 3.6275

 Epoch: 260 |Training loss: 3.1958


training:   3%|▎         | 51/1789 [4:04:59<139:35:00, 289.13s/it]


 Epoch: 261 |Training loss: 3.1916


training:   3%|▎         | 52/1789 [4:09:46<139:10:58, 288.46s/it]


 Epoch: 262 |Training loss: 3.1873


training:   3%|▎         | 53/1789 [4:14:33<138:54:55, 288.07s/it]


 Epoch: 263 |Training loss: 3.1831


training:   3%|▎         | 54/1789 [4:19:20<138:40:29, 287.74s/it]


 Epoch: 264 |Training loss: 3.1788

 validation loss: 3.6181


training:   3%|▎         | 55/1789 [4:24:19<140:14:40, 291.17s/it]


 Epoch: 265 |Training loss: 3.1746


training:   3%|▎         | 56/1789 [4:29:05<139:23:36, 289.57s/it]


 Epoch: 266 |Training loss: 3.1704


training:   3%|▎         | 57/1789 [4:33:52<138:53:46, 288.70s/it]


 Epoch: 267 |Training loss: 3.1662


training:   3%|▎         | 58/1789 [4:38:39<138:32:50, 288.14s/it]


 Epoch: 268 |Training loss: 3.1621


training:   3%|▎         | 59/1789 [4:43:25<138:16:14, 287.73s/it]


 Epoch: 269 |Training loss: 3.1579

 validation loss: 3.6094


training:   3%|▎         | 60/1789 [4:48:21<139:21:33, 290.16s/it]


 Epoch: 270 |Training loss: 3.1537


training:   3%|▎         | 61/1789 [4:53:12<139:25:24, 290.47s/it]


 Epoch: 271 |Training loss: 3.1496


training:   3%|▎         | 62/1789 [4:58:07<139:57:13, 291.74s/it]


 Epoch: 272 |Training loss: 3.1455


training:   4%|▎         | 63/1789 [5:03:00<140:04:04, 292.15s/it]


 Epoch: 273 |Training loss: 3.1414


training:   4%|▎         | 64/1789 [5:07:53<140:02:20, 292.26s/it]


 Epoch: 274 |Training loss: 3.1373

 validation loss: 3.6012


training:   4%|▎         | 65/1789 [5:12:58<141:53:00, 296.28s/it]


 Epoch: 275 |Training loss: 3.1332


training:   4%|▎         | 66/1789 [5:17:49<141:00:35, 294.62s/it]


 Epoch: 276 |Training loss: 3.1291


training:   4%|▎         | 67/1789 [5:22:41<140:28:33, 293.68s/it]


 Epoch: 277 |Training loss: 3.1250


training:   4%|▍         | 68/1789 [5:27:31<139:58:14, 292.79s/it]


 Epoch: 278 |Training loss: 3.1210


training:   4%|▍         | 69/1789 [5:32:21<139:27:58, 291.91s/it]


 Epoch: 279 |Training loss: 3.1170

 validation loss: 3.5936


training:   4%|▍         | 70/1789 [5:37:20<140:26:28, 294.12s/it]


 Epoch: 280 |Training loss: 3.1129


training:   4%|▍         | 71/1789 [5:42:08<139:25:53, 292.17s/it]


 Epoch: 281 |Training loss: 3.1089


training:   4%|▍         | 72/1789 [5:46:54<138:27:00, 290.29s/it]


 Epoch: 282 |Training loss: 3.1049


training:   4%|▍         | 73/1789 [5:51:41<137:51:46, 289.22s/it]


 Epoch: 283 |Training loss: 3.1009


training:   4%|▍         | 74/1789 [5:56:26<137:16:49, 288.17s/it]


 Epoch: 284 |Training loss: 3.0969

 validation loss: 3.5865


training:   4%|▍         | 75/1789 [6:01:23<138:20:33, 290.57s/it]


 Epoch: 285 |Training loss: 3.0930


training:   4%|▍         | 76/1789 [6:06:10<137:48:34, 289.62s/it]


 Epoch: 286 |Training loss: 3.0890


training:   4%|▍         | 77/1789 [6:10:59<137:34:55, 289.31s/it]


 Epoch: 287 |Training loss: 3.0851


training:   4%|▍         | 78/1789 [6:15:46<137:12:40, 288.70s/it]


 Epoch: 288 |Training loss: 3.0811


training:   4%|▍         | 79/1789 [6:20:33<136:56:54, 288.31s/it]


 Epoch: 289 |Training loss: 3.0772

 validation loss: 3.5799


training:   4%|▍         | 80/1789 [6:25:32<138:25:15, 291.58s/it]


 Epoch: 290 |Training loss: 3.0733


training:   5%|▍         | 81/1789 [6:30:22<138:06:10, 291.08s/it]


 Epoch: 291 |Training loss: 3.0694


training:   5%|▍         | 82/1789 [6:35:13<138:00:18, 291.05s/it]


 Epoch: 292 |Training loss: 3.0655


training:   5%|▍         | 83/1789 [6:40:08<138:24:54, 292.08s/it]


 Epoch: 293 |Training loss: 3.0616


training:   5%|▍         | 84/1789 [6:45:01<138:27:47, 292.36s/it]


 Epoch: 294 |Training loss: 3.0577

 validation loss: 3.5738


training:   5%|▍         | 85/1789 [6:50:04<139:57:30, 295.69s/it]


 Epoch: 295 |Training loss: 3.0539


training:   5%|▍         | 86/1789 [6:54:56<139:21:11, 294.58s/it]


 Epoch: 296 |Training loss: 3.0500


training:   5%|▍         | 87/1789 [6:59:47<138:44:46, 293.47s/it]


 Epoch: 297 |Training loss: 3.0462


training:   5%|▍         | 88/1789 [7:04:35<137:55:29, 291.90s/it]


 Epoch: 298 |Training loss: 3.0424


training:   5%|▍         | 89/1789 [7:09:22<137:06:36, 290.35s/it]


 Epoch: 299 |Training loss: 3.0385

 validation loss: 3.5681


training:   5%|▌         | 90/1789 [7:14:20<138:04:36, 292.57s/it]


 Epoch: 300 |Training loss: 3.0347


training:   5%|▌         | 91/1789 [7:19:06<137:04:49, 290.63s/it]


 Epoch: 301 |Training loss: 3.0309


training:   5%|▌         | 92/1789 [7:23:52<136:20:15, 289.23s/it]


 Epoch: 302 |Training loss: 3.0271


training:   5%|▌         | 93/1789 [7:28:38<135:50:51, 288.36s/it]


 Epoch: 303 |Training loss: 3.0234


training:   5%|▌         | 94/1789 [7:33:24<135:20:12, 287.44s/it]


 Epoch: 304 |Training loss: 3.0196

 validation loss: 3.5629


training:   5%|▌         | 95/1789 [7:38:22<136:46:06, 290.65s/it]


 Epoch: 305 |Training loss: 3.0158


training:   5%|▌         | 96/1789 [7:43:13<136:49:53, 290.96s/it]


 Epoch: 306 |Training loss: 3.0121


training:   5%|▌         | 97/1789 [7:48:06<136:58:26, 291.43s/it]


 Epoch: 307 |Training loss: 3.0084


training:   5%|▌         | 98/1789 [7:52:55<136:30:17, 290.61s/it]


 Epoch: 308 |Training loss: 3.0046


training:   6%|▌         | 99/1789 [7:57:44<136:12:27, 290.15s/it]


 Epoch: 309 |Training loss: 3.0009

 validation loss: 3.5581


training:   6%|▌         | 100/1789 [8:02:44<137:29:21, 293.05s/it]


 Epoch: 310 |Training loss: 2.9972


training:   6%|▌         | 101/1789 [8:07:33<136:54:52, 292.00s/it]


 Epoch: 311 |Training loss: 2.9935


training:   6%|▌         | 102/1789 [8:12:23<136:31:51, 291.35s/it]


 Epoch: 312 |Training loss: 2.9898


training:   6%|▌         | 103/1789 [8:17:12<136:06:57, 290.64s/it]


 Epoch: 313 |Training loss: 2.9862


training:   6%|▌         | 104/1789 [8:22:01<135:48:13, 290.14s/it]


 Epoch: 314 |Training loss: 2.9825

 validation loss: 3.5536


training:   6%|▌         | 105/1789 [8:26:58<136:41:13, 292.20s/it]


 Epoch: 315 |Training loss: 2.9788


training:   6%|▌         | 106/1789 [8:31:44<135:43:57, 290.34s/it]


 Epoch: 316 |Training loss: 2.9752


training:   6%|▌         | 107/1789 [8:36:30<135:02:09, 289.02s/it]


 Epoch: 317 |Training loss: 2.9716


training:   6%|▌         | 108/1789 [8:41:16<134:35:11, 288.23s/it]


 Epoch: 318 |Training loss: 2.9680


training:   6%|▌         | 109/1789 [8:46:05<134:36:05, 288.43s/it]


 Epoch: 319 |Training loss: 2.9643

 validation loss: 3.5494


training:   6%|▌         | 110/1789 [8:51:05<136:11:16, 292.01s/it]


 Epoch: 320 |Training loss: 2.9608


training:   6%|▌         | 111/1789 [8:55:53<135:26:42, 290.59s/it]


 Epoch: 321 |Training loss: 2.9572


training:   6%|▋         | 112/1789 [9:00:40<134:52:14, 289.53s/it]


 Epoch: 322 |Training loss: 2.9536


training:   6%|▋         | 113/1789 [9:05:27<134:27:00, 288.79s/it]


 Epoch: 323 |Training loss: 2.9500


training:   6%|▋         | 114/1789 [9:10:14<134:11:11, 288.40s/it]


 Epoch: 324 |Training loss: 2.9464

 validation loss: 3.5455


training:   6%|▋         | 115/1789 [9:15:14<135:38:06, 291.69s/it]


 Epoch: 325 |Training loss: 2.9429


training:   6%|▋         | 116/1789 [9:20:02<135:05:35, 290.70s/it]


 Epoch: 326 |Training loss: 2.9394


training:   7%|▋         | 117/1789 [9:24:51<134:41:39, 290.01s/it]


 Epoch: 327 |Training loss: 2.9358


training:   7%|▋         | 118/1789 [9:29:39<134:22:43, 289.51s/it]


 Epoch: 328 |Training loss: 2.9323


training:   7%|▋         | 119/1789 [9:34:29<134:25:46, 289.79s/it]


 Epoch: 329 |Training loss: 2.9288

 validation loss: 3.5412


training:   7%|▋         | 120/1789 [9:39:31<135:57:25, 293.26s/it]


 Epoch: 330 |Training loss: 2.9253


training:   7%|▋         | 121/1789 [9:44:16<134:49:03, 290.97s/it]


 Epoch: 331 |Training loss: 2.9219


training:   7%|▋         | 122/1789 [9:49:01<133:49:12, 288.99s/it]


 Epoch: 332 |Training loss: 2.9186


training:   7%|▋         | 123/1789 [9:53:45<133:02:24, 287.48s/it]


 Epoch: 333 |Training loss: 2.9153


training:   7%|▋         | 124/1789 [9:58:29<132:31:20, 286.53s/it]


 Epoch: 334 |Training loss: 2.9121


**Music generation**

---



In [None]:
# In case we want to use previously trained weights
weights = "model_best.pth.tar"
checkpoint = torch.load(output_dir+weights)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']


In [None]:
# Generate network input again
network_input = []
network_output = []
for i in range(0, len(notes) - sequence_length, 1):
  network_input.append([note_to_int[char] for char in notes[i:i + sequence_length]])
n_patterns = len(network_input)
network_input = np.reshape(network_input, (n_patterns, sequence_length))


The workflow now is:


1.   Pick a **seed sequence** randomly from your list of inputs (*pattern* variable)
2.   Pass it as input for your model to generate a new element (note or chord)
3.   Add the new element to your final song and to your *pattern* list
4.   Remove the first item from *pattern*
5.   Go to step 2


In [None]:
""" Generate notes from the neural network based on a sequence of notes """
# pick a random sequence from the input as a starting point for the prediction
start = np.random.randint(0, len(network_input)-1)
int_to_note = dict((number, note) for number, note in enumerate(pitchnames))
pattern = torch.from_numpy(network_input[start]).cuda()

prediction_output = model.generate(pattern, 500)


In [None]:
result_sample=[]

for i in range(500):
  print(i)
  result = int_to_note[prediction_output[i].item()]
  print('\r', 'Predicted ', i, " ",result, end='')
  result_sample.append(result)

prediction_output=result_sample

0
 Predicted  0   61
 Predicted  1   4.62
 Predicted  2   6.113
 Predicted  3   64
 Predicted  4   6.115
 Predicted  5   A46
 Predicted  6   4.67
 Predicted  7   F48
 Predicted  8   69
 Predicted  9   610
 Predicted  10   5.7.9.011
 Predicted  11   2.3.7.1012
 Predicted  12   D513
 Predicted  13   C514
 Predicted  14   5.7.9.015
 Predicted  15   C516
 Predicted  16   4.617
 Predicted  17   B-118
 Predicted  18   10.2.519
 Predicted  19   C520
 Predicted  20   6.1121
 Predicted  21   622
 Predicted  22   F223
 Predicted  23   6.1124
 Predicted  24   4.625
 Predicted  25   B-226
 Predicted  26   B-127
 Predicted  27   A428
 Predicted  28   629
 Predicted  29   C530
 Predicted  30   E-331
 Predicted  31   F232
 Predicted  32   4.633
 Predicted  33   534
 Predicted  34   5.1035
 Predicted  35   4.636
 Predicted  36   637
 Predicted  37   4.638
 Predicted  38   4.639
 Predicted  39   F240
 Predicted  40   4.641
 Predicted  41   B-242
 Predicted  42

The last step is creating a MIDI file from the predictions.

**music21** will help us again for this task. We should create a **Stream** and add to it the predicted notes and chords.

We are adding an offset of 0.5 between elements.

In [None]:
offset = 0
output_notes = []
# create note and chord objects based on the values generated by the model
for pattern in prediction_output:
    # pattern is a chord
    if ('.' in pattern) or pattern.isdigit():
        notes_in_chord = pattern.split('.')
        notes = []
        for current_note in notes_in_chord:
            new_note = note.Note(int(current_note))
            new_note.storedInstrument = instrument.Piano()
            notes.append(new_note)
        new_chord = chord.Chord(notes)
        new_chord.offset = offset
        output_notes.append(new_chord)
    # pattern is a note
    else:
        new_note = note.Note(pattern)
        new_note.offset = offset
        new_note.storedInstrument = instrument.Piano()
        output_notes.append(new_note)

    # increase offset each iteration so that notes do not stack
    offset += 0.5

midi_stream = stream.Stream(output_notes)
midi_stream.write('midi', fp='test_output.mid')

'test_output.mid'