Script for creating and training the LSTM network models used to generate classical music.

Matilda Wikström, matwiks@student.chalmers.se

Jesper Larsson, jesplar@student.chalmers.se

Links:

https://github.com/craffel/pretty-midi

https://towardsdatascience.com/generate-piano-instrumental-music-by-using-deep-learning-80ac35cdbd2e





In [130]:
#Imports
import torch
import pretty_midi
import numpy as np
import unicodedata
import re
import numpy as np
from random import shuffle, seed
import os
import io
import time
from tqdm import tnrange, tqdm_notebook, tqdm
# "Our code"
import preprocessing as pp
#import generatemidi as gm # Broken file, fix
import tokenizer as token
import generatemidi as gen

In [None]:
# Using pretty_midi example
fs=30
dict_note = {}
midi_pretty_format = pretty_midi.PrettyMIDI('test.midi')
piano_midi = midi_pretty_format.instruments[0] # Get the piano channels
piano_roll = piano_midi.get_piano_roll(fs=fs)
dict_note[0] = piano_roll

print(len(piano_roll[0]))

In [3]:
#Download the required data create folder structure
#Make sure we are in the correct folder

assert (os.path.basename(os.getcwd())=='LSTM-MusicGenerator'), "Wrong working dir"

#Download the MAESTRO Dataset
if not os.path.isfile('./maestro-v1.0.0-midi.zip'):
    !wget https://storage.googleapis.com/magentadata/datasets/maestro/v1.0.0/maestro-v1.0.0-midi.zip

#Check if we have extracted the files
if not os.path.isdir('./maestro-v1.0.0-midi') and not os.path.isfile('./maestro-v1.0.0/LICENSE'):
    !unzip maestro-v1.0.0-midi.zip


/bin/sh: 1: wget: not found


In [142]:
# Preprocessing
nbr_of_songs = 10 #100
list_all_midi = pp.get_list_midi() # 
sampled_200_midi = list_all_midi[0:nbr_of_songs]
batch = 1
start_index = 0
note_tokenizer = token.NoteTokenizer()
import pretty_midi

for i in tqdm_notebook(range(len(sampled_200_midi))):
    dict_time_notes = token.generate_dict_time_notes(sampled_200_midi, batch_song=1, start_index=i, use_tqdm=False, fs=5)
    full_notes = token.process_notes_in_song(dict_time_notes)
    for note in full_notes:
        note_tokenizer.partial_fit(list(note.values()))
note_tokenizer.add_new_note('e') # Add empty notes



HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




In [224]:
# ---- define model
import torch.nn as nn
import torch.nn.functional as F
   
input_size = 50
hidden_size = 100
num_layers = 2
is_bidirectional = True
dropout_rate = 0.2
batch_size = 1 #96 ### Should be 2085

seq_len = 50
unique_notes = note_tokenizer.unique_word #Used in our output layer to map



if is_bidirectional:
    num_directions = 2
else:
    num_directions = 1


class our_LSTM(torch.nn.Module):
    def __init__(self, h0, c0):
        super().__init__()
        self.hn = h0
        self.cn = c0
        
        self.lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, dropout=dropout_rate, bidirectional=is_bidirectional)
        self.fc = torch.nn.Linear(num_directions * hidden_size, unique_notes)
        self.softmax = torch.nn.Softmax(dim=-1)
        
    def forward(self, inp, use_softmax = False):
        out, (self.hn, self.cn) = self.lstm(inp, (self.hn, self.cn))
        out = self.fc(out)
        if use_softmax:
            out = self.softmax(out)
        out = out.permute(0,2,1)
        out = out.view(-1,unique_notes)
        
        return out, self.hn, self.cn
    


In [200]:
from torch import optim
import matplotlib.pyplot as plt

# ---- variables
seq_len = 50 
EPOCHS = 200-13
BATCH_SONG = 1 #16
BATCH_NNET_SIZE = 128 
TOTAL_SONGS = len(sampled_200_midi)
FRAME_PER_SECOND = 5
DATE = 20191023

# ---- initialize model, states, loss fcn, optimizer 
hn = torch.zeros(num_layers*num_directions, batch_size, hidden_size,dtype=torch.float32)
cn = torch.zeros(num_layers*num_directions, batch_size, hidden_size,dtype=torch.float32)
model = our_LSTM(hn,cn)

# -- load trained model
START_EPOCH = 13
model.load_state_dict(torch.load(f'lstm_20191023_2layer_100hidden_ep{START_EPOCH}'))

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = optim.RMSprop(model.parameters(), lr=0.0001)

# ---- live plotting
%matplotlib notebook 
fig, axes = plt.subplots(1,1)
plt.ion()
fig.show()
fig.canvas.draw()
loss_vec = []
loss_vec = np.load(f'loss_20191023_2layer_100hiddenep{START_EPOCH}.npy', allow_pickle=True).tolist()




# ---- training loop
for epoch in tqdm_notebook(range(START_EPOCH,START_EPOCH+EPOCHS), desc='epoch'):
    batch_song = BATCH_SONG
    frame_per_second = FRAME_PER_SECOND
    shuffle(sampled_200_midi)
        
    for i in tqdm_notebook(range(0,len(sampled_200_midi),BATCH_SONG), desc=f'completed songs in epoch {epoch+1}'):
        
        # -- prepare data
        inputs_nnet_large, outputs_nnet_large = token.generate_batch_song(
                sampled_200_midi, batch_song, start_index=i, fs=frame_per_second, 
                seq_len=seq_len, use_tqdm=False)
        trans_in = note_tokenizer.transform(inputs_nnet_large)
        trans_out = note_tokenizer.transform(outputs_nnet_large)
        trans_out = trans_out - 1
        
        # -- reset hidden states every song
        model.hn.detach_()
        model.cn.detach_()
        
        # -- train on 1 song 
        counter = 0
        for j in range(0,len(trans_in)-BATCH_NNET_SIZE,BATCH_NNET_SIZE):
            counter += 1 
            input_tensor = trans_in[j:j+BATCH_NNET_SIZE]
            target_tensor = trans_out[j:j+BATCH_NNET_SIZE]

            input_tensor = torch.tensor(input_tensor,dtype=torch.float32)
            input_tensor = input_tensor[None,:,:] # Should be (seq_len, batch, input_size)
            input_tensor = input_tensor.permute(1,0,2)
            target_tensor = torch.tensor(target_tensor,dtype=torch.long).view(-1)
            pred, hn, cn = model(input_tensor)
            loss = loss_fn(pred,target_tensor)
            loss.backward(retain_graph=True)
            optimizer.step()
            optimizer.zero_grad()
        
        # -- plot loss after every song
        loss_vec.append(loss.data)
        axes.clear()
        axes.plot(loss_vec)
        axes.set_xlabel('Nbr of Songs')
        axes.set_ylabel('Loss')
        plt.legend(["Training Loss"])
        fig.canvas.draw()
    
    # -- save model and loss vector every epoch
    torch.save(model.state_dict(), f'lstm_{DATE}_{num_layers}layer_{hidden_size}hidden_ep{epoch+1}')
    np.save(f'loss_{DATE}_{num_layers}layer_{hidden_size}hiddenep{epoch+1}', np.array(loss_vec))
    
    # -- reset hidden states every song
    model.hn.detach_()
    model.cn.detach_()





<IPython.core.display.Javascript object>

HBox(children=(IntProgress(value=0, description='epoch', max=187, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='completed songs in epoch 14', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 15', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 16', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 17', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 18', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 19', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 20', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 21', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 22', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 23', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 24', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 25', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 26', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 27', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 28', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 29', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 30', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 31', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 32', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 33', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 34', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 35', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 36', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 37', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 38', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 39', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 40', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 41', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 42', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 43', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 44', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 45', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 46', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 47', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 48', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 49', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 50', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 51', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 52', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 53', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 54', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 55', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 56', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 57', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 58', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 59', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 60', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 61', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 62', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 63', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 64', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 65', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 66', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 67', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 68', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 69', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 70', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 71', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 72', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 73', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 74', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 75', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 76', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 77', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 78', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 79', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 80', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 81', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 82', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 83', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 84', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 85', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 86', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 87', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 88', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 89', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 90', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 91', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 92', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 93', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 94', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 95', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 96', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 97', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 98', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 99', max=10, style=ProgressStyle(des…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 100', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 101', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 102', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 103', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 104', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 105', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 106', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 107', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 108', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 109', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 110', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 111', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 112', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 113', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 114', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 115', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 116', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 117', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 118', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 119', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 120', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 121', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 122', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 123', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 124', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 125', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 126', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 127', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 128', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 129', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 130', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 131', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 132', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 133', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 134', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 135', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 136', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 137', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 138', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 139', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 140', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 141', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 142', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 143', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 144', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 145', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 146', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 147', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 148', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 149', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 150', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 151', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 152', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 153', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 154', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 155', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 156', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 157', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 158', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 159', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 160', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 161', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 162', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 163', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 164', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 165', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 166', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 167', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 168', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 169', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 170', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 171', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 172', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 173', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 174', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 175', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 176', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 177', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 178', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 179', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 180', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 181', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 182', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 183', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 184', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 185', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 186', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 187', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 188', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 189', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 190', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 191', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 192', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 193', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 194', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 195', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 196', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 197', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 198', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 199', max=10, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='completed songs in epoch 200', max=10, style=ProgressStyle(de…





In [6]:
# ---- save model and loss vector
torch.save(model.state_dict(), f'lstm_{DATE}_{num_layers}layer_{hidden_size}hidden_ep{15}_3')
np.save(f'loss_{DATE}_{num_layers}layer_{hidden_size}hiddenep{15}_3', np.array(loss_vec))

In [232]:
# ---- plot loss
plt.figure()
plt.plot(np.load('loss_20191022_2layer_100hiddenep13.npy', allow_pickle=True).tolist(), label='2 LSTM layers, 100 hidden units, 100 songs')
plt.plot(np.load(f'loss_20191023_2layer_100hiddenep{200}.npy', allow_pickle=True).tolist(), label='2 LSTM layers, 100 hidden units, 10 songs')
plt.plot(np.load(f'loss_20191022_1layer_200hiddenep{200}_3.npy', allow_pickle=True).tolist(), label='1 LSTM layer, 200 hidden units, 10 songs')
plt.xlabel('Number of Songs')
plt.ylabel('Cross Entropy Loss')
plt.legend()
plt.title('Traning Loss for Different Architectures')
plt.show()


<IPython.core.display.Javascript object>

In [220]:
# ---- help function for generating mew music and creating MIDI file
def generate_from_random(unique_notes, seq_len=50):
  generate = np.random.randint(0,unique_notes-1,seq_len).tolist()
  return generate
    
def generate_from_one_note(note_tokenizer, new_notes='35'):
  generate = [note_tokenizer.notes_to_index['e'] for i in range(49)]
  generate += [note_tokenizer.notes_to_index[new_notes]]
  return generate

def generate_notes(generate, model, unique_notes, max_generated=1000, seq_len=50):
  for i in tqdm_notebook(range(max_generated), desc='genrt'):
    test_input = np.array([generate])[:,i:i+seq_len]
    test_input = torch.tensor(test_input, dtype=torch.float32)
    test_inputa = test_input[None,:,:]
    
    predicted_note, _, _ = model(test_inputa, use_softmax=False)
    predicted_note = predicted_note.detach().numpy()
#     predicted_note[predicted_note<0]=0
    prednote=predicted_note/np.sum(predicted_note)
    r = np.random.rand()
    if r < 0.0:
        random_note_pred = np.random.choice(unique_notes, 1, replace=False, p=prednote[0])
        generate.append(random_note_pred[0])
    else: 
        generate.append(prednote.argmax().item())
  return generate


def write_midi_file_from_generated(generate, midi_file_name = "result.mid", start_index=0, fs=8, max_generated=1000):
    generate = np.array(generate)
#     generate[generate<1] = 1
    note_string = [note_tokenizer.index_to_notes[ind_note] for ind_note in generate]
    array_piano_roll = np.zeros((128,max_generated+1), dtype=np.int16)
 
    for index, note in enumerate(note_string[start_index:]):
       if note == 'e':
         pass
       else:
         splitted_note = note.split(',')
         for j in splitted_note:
           array_piano_roll[int(j),index] = 1
 
    generate_to_midi = pp.piano_roll_to_pretty_midi(array_piano_roll, fs=fs)
    print("Tempo {}".format(generate_to_midi.estimate_tempo()))
    for note in generate_to_midi.instruments[0].notes:
        note.velocity = 75
    generate_to_midi.write(midi_file_name)


In [226]:
# ---- load saved model
hn = torch.randn(num_layers*num_directions, batch_size, hidden_size,dtype=torch.float32)
cn = torch.randn(num_layers*num_directions, batch_size, hidden_size,dtype=torch.float32)
model = our_LSTM(hn,cn)

model.load_state_dict(torch.load('lstm_20191023_2layer_100hidden_ep200'))

<All keys matched successfully>

In [231]:
# ---- generate new music
one_note = True
max_generate = 100
input_len = 10
unique_notes = note_tokenizer.unique_word
print(unique_notes)
seq_len=50
if one_note:
    generate = generate_from_one_note(note_tokenizer, '27')
    filename = 'a_one_note.mid'
else:
    generate = generate_from_random(unique_notes-1, seq_len)
    filename = 'a_random.mid'
# print(generate)
generate_song = generate_notes(generate, model, unique_notes, max_generate, seq_len)
print(generate_song)
write_midi_file_from_generated(generate_song, filename, start_index=seq_len-1, fs=5, max_generated = max_generate)

5756


HBox(children=(IntProgress(value=0, description='genrt', style=ProgressStyle(description_width='initial')), HT…


[5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 5756, 3470, 5737, 5737, 2106, 5736, 4529, 2106, 5737, 5737, 5737, 5742, 5742, 5742, 5742, 5742, 5742, 5742, 4550, 5742, 5742, 2110, 2110, 5747, 2105, 5747, 4568, 4568, 4568, 4568, 4568, 4568, 4568, 4568, 5752, 5752, 5737, 5736, 5736, 5736, 5742, 5742, 5742, 5736, 5736, 5736, 4523, 4523, 4523, 2107, 2107, 2107, 4544, 4544, 4544, 5742, 2105, 2105, 5742, 5742, 5742, 5742, 2105, 5742, 4550, 4550, 1179, 5747, 5747, 1179, 5747, 2102, 2413, 2413, 2413, 2413, 2413, 2413, 2413, 2413, 2413, 4550, 1185, 1185, 1185, 1185, 1465, 5736, 4535, 4535, 4535, 4535, 4535, 5753, 5753, 5753, 5753, 5753, 2111, 4555, 4531, 4528]
Tempo 150.00000000000009
