<a href="https://colab.research.google.com/github/GiovanniSorice/Deep_Music_Generator/blob/main/notebooks/Music_Generation_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transformer Music Generator 



In this notebook, we use an Transformer to generate some music.


**This notebook was inspired (and part of the code comes from it) by [Music_Generation_LSTM](https://colab.research.google.com/drive/19TQqekOlnOSW36VCL8CPVEQKBBukmaEQ#scrollTo=DDOBVWULXfpz)**




**Load dependencies**

In [1]:
pip install compressive_transformer_pytorch

Collecting compressive_transformer_pytorch
  Downloading https://files.pythonhosted.org/packages/30/39/b8caf2671abcb8615977c08766aa9f450addd6949f57c7dda87224e844b5/compressive_transformer_pytorch-0.3.20-py3-none-any.whl
Collecting mogrifier
  Downloading https://files.pythonhosted.org/packages/77/01/62a55d0f8048e788fce435f2ade6478f443e4e53ed9b89b55ba0fc42c198/mogrifier-0.0.3-py3-none-any.whl
Installing collected packages: mogrifier, compressive-transformer-pytorch
Successfully installed compressive-transformer-pytorch-0.3.20 mogrifier-0.0.3


In [2]:
import torch
import tqdm
import numpy as np
import pandas as pd
import tensorflow as tf
import os
from compressive_transformer_pytorch import CompressiveTransformer
from compressive_transformer_pytorch import AutoregressiveWrapper
from torchsummary import summary
from torch.utils.data import DataLoader, Dataset
from tensorflow.keras import utils
from sklearn.metrics import roc_auc_score 
import matplotlib.pyplot as plt
import glob
import pickle
from music21 import converter, instrument, stream, note, chord
import math
import shutil

In [3]:
# Set to false if you are not running
# this notebook in Google Colaboratory
run_on_colab = True

**Set hyperparameters**

In [4]:
# output directory name:
output_dir = '/content/drive/My Drive/ISPR_project/Transformer/'
current_path ='/content/drive/My Drive/ISPR_project/'
# training:
epochs = 500
batch_size = 32
learning_rate=1e-3
# vector-space embedding: 
n_dim = 64 
sequence_length = 128


VALIDATE_EVERY  = 100

GENERATE_EVERY  = 500



**Save model function**

In [5]:
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, output_dir+filename)
    if is_best:
        shutil.copyfile(output_dir+filename, output_dir+'model_best.pth.tar')

**Google drive configuration (only Colab)**

In [7]:
if(run_on_colab):
  from google.colab import drive
  # This will prompt for authorization.
  drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


**Load data** \\
Original MIDI files
 I have obtained  **MIDI files** from [The Lakh MIDI Dataset v0.1](https://colinraffel.com/projects/lmd/). 

## Processing data

Let's process the files, and load them into **music21**

In [8]:
file = current_path+"midi_songs/Andra tutto bene ('58).1.mid"
midi = converter.parse(file)
notes_to_parse = midi.flat.notes
for element in notes_to_parse[:10]:
  print(element, element.offset)

<music21.chord.Chord F3 F2> 4.0
<music21.note.Note A> 4.0
<music21.chord.Chord B1 F#3 F#2> 4.0
<music21.note.Note F> 4.0
<music21.chord.Chord C4 F4> 4.0
<music21.chord.Chord F#3 C#6 F#2> 4.5
<music21.note.Note C#> 4.75
<music21.chord.Chord F#2 E2 F#3> 5.0
<music21.chord.Chord A4 A3 F4 C4 A3> 5.0
<music21.note.Note F> 5.0


I will process all MIDI files obtaining data from each note of chord.

- If I process a **note**, I will store in the list a string representing the pitch (the note name) and the octave.

- If I process a **chord** (Remember that chords are set of notes that are played at the same time) I will store a different type of string with numbers separated by dots. Each number represents the pitch of a chord note. 

As you can see, **I are not considering yet time offsets of each element**. In this first version, we won't consider them, so all the notes and chords will have the same duration. Maybe, in the future, I will consider them.

I are creating a big list with all the elements of all the compositions.

In [17]:
notes = []
for i,file in enumerate(glob.glob("/content/drive/My Drive/ISPR_project/midi_songs/Andra tutto bene ('58).1.mid")):
  midi = converter.parse(file)
  print('\r', 'Parsing file ', i, " ",file, end='')
  notes_to_parse = None
  try: # file has instrument parts
    s2 = instrument.partitionByInstrument(midi)
    print("s2")
    print(s2)
    notes_to_parse = s2.parts[0].recurse() 
    print("notes_to_parse")
    print(notes_to_parse)
  except: # file has notes in a flat structure
    notes_to_parse = midi.flat.notes
    print("notes_to_parse flat")
    print(notes_to_parse)

  for element in notes_to_parse:
    print("element")
    print(element)

    if isinstance(element, note.Note):
      print("element.pitch")
      print(element.pitch)
      notes.append(str(element.pitch))
    elif isinstance(element, chord.Chord):
      print("chord")
      print('.'.join(str(n) for n in element.normalOrder))
      notes.append('.'.join(str(n) for n in element.normalOrder))
with open('notes', 'wb') as filepath:
  pickle.dump(notes, filepath)

len(notes)

 Parsing file  0   /content/drive/My Drive/ISPR_project/midi_songs/Andra tutto bene ('58).1.mids2
<music21.stream.Score 0x7f6f3c685080>
notes_to_parse
<music21.stream.iterator.RecursiveIterator for Part:Electric Bass @:0>
element
Electric Bass


0

In [9]:
notes = []
for i,file in enumerate(glob.glob(current_path+"midi_songs/*.mid")):
  midi = converter.parse(file)
  print('\r', 'Parsing file ', i, " ",file, end='')
  notes_to_parse = None
  try: # file has instrument parts
    s2 = instrument.partitionByInstrument(midi)
    notes_to_parse = s2.parts[0].recurse() 
  except: # file has notes in a flat structure
    notes_to_parse = midi.flat.notes
  for element in notes_to_parse:
    if isinstance(element, note.Note):
      notes.append(str(element.pitch))
    elif isinstance(element, chord.Chord):
      notes.append('.'.join(str(n) for n in element.normalOrder))
with open('notes', 'wb') as filepath:
  pickle.dump(notes, filepath)

 Parsing file  3   /content/drive/My Drive/ISPR_project/midi_songs/Andra tutto bene ('58).1.mid

I obtain the number of different notes in our dataset, because this will be the **number of possible output classes**  of our model.

In [10]:
# Count different possible outputs
n_vocab = (len(set(notes)))
n_vocab

71

**Preprocess data** \\
Now, there is some **data processing** that I have to do:

- I will map each pitch or chord to an integer
- I will create pairs of input sequences and its corresponding output note

I can try different **sequence_length** to obtain different results. In this first version, I will use a sequence_length of 100.

The network will made its prediction of the next note (or chord), based on the previous *sequence_length* notes (or chords). 


In [11]:
# get all pitch names
pitchnames = sorted(set(item for item in notes))
# create a dictionary to map pitches to integers
note_to_int = dict((note, number) for number, note in enumerate(pitchnames))
network_input = []
network_output = []
# create input sequences and the corresponding outputs
for i in range(0, len(notes) - sequence_length, 1):
  # Map pitches of sequence_in to integers
  network_input.append([note_to_int[char] for char in notes[i:i + sequence_length]])
n_patterns = len(network_input)
# reshape the input into a format compatible with LSTM layers
network_input = np.reshape(network_input, (n_patterns, sequence_length))
# normalize input
#network_input = network_input / float(n_vocab)


Let's see the new metwork_input size

In [12]:
network_input.shape

(4875, 128)

**Design neural network architecture** 

In [13]:
def create_network(sequence_length, n_vocab):
    """ create the structure of the neural network """
    model = CompressiveTransformer(
    num_tokens = n_vocab,
    dim = sequence_length,
    depth = 6,
    seq_len = sequence_length,
    mem_len = sequence_length,
    cmem_len = 256,
    cmem_ratio = 4,
    memory_layers = [5,6]
    )

    model = AutoregressiveWrapper(model)
    model.cuda()
    return model

In [15]:
model = create_network(sequence_length,n_vocab)

print(model)


AutoregressiveWrapper(
  (net): CompressiveTransformer(
    (token_emb): Embedding(71, 128)
    (to_model_dim): Identity()
    (to_logits): Sequential(
      (0): Identity()
      (1): Linear(in_features=128, out_features=71, bias=True)
    )
    (attn_layers): ModuleList(
      (0): GRUGating(
        (fn): PreNorm(
          (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (fn): SelfAttention(
            (compress_mem_fn): ConvCompress(
              (conv): Conv1d(128, 128, kernel_size=(4,), stride=(4,))
            )
            (to_q): Linear(in_features=128, out_features=128, bias=False)
            (to_kv): Linear(in_features=128, out_features=256, bias=False)
            (to_out): Linear(in_features=128, out_features=128, bias=True)
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (dropout): Dropout(p=0.0, inplace=False)
            (reconstruction_attn_dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (gru): GRUC

In [16]:
def cycle(loader):
    while True:
        for data in loader:
          yield data


data_train = torch.from_numpy(network_input).cuda()
train_loader = torch.utils.data.DataLoader(data_train, batch_size=32) 
cycle_train_loader  = cycle(DataLoader(data_train, batch_size = data_train.shape[0]))
num_bathes=math.ceil(data_train.shape[0]/batch_size) # Total number of batches

In [17]:
# optimizer

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In case we want to use previously trained weights, to continue the training in the point we left it, we should load them into the model.

This is very useful in Google Colaboratory, that usually kills the virtual machine that is executing the Jupyter notework after a certime amount of time. If this happens to you, you should have to look for the last weights file in your configured Drive account and use it to train the network.


In [18]:
# In case we want to use previously trained weights
weights = "model_best.pth.tar"
checkpoint = torch.load(output_dir+weights)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']


In [24]:
# training

for i in tqdm.tqdm(range(epochs), mininterval=20., desc='training'):
    model.train()
    avg_loss = 0.0
    is_best=0
    best_loss_value=n_vocab
    for mlm_loss, aux_loss, is_last in model(next(cycle_train_loader), max_batch_size = batch_size, return_loss = True):
        loss = mlm_loss + aux_loss
        loss.backward()

        avg_loss+=loss/num_bathes;

        if is_last:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            optimizer.zero_grad()

    if i%10==0:
      if best_loss_value>avg_loss:
        best_loss_value=avg_loss;
        is_best=1

      save_checkpoint({
      'epoch': i,
      'model_state_dict': model.state_dict(),
      'optimizer_state_dict' : optimizer.state_dict(),
      'loss':avg_loss.item(),
     }, is_best, 'Tran_Checkpoint'+str(i)+'_'+"{:.4f}".format(avg_loss.item())+'.pth.tar')
      is_best=0

    print(f'Epoch: {i} |Training loss: {avg_loss.item():.4f}')
print('Training complete.')









training:   0%|          | 0/500 [00:00<?, ?it/s][A[A[A

Epoch: 0 |Training loss: 1.9461
Epoch: 1 |Training loss: 1.9251





training:   1%|          | 3/500 [00:29<1:21:33,  9.85s/it][A[A[A

Epoch: 2 |Training loss: 1.9089
Epoch: 3 |Training loss: 1.8980





training:   1%|          | 3/500 [00:42<1:21:33,  9.85s/it][A[A[A

Epoch: 4 |Training loss: 1.8742





training:   1%|          | 6/500 [00:59<1:21:04,  9.85s/it][A[A[A

Epoch: 5 |Training loss: 1.8625
Epoch: 6 |Training loss: 1.8460





training:   1%|          | 6/500 [01:12<1:21:04,  9.85s/it][A[A[A

Epoch: 7 |Training loss: 1.8290





training:   2%|▏         | 9/500 [01:28<1:20:18,  9.81s/it][A[A[A

Epoch: 8 |Training loss: 1.8128
Epoch: 9 |Training loss: 1.7954





training:   2%|▏         | 9/500 [01:42<1:20:18,  9.81s/it][A[A[A

Epoch: 10 |Training loss: 1.7792





training:   2%|▏         | 12/500 [01:57<1:19:56,  9.83s/it][A[A[A

Epoch: 11 |Training loss: 1.7687
Epoch: 12 |Training loss: 1.7508





training:   2%|▏         | 12/500 [02:12<1:19:56,  9.83s/it][A[A[A

Epoch: 13 |Training loss: 1.7438





training:   3%|▎         | 15/500 [02:27<1:19:21,  9.82s/it][A[A[A

Epoch: 14 |Training loss: 1.7223
Epoch: 15 |Training loss: 1.7072





training:   3%|▎         | 15/500 [02:42<1:19:21,  9.82s/it][A[A[A

Epoch: 16 |Training loss: 1.6935





training:   4%|▎         | 18/500 [02:56<1:18:45,  9.80s/it][A[A[A

Epoch: 17 |Training loss: 1.6789
Epoch: 18 |Training loss: 1.6828





training:   4%|▎         | 18/500 [03:12<1:18:45,  9.80s/it][A[A[A

Epoch: 19 |Training loss: 1.6522





training:   4%|▍         | 21/500 [03:26<1:18:23,  9.82s/it][A[A[A

Epoch: 20 |Training loss: 1.6515
Epoch: 21 |Training loss: 1.6346





training:   4%|▍         | 21/500 [03:42<1:18:23,  9.82s/it][A[A[A

Epoch: 22 |Training loss: 1.6203





training:   5%|▍         | 24/500 [03:55<1:17:50,  9.81s/it][A[A[A

Epoch: 23 |Training loss: 1.6047
Epoch: 24 |Training loss: 1.5925





training:   5%|▍         | 24/500 [04:12<1:17:50,  9.81s/it][A[A[A

Epoch: 25 |Training loss: 1.5787





training:   5%|▌         | 27/500 [04:24<1:17:16,  9.80s/it][A[A[A

Epoch: 26 |Training loss: 1.5700
Epoch: 27 |Training loss: 1.5600





training:   5%|▌         | 27/500 [04:42<1:17:16,  9.80s/it][A[A[A

Epoch: 28 |Training loss: 1.5420





training:   6%|▌         | 30/500 [04:54<1:16:44,  9.80s/it][A[A[A

Epoch: 29 |Training loss: 1.5361
Epoch: 30 |Training loss: 1.5217





training:   6%|▌         | 30/500 [05:12<1:16:44,  9.80s/it][A[A[A

Epoch: 31 |Training loss: 1.5172





training:   7%|▋         | 33/500 [05:23<1:16:26,  9.82s/it][A[A[A

Epoch: 32 |Training loss: 1.5111
Epoch: 33 |Training loss: 1.4925





training:   7%|▋         | 33/500 [05:42<1:16:26,  9.82s/it][A[A[A

Epoch: 34 |Training loss: 1.4948





training:   7%|▋         | 36/500 [05:53<1:15:50,  9.81s/it][A[A[A

Epoch: 35 |Training loss: 1.4878
Epoch: 36 |Training loss: 1.4709





training:   7%|▋         | 36/500 [06:12<1:15:50,  9.81s/it][A[A[A

Epoch: 37 |Training loss: 1.4746





training:   8%|▊         | 39/500 [06:22<1:15:16,  9.80s/it][A[A[A

Epoch: 38 |Training loss: 1.4604
Epoch: 39 |Training loss: 1.4610





training:   8%|▊         | 39/500 [06:42<1:15:16,  9.80s/it][A[A[A

Epoch: 40 |Training loss: 1.4403





training:   8%|▊         | 42/500 [06:52<1:14:57,  9.82s/it][A[A[A

Epoch: 41 |Training loss: 1.4553
Epoch: 42 |Training loss: 1.4304





training:   8%|▊         | 42/500 [07:02<1:14:57,  9.82s/it][A[A[A

Epoch: 43 |Training loss: 1.4539





training:   9%|▉         | 45/500 [07:21<1:14:22,  9.81s/it][A[A[A

Epoch: 44 |Training loss: 1.4505
Epoch: 45 |Training loss: 1.4144





training:   9%|▉         | 45/500 [07:32<1:14:22,  9.81s/it][A[A[A

Epoch: 46 |Training loss: 1.4161





training:  10%|▉         | 48/500 [07:50<1:13:55,  9.81s/it][A[A[A

Epoch: 47 |Training loss: 1.3967
Epoch: 48 |Training loss: 1.3996





training:  10%|▉         | 48/500 [08:02<1:13:55,  9.81s/it][A[A[A

Epoch: 49 |Training loss: 1.3786





training:  10%|█         | 51/500 [08:20<1:13:37,  9.84s/it][A[A[A

Epoch: 50 |Training loss: 1.3900
Epoch: 51 |Training loss: 1.3707





training:  10%|█         | 51/500 [08:32<1:13:37,  9.84s/it][A[A[A

Epoch: 52 |Training loss: 1.3733





training:  11%|█         | 54/500 [08:49<1:13:01,  9.82s/it][A[A[A

Epoch: 53 |Training loss: 1.3633
Epoch: 54 |Training loss: 1.3586





training:  11%|█         | 54/500 [09:02<1:13:01,  9.82s/it][A[A[A

Epoch: 55 |Training loss: 1.3468





training:  11%|█▏        | 57/500 [09:19<1:12:26,  9.81s/it][A[A[A

Epoch: 56 |Training loss: 1.3389
Epoch: 57 |Training loss: 1.3339





training:  11%|█▏        | 57/500 [09:32<1:12:26,  9.81s/it][A[A[A

Epoch: 58 |Training loss: 1.3196





training:  12%|█▏        | 60/500 [09:48<1:11:51,  9.80s/it][A[A[A

Epoch: 59 |Training loss: 1.3158
Epoch: 60 |Training loss: 1.3083





training:  12%|█▏        | 60/500 [10:02<1:11:51,  9.80s/it][A[A[A

Epoch: 61 |Training loss: 1.2976





training:  13%|█▎        | 63/500 [10:18<1:11:34,  9.83s/it][A[A[A

Epoch: 62 |Training loss: 1.3018
Epoch: 63 |Training loss: 1.2911





training:  13%|█▎        | 63/500 [10:32<1:11:34,  9.83s/it][A[A[A

Epoch: 64 |Training loss: 1.2869





training:  13%|█▎        | 66/500 [10:47<1:10:57,  9.81s/it][A[A[A

Epoch: 65 |Training loss: 1.2834
Epoch: 66 |Training loss: 1.2707





training:  13%|█▎        | 66/500 [11:02<1:10:57,  9.81s/it][A[A[A

Epoch: 67 |Training loss: 1.2652





training:  14%|█▍        | 69/500 [11:16<1:10:23,  9.80s/it][A[A[A

Epoch: 68 |Training loss: 1.2540
Epoch: 69 |Training loss: 1.2524





training:  14%|█▍        | 69/500 [11:32<1:10:23,  9.80s/it][A[A[A

Epoch: 70 |Training loss: 1.2405





training:  14%|█▍        | 72/500 [11:46<1:10:03,  9.82s/it][A[A[A

Epoch: 71 |Training loss: 1.2287
Epoch: 72 |Training loss: 1.2231





training:  14%|█▍        | 72/500 [12:02<1:10:03,  9.82s/it][A[A[A

Epoch: 73 |Training loss: 1.2213





training:  15%|█▌        | 75/500 [12:15<1:09:27,  9.81s/it][A[A[A

Epoch: 74 |Training loss: 1.2143
Epoch: 75 |Training loss: 1.2107





training:  15%|█▌        | 75/500 [12:32<1:09:27,  9.81s/it][A[A[A

Epoch: 76 |Training loss: 1.2048





training:  16%|█▌        | 78/500 [12:45<1:08:53,  9.80s/it][A[A[A

Epoch: 77 |Training loss: 1.1866
Epoch: 78 |Training loss: 1.1850





training:  16%|█▌        | 78/500 [13:02<1:08:53,  9.80s/it][A[A[A

Epoch: 79 |Training loss: 1.1782





training:  16%|█▌        | 81/500 [13:14<1:08:32,  9.82s/it][A[A[A

Epoch: 80 |Training loss: 1.1829
Epoch: 81 |Training loss: 1.1665





training:  16%|█▌        | 81/500 [13:32<1:08:32,  9.82s/it][A[A[A

Epoch: 82 |Training loss: 1.1828





training:  17%|█▋        | 84/500 [13:44<1:07:59,  9.81s/it][A[A[A

Epoch: 83 |Training loss: 1.1599
Epoch: 84 |Training loss: 1.1886





training:  17%|█▋        | 84/500 [14:02<1:07:59,  9.81s/it][A[A[A

Epoch: 85 |Training loss: 1.1865





training:  17%|█▋        | 87/500 [14:13<1:07:25,  9.80s/it][A[A[A

Epoch: 86 |Training loss: 1.1609
Epoch: 87 |Training loss: 1.1575





training:  17%|█▋        | 87/500 [14:32<1:07:25,  9.80s/it][A[A[A

Epoch: 88 |Training loss: 1.1471





training:  18%|█▊        | 90/500 [14:42<1:06:53,  9.79s/it][A[A[A

Epoch: 89 |Training loss: 1.1459
Epoch: 90 |Training loss: 1.1418





training:  18%|█▊        | 90/500 [15:02<1:06:53,  9.79s/it][A[A[A

Epoch: 91 |Training loss: 1.1352





training:  19%|█▊        | 93/500 [15:12<1:06:38,  9.82s/it][A[A[A

Epoch: 92 |Training loss: 1.1292
Epoch: 93 |Training loss: 1.1305
Epoch: 94 |Training loss: 1.1209





training:  19%|█▊        | 93/500 [15:32<1:06:38,  9.82s/it][A[A[A


training:  19%|█▉        | 96/500 [15:41<1:06:02,  9.81s/it][A[A[A

Epoch: 95 |Training loss: 1.1120
Epoch: 96 |Training loss: 1.1071





training:  19%|█▉        | 96/500 [15:52<1:06:02,  9.81s/it][A[A[A

Epoch: 97 |Training loss: 1.1196





training:  20%|█▉        | 99/500 [16:11<1:05:28,  9.80s/it][A[A[A

Epoch: 98 |Training loss: 1.1024
Epoch: 99 |Training loss: 1.1041





training:  20%|█▉        | 99/500 [16:22<1:05:28,  9.80s/it][A[A[A

Epoch: 100 |Training loss: 1.1029





training:  20%|██        | 102/500 [16:40<1:05:09,  9.82s/it][A[A[A

Epoch: 101 |Training loss: 1.0867
Epoch: 102 |Training loss: 1.0845





training:  20%|██        | 102/500 [16:52<1:05:09,  9.82s/it][A[A[A

Epoch: 103 |Training loss: 1.0827





training:  21%|██        | 105/500 [17:10<1:04:34,  9.81s/it][A[A[A

Epoch: 104 |Training loss: 1.0674
Epoch: 105 |Training loss: 1.0637





training:  21%|██        | 105/500 [17:22<1:04:34,  9.81s/it][A[A[A

Epoch: 106 |Training loss: 1.0534





training:  22%|██▏       | 108/500 [17:39<1:04:00,  9.80s/it][A[A[A

Epoch: 107 |Training loss: 1.0561
Epoch: 108 |Training loss: 1.0487





training:  22%|██▏       | 108/500 [17:52<1:04:00,  9.80s/it][A[A[A

Epoch: 109 |Training loss: 1.0582





training:  22%|██▏       | 111/500 [18:09<1:03:39,  9.82s/it][A[A[A

Epoch: 110 |Training loss: 1.0499
Epoch: 111 |Training loss: 1.0324





training:  22%|██▏       | 111/500 [18:22<1:03:39,  9.82s/it][A[A[A

Epoch: 112 |Training loss: 1.0635





training:  23%|██▎       | 114/500 [18:38<1:03:04,  9.80s/it][A[A[A

Epoch: 113 |Training loss: 1.0409
Epoch: 114 |Training loss: 1.0492





training:  23%|██▎       | 114/500 [18:52<1:03:04,  9.80s/it][A[A[A

Epoch: 115 |Training loss: 1.0452





training:  23%|██▎       | 117/500 [19:07<1:02:30,  9.79s/it][A[A[A

Epoch: 116 |Training loss: 1.0329
Epoch: 117 |Training loss: 1.0200





training:  23%|██▎       | 117/500 [19:22<1:02:30,  9.79s/it][A[A[A

Epoch: 118 |Training loss: 1.0273





training:  24%|██▍       | 120/500 [19:36<1:01:58,  9.79s/it][A[A[A

Epoch: 119 |Training loss: 1.0207
Epoch: 120 |Training loss: 1.0041





training:  24%|██▍       | 120/500 [19:52<1:01:58,  9.79s/it][A[A[A

Epoch: 121 |Training loss: 1.0218





training:  25%|██▍       | 123/500 [20:06<1:01:39,  9.81s/it][A[A[A

Epoch: 122 |Training loss: 1.0104
Epoch: 123 |Training loss: 0.9911





training:  25%|██▍       | 123/500 [20:22<1:01:39,  9.81s/it][A[A[A

Epoch: 124 |Training loss: 0.9938





training:  25%|██▌       | 126/500 [20:35<1:01:05,  9.80s/it][A[A[A

Epoch: 125 |Training loss: 0.9928
Epoch: 126 |Training loss: 0.9810





training:  25%|██▌       | 126/500 [20:52<1:01:05,  9.80s/it][A[A[A

Epoch: 127 |Training loss: 0.9706





training:  26%|██▌       | 129/500 [21:05<1:00:33,  9.79s/it][A[A[A

Epoch: 128 |Training loss: 0.9730
Epoch: 129 |Training loss: 0.9794





training:  26%|██▌       | 129/500 [21:22<1:00:33,  9.79s/it][A[A[A

Epoch: 130 |Training loss: 0.9692





training:  26%|██▋       | 132/500 [21:35<1:00:19,  9.84s/it][A[A[A

Epoch: 131 |Training loss: 0.9645
Epoch: 132 |Training loss: 0.9631





training:  26%|██▋       | 132/500 [21:52<1:00:19,  9.84s/it][A[A[A

Epoch: 133 |Training loss: 0.9593





training:  27%|██▋       | 135/500 [22:04<59:44,  9.82s/it]  [A[A[A

Epoch: 134 |Training loss: 0.9442
Epoch: 135 |Training loss: 0.9711





training:  27%|██▋       | 135/500 [22:22<59:44,  9.82s/it][A[A[A

Epoch: 136 |Training loss: 0.9577





training:  28%|██▊       | 138/500 [22:33<59:11,  9.81s/it][A[A[A

Epoch: 137 |Training loss: 0.9527
Epoch: 138 |Training loss: 0.9442





training:  28%|██▊       | 138/500 [22:52<59:11,  9.81s/it][A[A[A

Epoch: 139 |Training loss: 0.9345





training:  28%|██▊       | 141/500 [23:03<58:50,  9.83s/it][A[A[A

Epoch: 140 |Training loss: 0.9479
Epoch: 141 |Training loss: 0.9335





training:  28%|██▊       | 141/500 [23:22<58:50,  9.83s/it][A[A[A

Epoch: 142 |Training loss: 0.9262





training:  29%|██▉       | 144/500 [23:32<58:16,  9.82s/it][A[A[A

Epoch: 143 |Training loss: 0.9208
Epoch: 144 |Training loss: 0.9120
Epoch: 145 |Training loss: 0.9123





training:  29%|██▉       | 144/500 [23:52<58:16,  9.82s/it][A[A[A


training:  29%|██▉       | 147/500 [24:02<57:42,  9.81s/it][A[A[A

Epoch: 146 |Training loss: 0.9094
Epoch: 147 |Training loss: 0.8929





training:  29%|██▉       | 147/500 [24:12<57:42,  9.81s/it][A[A[A

Epoch: 148 |Training loss: 0.9076





training:  30%|███       | 150/500 [24:31<57:09,  9.80s/it][A[A[A

Epoch: 149 |Training loss: 0.8998
Epoch: 150 |Training loss: 0.8957





training:  30%|███       | 150/500 [24:42<57:09,  9.80s/it][A[A[A

Epoch: 151 |Training loss: 0.9019





training:  31%|███       | 153/500 [25:01<56:50,  9.83s/it][A[A[A

Epoch: 152 |Training loss: 0.8810
Epoch: 153 |Training loss: 0.8883





training:  31%|███       | 153/500 [25:12<56:50,  9.83s/it][A[A[A

Epoch: 154 |Training loss: 0.8746





training:  31%|███       | 156/500 [25:30<56:15,  9.81s/it][A[A[A

Epoch: 155 |Training loss: 0.8880
Epoch: 156 |Training loss: 0.8935





training:  31%|███       | 156/500 [25:42<56:15,  9.81s/it][A[A[A

Epoch: 157 |Training loss: 0.8709





training:  32%|███▏      | 159/500 [25:59<55:42,  9.80s/it][A[A[A

Epoch: 158 |Training loss: 0.8726
Epoch: 159 |Training loss: 0.8692





training:  32%|███▏      | 159/500 [26:12<55:42,  9.80s/it][A[A[A

Epoch: 160 |Training loss: 0.8595





training:  32%|███▏      | 162/500 [26:29<55:22,  9.83s/it][A[A[A

Epoch: 161 |Training loss: 0.8497
Epoch: 162 |Training loss: 0.8458





training:  32%|███▏      | 162/500 [26:42<55:22,  9.83s/it][A[A[A

Epoch: 163 |Training loss: 0.8635





training:  33%|███▎      | 165/500 [26:58<54:49,  9.82s/it][A[A[A

Epoch: 164 |Training loss: 0.8490
Epoch: 165 |Training loss: 0.8413





training:  33%|███▎      | 165/500 [27:12<54:49,  9.82s/it][A[A[A

Epoch: 166 |Training loss: 0.8314





training:  34%|███▎      | 168/500 [27:28<54:15,  9.81s/it][A[A[A

Epoch: 167 |Training loss: 0.8377
Epoch: 168 |Training loss: 0.8296





training:  34%|███▎      | 168/500 [27:42<54:15,  9.81s/it][A[A[A

Epoch: 169 |Training loss: 0.8265





training:  34%|███▍      | 171/500 [27:57<53:53,  9.83s/it][A[A[A

Epoch: 170 |Training loss: 0.8099
Epoch: 171 |Training loss: 0.8453





training:  34%|███▍      | 171/500 [28:12<53:53,  9.83s/it][A[A[A

Epoch: 172 |Training loss: 0.8306





training:  35%|███▍      | 174/500 [28:27<53:20,  9.82s/it][A[A[A

Epoch: 173 |Training loss: 0.8234
Epoch: 174 |Training loss: 0.8134





training:  35%|███▍      | 174/500 [28:42<53:20,  9.82s/it][A[A[A

Epoch: 175 |Training loss: 0.8167





training:  35%|███▌      | 177/500 [28:56<52:45,  9.80s/it][A[A[A

Epoch: 176 |Training loss: 0.8145
Epoch: 177 |Training loss: 0.8034





training:  35%|███▌      | 177/500 [29:12<52:45,  9.80s/it][A[A[A

Epoch: 178 |Training loss: 0.7935





training:  36%|███▌      | 180/500 [29:25<52:14,  9.80s/it][A[A[A

Epoch: 179 |Training loss: 0.8177
Epoch: 180 |Training loss: 0.8091





training:  36%|███▌      | 180/500 [29:42<52:14,  9.80s/it][A[A[A

Epoch: 181 |Training loss: 0.8078





training:  37%|███▋      | 183/500 [29:55<51:54,  9.82s/it][A[A[A

Epoch: 182 |Training loss: 0.8197
Epoch: 183 |Training loss: 0.7996





training:  37%|███▋      | 183/500 [30:12<51:54,  9.82s/it][A[A[A

Epoch: 184 |Training loss: 0.7873





training:  37%|███▋      | 186/500 [30:24<51:19,  9.81s/it][A[A[A

Epoch: 185 |Training loss: 0.8015
Epoch: 186 |Training loss: 0.7779





training:  37%|███▋      | 186/500 [30:42<51:19,  9.81s/it][A[A[A

Epoch: 187 |Training loss: 0.8191





training:  38%|███▊      | 189/500 [30:54<50:45,  9.79s/it][A[A[A

Epoch: 188 |Training loss: 0.8175
Epoch: 189 |Training loss: 0.7992





training:  38%|███▊      | 189/500 [31:12<50:45,  9.79s/it][A[A[A

Epoch: 190 |Training loss: 0.7759





training:  38%|███▊      | 192/500 [31:23<50:22,  9.81s/it][A[A[A

Epoch: 191 |Training loss: 0.7840
Epoch: 192 |Training loss: 0.7707





training:  38%|███▊      | 192/500 [31:42<50:22,  9.81s/it][A[A[A

Epoch: 193 |Training loss: 0.7807





training:  39%|███▉      | 195/500 [31:53<49:49,  9.80s/it][A[A[A

Epoch: 194 |Training loss: 0.7616
Epoch: 195 |Training loss: 0.7717





training:  39%|███▉      | 195/500 [32:12<49:49,  9.80s/it][A[A[A

Epoch: 196 |Training loss: 0.7709





training:  40%|███▉      | 198/500 [32:22<49:16,  9.79s/it][A[A[A

Epoch: 197 |Training loss: 0.7551
Epoch: 198 |Training loss: 0.7550





training:  40%|███▉      | 198/500 [32:32<49:16,  9.79s/it][A[A[A

Epoch: 199 |Training loss: 0.7424





training:  40%|████      | 201/500 [32:51<48:53,  9.81s/it][A[A[A

Epoch: 200 |Training loss: 0.7395
Epoch: 201 |Training loss: 0.7329





training:  40%|████      | 201/500 [33:02<48:53,  9.81s/it][A[A[A

Epoch: 202 |Training loss: 0.7256





training:  41%|████      | 204/500 [33:21<48:21,  9.80s/it][A[A[A

Epoch: 203 |Training loss: 0.7224
Epoch: 204 |Training loss: 0.7196





training:  41%|████      | 204/500 [33:32<48:21,  9.80s/it][A[A[A

Epoch: 205 |Training loss: 0.7089





training:  41%|████▏     | 207/500 [33:50<47:49,  9.80s/it][A[A[A

Epoch: 206 |Training loss: 0.7118
Epoch: 207 |Training loss: 0.7075





training:  41%|████▏     | 207/500 [34:02<47:49,  9.80s/it][A[A[A

Epoch: 208 |Training loss: 0.7033





training:  42%|████▏     | 210/500 [34:19<47:18,  9.79s/it][A[A[A

Epoch: 209 |Training loss: 0.6903
Epoch: 210 |Training loss: 0.6975





training:  42%|████▏     | 210/500 [34:32<47:18,  9.79s/it][A[A[A

Epoch: 211 |Training loss: 0.7025





training:  43%|████▎     | 213/500 [34:49<46:58,  9.82s/it][A[A[A

Epoch: 212 |Training loss: 0.6833
Epoch: 213 |Training loss: 0.7110





training:  43%|████▎     | 213/500 [35:02<46:58,  9.82s/it][A[A[A

Epoch: 214 |Training loss: 0.6926





training:  43%|████▎     | 216/500 [35:18<46:25,  9.81s/it][A[A[A

Epoch: 215 |Training loss: 0.7004
Epoch: 216 |Training loss: 0.6993





training:  43%|████▎     | 216/500 [35:32<46:25,  9.81s/it][A[A[A

Epoch: 217 |Training loss: 0.6820





training:  44%|████▍     | 219/500 [35:48<45:58,  9.82s/it][A[A[A

Epoch: 218 |Training loss: 0.6787
Epoch: 219 |Training loss: 0.6790





training:  44%|████▍     | 219/500 [36:02<45:58,  9.82s/it][A[A[A

Epoch: 220 |Training loss: 0.6672





training:  44%|████▍     | 222/500 [36:18<45:34,  9.84s/it][A[A[A

Epoch: 221 |Training loss: 0.6667
Epoch: 222 |Training loss: 0.6554





training:  44%|████▍     | 222/500 [36:32<45:34,  9.84s/it][A[A[A

Epoch: 223 |Training loss: 0.6661





training:  45%|████▌     | 225/500 [36:47<45:00,  9.82s/it][A[A[A

Epoch: 224 |Training loss: 0.6627
Epoch: 225 |Training loss: 0.6647





training:  45%|████▌     | 225/500 [37:02<45:00,  9.82s/it][A[A[A

Epoch: 226 |Training loss: 0.6700





training:  46%|████▌     | 228/500 [37:16<44:27,  9.81s/it][A[A[A

Epoch: 227 |Training loss: 0.6432
Epoch: 228 |Training loss: 0.6624





training:  46%|████▌     | 228/500 [37:32<44:27,  9.81s/it][A[A[A

Epoch: 229 |Training loss: 0.6488





training:  46%|████▌     | 231/500 [37:46<44:02,  9.82s/it][A[A[A

Epoch: 230 |Training loss: 0.6507
Epoch: 231 |Training loss: 0.6468





training:  46%|████▌     | 231/500 [38:02<44:02,  9.82s/it][A[A[A

Epoch: 232 |Training loss: 0.6384





training:  47%|████▋     | 234/500 [38:15<43:30,  9.82s/it][A[A[A

Epoch: 233 |Training loss: 0.6340
Epoch: 234 |Training loss: 0.6221





training:  47%|████▋     | 234/500 [38:32<43:30,  9.82s/it][A[A[A

Epoch: 235 |Training loss: 0.6469





training:  47%|████▋     | 237/500 [38:45<42:58,  9.80s/it][A[A[A

Epoch: 236 |Training loss: 0.6352
Epoch: 237 |Training loss: 0.6439





training:  47%|████▋     | 237/500 [39:02<42:58,  9.80s/it][A[A[A

Epoch: 238 |Training loss: 0.6442





training:  48%|████▊     | 240/500 [39:14<42:27,  9.80s/it][A[A[A

Epoch: 239 |Training loss: 0.6313
Epoch: 240 |Training loss: 0.6414





training:  48%|████▊     | 240/500 [39:32<42:27,  9.80s/it][A[A[A

Epoch: 241 |Training loss: 0.6298





training:  49%|████▊     | 243/500 [39:44<42:07,  9.84s/it][A[A[A

Epoch: 242 |Training loss: 0.6305
Epoch: 243 |Training loss: 0.6222





training:  49%|████▊     | 243/500 [40:02<42:07,  9.84s/it][A[A[A

Epoch: 244 |Training loss: 0.6226





training:  49%|████▉     | 246/500 [40:13<41:34,  9.82s/it][A[A[A

Epoch: 245 |Training loss: 0.6110
Epoch: 246 |Training loss: 0.6215





training:  49%|████▉     | 246/500 [40:32<41:34,  9.82s/it][A[A[A

Epoch: 247 |Training loss: 0.6036





training:  50%|████▉     | 249/500 [40:42<41:02,  9.81s/it][A[A[A

Epoch: 248 |Training loss: 0.6263
Epoch: 249 |Training loss: 0.6114





training:  50%|████▉     | 249/500 [40:52<41:02,  9.81s/it][A[A[A

Epoch: 250 |Training loss: 0.6153





training:  50%|█████     | 252/500 [41:12<40:39,  9.84s/it][A[A[A

Epoch: 251 |Training loss: 0.6201
Epoch: 252 |Training loss: 0.5927





training:  50%|█████     | 252/500 [41:22<40:39,  9.84s/it][A[A[A

Epoch: 253 |Training loss: 0.6143





training:  51%|█████     | 255/500 [41:41<40:05,  9.82s/it][A[A[A

Epoch: 254 |Training loss: 0.6060
Epoch: 255 |Training loss: 0.5962





training:  51%|█████     | 255/500 [41:52<40:05,  9.82s/it][A[A[A

Epoch: 256 |Training loss: 0.5979





training:  52%|█████▏    | 258/500 [42:11<39:33,  9.81s/it][A[A[A

Epoch: 257 |Training loss: 0.5886
Epoch: 258 |Training loss: 0.5878





training:  52%|█████▏    | 258/500 [42:22<39:33,  9.81s/it][A[A[A

Epoch: 259 |Training loss: 0.5815





training:  52%|█████▏    | 261/500 [42:40<39:08,  9.83s/it][A[A[A

Epoch: 260 |Training loss: 0.5808
Epoch: 261 |Training loss: 0.5713





training:  52%|█████▏    | 261/500 [42:52<39:08,  9.83s/it][A[A[A

Epoch: 262 |Training loss: 0.5786





training:  53%|█████▎    | 264/500 [43:10<38:36,  9.81s/it][A[A[A

Epoch: 263 |Training loss: 0.5713
Epoch: 264 |Training loss: 0.5642





training:  53%|█████▎    | 264/500 [43:23<38:36,  9.81s/it][A[A[A

Epoch: 265 |Training loss: 0.5865





training:  53%|█████▎    | 267/500 [43:39<38:03,  9.80s/it][A[A[A

Epoch: 266 |Training loss: 0.5772
Epoch: 267 |Training loss: 0.5633





training:  53%|█████▎    | 267/500 [43:53<38:03,  9.80s/it][A[A[A

Epoch: 268 |Training loss: 0.5759





training:  54%|█████▍    | 270/500 [44:08<37:32,  9.79s/it][A[A[A

Epoch: 269 |Training loss: 0.5826
Epoch: 270 |Training loss: 0.5795





training:  54%|█████▍    | 270/500 [44:23<37:32,  9.79s/it][A[A[A

Epoch: 271 |Training loss: 0.5590





training:  55%|█████▍    | 273/500 [44:38<37:09,  9.82s/it][A[A[A

Epoch: 272 |Training loss: 0.5777
Epoch: 273 |Training loss: 0.5575





training:  55%|█████▍    | 273/500 [44:53<37:09,  9.82s/it][A[A[A

Epoch: 274 |Training loss: 0.5876





training:  55%|█████▌    | 276/500 [45:07<36:37,  9.81s/it][A[A[A

Epoch: 275 |Training loss: 0.5789
Epoch: 276 |Training loss: 0.5769





training:  55%|█████▌    | 276/500 [45:23<36:37,  9.81s/it][A[A[A

Epoch: 277 |Training loss: 0.5615





training:  56%|█████▌    | 279/500 [45:37<36:05,  9.80s/it][A[A[A

Epoch: 278 |Training loss: 0.5692
Epoch: 279 |Training loss: 0.5452





training:  56%|█████▌    | 279/500 [45:53<36:05,  9.80s/it][A[A[A

Epoch: 280 |Training loss: 0.5522





training:  56%|█████▋    | 282/500 [46:06<35:42,  9.83s/it][A[A[A

Epoch: 281 |Training loss: 0.5407
Epoch: 282 |Training loss: 0.5481





training:  56%|█████▋    | 282/500 [46:23<35:42,  9.83s/it][A[A[A

Epoch: 283 |Training loss: 0.5383





training:  57%|█████▋    | 285/500 [46:36<35:10,  9.81s/it][A[A[A

Epoch: 284 |Training loss: 0.5418
Epoch: 285 |Training loss: 0.5372





training:  57%|█████▋    | 285/500 [46:53<35:10,  9.81s/it][A[A[A

Epoch: 286 |Training loss: 0.5362





training:  58%|█████▊    | 288/500 [47:05<34:37,  9.80s/it][A[A[A

Epoch: 287 |Training loss: 0.5299
Epoch: 288 |Training loss: 0.5376





training:  58%|█████▊    | 288/500 [47:23<34:37,  9.80s/it][A[A[A

Epoch: 289 |Training loss: 0.5261





training:  58%|█████▊    | 291/500 [47:35<34:12,  9.82s/it][A[A[A

Epoch: 290 |Training loss: 0.5294
Epoch: 291 |Training loss: 0.5335





training:  58%|█████▊    | 291/500 [47:53<34:12,  9.82s/it][A[A[A

Epoch: 292 |Training loss: 0.5081





training:  59%|█████▉    | 294/500 [48:04<33:41,  9.81s/it][A[A[A

Epoch: 293 |Training loss: 0.5160
Epoch: 294 |Training loss: 0.5122





training:  59%|█████▉    | 294/500 [48:23<33:41,  9.81s/it][A[A[A

Epoch: 295 |Training loss: 0.5097





training:  59%|█████▉    | 297/500 [48:34<33:13,  9.82s/it][A[A[A

Epoch: 296 |Training loss: 0.4969
Epoch: 297 |Training loss: 0.5030





training:  59%|█████▉    | 297/500 [48:53<33:13,  9.82s/it][A[A[A

Epoch: 298 |Training loss: 0.4997





training:  60%|██████    | 300/500 [49:03<32:41,  9.81s/it][A[A[A

Epoch: 299 |Training loss: 0.4887
Epoch: 300 |Training loss: 0.4896





training:  60%|██████    | 300/500 [49:23<32:41,  9.81s/it][A[A[A

Epoch: 301 |Training loss: 0.4866





training:  61%|██████    | 303/500 [49:33<32:17,  9.83s/it][A[A[A

Epoch: 302 |Training loss: 0.4831
Epoch: 303 |Training loss: 0.4771





training:  61%|██████    | 303/500 [49:43<32:17,  9.83s/it][A[A[A

Epoch: 304 |Training loss: 0.4719





training:  61%|██████    | 306/500 [50:02<31:44,  9.82s/it][A[A[A

Epoch: 305 |Training loss: 0.4745
Epoch: 306 |Training loss: 0.4714





training:  61%|██████    | 306/500 [50:13<31:44,  9.82s/it][A[A[A

Epoch: 307 |Training loss: 0.4706





training:  62%|██████▏   | 309/500 [50:31<31:13,  9.81s/it][A[A[A

Epoch: 308 |Training loss: 0.4691
Epoch: 309 |Training loss: 0.4639





training:  62%|██████▏   | 309/500 [50:43<31:13,  9.81s/it][A[A[A

Epoch: 310 |Training loss: 0.4697





training:  62%|██████▏   | 312/500 [51:01<30:48,  9.83s/it][A[A[A

Epoch: 311 |Training loss: 0.4741
Epoch: 312 |Training loss: 0.4583





training:  62%|██████▏   | 312/500 [51:13<30:48,  9.83s/it][A[A[A

Epoch: 313 |Training loss: 0.4735





training:  63%|██████▎   | 315/500 [51:30<30:16,  9.82s/it][A[A[A

Epoch: 314 |Training loss: 0.4646
Epoch: 315 |Training loss: 0.5256





training:  63%|██████▎   | 315/500 [51:43<30:16,  9.82s/it][A[A[A

Epoch: 316 |Training loss: 0.5165





training:  64%|██████▎   | 318/500 [52:00<29:45,  9.81s/it][A[A[A

Epoch: 317 |Training loss: 0.5023
Epoch: 318 |Training loss: 0.4946





training:  64%|██████▎   | 318/500 [52:13<29:45,  9.81s/it][A[A[A

Epoch: 319 |Training loss: 0.4894





training:  64%|██████▍   | 321/500 [52:29<29:19,  9.83s/it][A[A[A

Epoch: 320 |Training loss: 0.4813
Epoch: 321 |Training loss: 0.4824





training:  64%|██████▍   | 321/500 [52:43<29:19,  9.83s/it][A[A[A

Epoch: 322 |Training loss: 0.4597





training:  65%|██████▍   | 324/500 [52:59<28:47,  9.82s/it][A[A[A

Epoch: 323 |Training loss: 0.4738
Epoch: 324 |Training loss: 0.4663





training:  65%|██████▍   | 324/500 [53:13<28:47,  9.82s/it][A[A[A

Epoch: 325 |Training loss: 0.4550





training:  65%|██████▌   | 327/500 [53:28<28:15,  9.80s/it][A[A[A

Epoch: 326 |Training loss: 0.4603
Epoch: 327 |Training loss: 0.4424





training:  65%|██████▌   | 327/500 [53:43<28:15,  9.80s/it][A[A[A

Epoch: 328 |Training loss: 0.4565





training:  66%|██████▌   | 330/500 [53:57<27:44,  9.79s/it][A[A[A

Epoch: 329 |Training loss: 0.4546
Epoch: 330 |Training loss: 0.4605





training:  66%|██████▌   | 330/500 [54:13<27:44,  9.79s/it][A[A[A

Epoch: 331 |Training loss: 0.4510





training:  67%|██████▋   | 333/500 [54:27<27:19,  9.81s/it][A[A[A

Epoch: 332 |Training loss: 0.4478
Epoch: 333 |Training loss: 0.4453





training:  67%|██████▋   | 333/500 [54:43<27:19,  9.81s/it][A[A[A

Epoch: 334 |Training loss: 0.4353





training:  67%|██████▋   | 336/500 [54:56<26:47,  9.80s/it][A[A[A

Epoch: 335 |Training loss: 0.4436
Epoch: 336 |Training loss: 0.4518





training:  67%|██████▋   | 336/500 [55:13<26:47,  9.80s/it][A[A[A

Epoch: 337 |Training loss: 0.4302





training:  68%|██████▊   | 339/500 [55:25<26:16,  9.79s/it][A[A[A

Epoch: 338 |Training loss: 0.4311
Epoch: 339 |Training loss: 0.4314





training:  68%|██████▊   | 339/500 [55:43<26:16,  9.79s/it][A[A[A

Epoch: 340 |Training loss: 0.4263


KeyboardInterrupt: ignored

**Music generation**

In [142]:
# In case we want to use previously trained weights
weights = "model_best.pth.tar"
checkpoint = torch.load(output_dir+weights)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']


In [32]:
# Generate network input again
network_input = []
network_output = []
for i in range(0, len(notes) - sequence_length, 1):
  network_input.append([note_to_int[char] for char in notes[i:i + sequence_length]])
n_patterns = len(network_input)
network_input = np.reshape(network_input, (n_patterns, sequence_length))


The workflow now is:


1.   Pick a **seed sequence** randomly from your list of inputs (*pattern* variable)
2.   Pass it as input for your model to generate a new element (note or chord)
3.   Add the new element to your final song and to your *pattern* list
4.   Remove the first item from *pattern*
5.   Go to step 2


In [34]:
""" Generate notes from the neural network based on a sequence of notes """
# pick a random sequence from the input as a starting point for the prediction
start = np.random.randint(0, len(network_input)-1)
int_to_note = dict((number, note) for number, note in enumerate(pitchnames))
pattern = torch.from_numpy(network_input[start]).cuda()

prediction_output = model.generate(pattern, 500)


In [35]:
result_sample=[]

for i in range(500):
  print(i)
  result = int_to_note[prediction_output[i].item()]
  print('\r', 'Predicted ', i, " ",result, end='')
  result_sample.append(result)

prediction_output=result_sample

0
 Predicted  0   E51
 Predicted  1   0.4.72
 Predicted  2   7.03
 Predicted  3   6.114
 Predicted  4   E45
 Predicted  5   F#36
 Predicted  6   B37
 Predicted  7   6.118
 Predicted  8   E49
 Predicted  9   F#310
 Predicted  10   C211
 Predicted  11   G512
 Predicted  12   D513
 Predicted  13   7.11.214
 Predicted  14   2.715
 Predicted  15   F#316
 Predicted  16   G117
 Predicted  17   G418
 Predicted  18   F#319
 Predicted  19   G120
 Predicted  20   G421
 Predicted  21   F#322
 Predicted  22   G123
 Predicted  23   G424
 Predicted  24   F#325
 Predicted  25   G126
 Predicted  26   G527
 Predicted  27   D528
 Predicted  28   2.729
 Predicted  29   F#330
 Predicted  30   G131
 Predicted  31   G432
 Predicted  32   F#333
 Predicted  33   2.734
 Predicted  34   F#335
 Predicted  35   G136
 Predicted  36   G437
 Predicted  37   G438
 Predicted  38   G439
 Predicted  39   G440
 Predicted  40   7.11.241
 Predicted  41   2.742
 Predicted  42   F#34

The last step is creating a MIDI file from the predictions.

**music21** will help us again for this task. We should create a **Stream** and add to it the predicted notes and chords.

We are adding an offset of 0.5 between elements.

In [36]:
offset = 0
output_notes = []
# create note and chord objects based on the values generated by the model
for pattern in prediction_output:
    # pattern is a chord
    if ('.' in pattern) or pattern.isdigit():
        notes_in_chord = pattern.split('.')
        notes = []
        for current_note in notes_in_chord:
            new_note = note.Note(int(current_note))
            new_note.storedInstrument = instrument.Piano()
            notes.append(new_note)
        new_chord = chord.Chord(notes)
        new_chord.offset = offset
        output_notes.append(new_chord)
    # pattern is a note
    else:
        new_note = note.Note(pattern)
        new_note.offset = offset
        new_note.storedInstrument = instrument.Piano()
        output_notes.append(new_note)

    # increase offset each iteration so that notes do not stack
    offset += 0.5

midi_stream = stream.Stream(output_notes)
midi_stream.write('midi', fp='test_output.mid')

'test_output.mid'