In [1]:
import matplotlib
matplotlib.use('Agg')
import numpy as np
import glob
import pypianoroll as ppr
import time
import music21
import os
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
import torch.utils.data as data
from utils.utilsPreprocessing import *
import time
#np.set_printoptions(threshold=np.inf)
torch.set_printoptions(threshold=50000)

In [2]:
###HYPERPARAMETERS######################
epochs = 100
batch_size = 10
learning_rate = 1e-3
log_interval = 1
hidden_size = 400
num_layers = 2
validate = False #play/show a random sample after each epoch

# Beat resolution
If beat resolution = 1 --> 1 tick = 1/4 note<br>
If beat resolution = 2 --> 1 tick = 1/8 note<br>
If beat resolution = 4 --> 1 tick = 1/16 note<br>
If beat resolution = 8 --> 1 tick = 1/32 note<br>
...



In [None]:
beat_resolution = 4

In [None]:
pathToFiles = "/media/EXTHD/niciData/Datasets/YamahaPianoCompetition2002NoTransposeFiles/"
midiDatasetTrain = createDatasetLSTM(pathToFiles + "train/*.mid", beat_res = beat_resolution)
midiDatasetTrain.setMaxLength()

midiDatasetTest = createDatasetLSTM(pathToFiles + "test/*.mid", beat_res = beat_resolution)
midiDatasetTest.setMaxLength()

midiDatasetVal = createDatasetLSTM("../WikifoniaServer/samples/*.mid")
midiDatasetVal.setMaxLength()

In [None]:
print("\nThere are {} songs in the training set\n".format(len(midiDatasetTrain)))
print("There are {} songs in the test set\n".format(len(midiDatasetTest)))
print("There are {} songs in the validation set\n".format(len(midiDatasetVal)))

In [None]:
train_loader = torch.utils.data.DataLoader(midiDatasetTrain, batch_size=batch_size, shuffle=False, drop_last=True)
test_loader = torch.utils.data.DataLoader(midiDatasetTest, batch_size=batch_size, shuffle=False, drop_last=True)
val_loader = torch.utils.data.DataLoader(midiDatasetTest, batch_size=1, shuffle=True, drop_last=True)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class LSTM_notewise(nn.Module):
    def __init__(self, hidden_size=400, num_layers=2, batch_size=1):
        super(LSTM_notewise, self).__init__()
        
        self.input_size = 60
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.batch_size = batch_size
        self.num_classes = 60

        
        self.i2h = nn.Linear(60, self.hidden_size)
        self.lstm = nn.LSTM(self.hidden_size, self.hidden_size,
                            self.num_layers, batch_first=True)
        self.h2o = nn.Linear(self.hidden_size, 60)
        
    def initState(self):
        state = torch.zeros(self.num_layers,
                            self.batch_size, 
                            self.hidden_size).double().to(device)
        return state
    
    def forward(self, input, seq_lengths):
        h_t = self.initState()
        c_t = self.initState()
        
        embedded_notes = self.i2h(input)
        embedded_notes = torch.nn.utils.rnn.pack_padded_sequence(embedded_notes, 
                                                                 seq_lengths, 
                                                                 batch_first=True)    
        out, (h_t, c_t) = self.lstm(embedded_notes, (h_t, c_t))
        out, out_lengths = torch.nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
        out = F.sigmoid(self.h2o(out))
        #out = 1-out
        #print(out.size())
        
        return out

In [None]:
model = LSTM_notewise(hidden_size=hidden_size, num_layers=num_layers, 
                      batch_size=batch_size).double().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
#load model
#from loadModel import loadStateDict
#pathToModel = "/media/EXTHD/niciData/models/LSTM_notewise.pth"

#model = loadStateDict(model, pathToModel)


In [None]:
def train(epoch):
    
    model.train()
    train_loss = 0
    criterion = nn.BCELoss()
    for batch_idx, data in enumerate(train_loader):
        optimizer.zero_grad()
        input_lstm, ground_truth, seq_lengths = reorderBatch(data)
        prediction = model(input_lstm.double().to(device), seq_lengths)
        ground_truth = ground_truth.to(device)
        print('')
        print(prediction.size(), ground_truth.size())
        print('prediction', torch.argmax(prediction[0,40:41,:]))
        print('ground_truth', torch.argmax(ground_truth[0,40:41,:]))
        
        loss = criterion(prediction, ground_truth)
        loss.backward()
        train_loss += loss.item()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if(batch_idx % log_interval == 0):
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average Loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))
    return train_loss

def test(epoch):
    model.eval()
    test_loss = 0
    criterion = nn.BCEWithLogitsLoss()
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            input_lstm, ground_truth, seq_lengths = reorderBatch(data)
            prediction = model(input_lstm.double().to(device), seq_lengths)
            
            test_loss += criterion(prediction, ground_truth.to(device)).item()


    print('====> Test set Loss: {:.4f}'.format(test_loss))
    return test_loss

def val(epoch):
    model.eval()
    val_loss = 0
    criterion = nn.BCEWithLogitsLoss()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            input_lstm, ground_truth, seq_lengths = reorderBatch(data)
            prediction = model(input_lstm.double().to(device), seq_lengths)
            val_loss = criterion(prediction, ground_truth.to(device)).item()

            break
    print('====> Validation Loss: {:.4f}'.format(val_loss))
    return val_loss
        

In [None]:
import matplotlib.pyplot as plt


train_losses = []
test_losses = []
if validate:
    val_losses = []
best_test_loss = 999
for epoch in range(1, epochs + 1):
    #train
    current_train_loss = (train(epoch))
    train_losses.append(current_train_loss)
    #test
    current_test_loss = test(epoch)
    test_losses.append(current_test_loss)
    #save if model better than best model
    if(current_test_loss < best_test_loss):
        best_test_loss = current_test_loss
        torch.save(model.state_dict(),'/media/EXTHD/niciData/LSTM_notewise.pth')
    
    #validate
    if validate:
        current_val_loss = val(epoch)
        val_losses.append(current_val_loss)
 
    
plt.plot(train_losses, color='red', label='Train loss')
plt.plot(test_losses, color='orange', label='Test loss')
if validate:
    plt.plot(val_losses, color='yellow', label='Validation loss')
plt.legend()
plt.savefig('LSTM_notewise.png')

# Generate

In [None]:
"""
model.eval()
model.batch_size=1
with torch.no_grad():
    #get unseeen sample from validation set
    for data in val_loader:
        input_lstm, ground_truth, seq_lengths = reorderBatch(data)
        prediction = model(input_lstm.double().to(device), seq_lengths)
        print(prediction.size())

        prediction = prediction.squeeze(0).cpu().numpy()
        print(prediction)
        #prediction /= np.max(np.abs(prediction))
        #prediction[prediction < 0.2] = 0
        prediction = debinarizeMidi(prediction, prediction=True)
        prediction = addCuttedOctaves(prediction)
        print(prediction)
        pianorollMatrixToTempMidi(prediction, show=True, showPlayer=True, autoplay=False,
                                 path='../temp/temp.mid')
"""
print('')