In [326]:
import torch
import numpy as np
import os
import time
import functools
from IPython import display as ipythondisplay
from tqdm import tqdm
!apt-get install abcmidi timidity > /dev/null 2>&1

In [414]:
# Custom Dataset class
class MIDIDataset(torch.utils.data.Dataset):
    # CONSTRUCTOR: creates an array of message chunks. Assumes that the directory
    # contains recording0.npy to recordingM.npy
    # ARGUMENTS
    # root_dir: the directory to search
    # chunk_size: we'll chunk the data into chunks of this size (or less)
    def __init__(self, root_dir, chunk_size, transform=None):
        self.chunks = []
        self.masks = []
        
        ch = 0
        for f, file in enumerate(os.listdir(root_dir)):
            data = np.load(root_dir + '/' + file)
            nchunks = int(np.ceil(data.shape[0]/chunk_size))
            self.chunks += [torch.zeros(chunk_size, dtype=torch.long) for c in range(nchunks)]
            self.masks += [torch.zeros(chunk_size, dtype=torch.bool) for c in range(nchunks)]
            for chunk_start in range(0, data.shape[0], chunk_size):
                chunk_end = min(chunk_start + chunk_size, data.shape[0])
                size = chunk_end - chunk_start
                self.chunks[ch][:size] = torch.tensor(data[chunk_start:chunk_end])
                self.masks[ch][:size] = False
                ch += 1
            
            if f%100 == 0:
                print(f)
            
        self.transform = transform

    # __len__
    # RETURN: the number of chunks in the dataset
    def __len__(self):
        return len(self.chunks)

    # __getitem__
    # ARGUMENTS
    # idx: indicates which chunk to get
    # RETURN: instance, a dictionary with keys 'messages' and 'mask'.
    # Both values associated with these keys are length L tensors
    def __getitem__(self, idx):  
        instance = {'messages': self.chunks[idx], \
                    'mask': self.masks[idx]}

        if self.transform:
            instance = self.transform(instance)
            
        return instance
    
def collate_fn(batch):
    chunk_size = batch[0]['messages'].shape[0]
    sample = {'messages': torch.zeros((chunk_size, len(batch)), dtype=torch.long), \
              'mask': torch.ones((chunk_size, len(batch)), dtype=torch.bool)}
    
    for b, instance in enumerate(batch):
        sample['messages'][:, b] = instance['messages']
        sample['mask'][:, b] = instance['mask']
    return sample

In [422]:
train_data = MIDIDataset('train_sc', 200)
len(train_data)

0
100
200


5743

In [423]:
test_data = MIDIDataset('test_sc', 200)
len(test_data)

0


1494

In [424]:
# BaselineLSTM: generates a sequence of MIDI messages
class BaselineLSTM(torch.nn.Module):
    # CONSTRUCTOR
    # ARGUMENTS
    # message_dim: dimension of a MIDI message
    # embed_dim: dimension of message embedding
    # hidden_size: size of hidden LSTM state
    # recurrent_layers: the number of layers in the lstm
    def __init__(self, message_dim, embed_dim, hidden_size, recurrent_layers=3):
        super(BaselineLSTM, self).__init__()
        self.embedding = torch.nn.Embedding(message_dim, embed_dim)
        self.lstm = torch.nn.LSTM(embed_dim, hidden_size, num_layers=recurrent_layers)  
        self.logits = torch.nn.Linear(hidden_size, message_dim)
    
    # forward: generates a probability distribution for the next MIDI message at each time step in a sequence
    # ARGUMENTS
    # seq: an LxB tensor, where L is the length of the longest message sequence in
    # the batch, and B is the batch size. This should be END-PADDED along dimension 0
    # hidden: previous hidden state (default None)
    # RETURN: an LxBxD tensor representing the logits for the next message in each batch,
    # as well as the last hidden state for the LSTM
    def forward(self, seq, hidden=None):
        out, new_hidden = self.lstm(self.embedding(seq), hidden)
        return self.logits(out), new_hidden

In [425]:
def compute_loss(model, loss_fn, batch):
    logits, new_hidden = model(batch['messages'][:-1])
    mask = torch.logical_not(batch['masks'][1:].flatten())
    targets = batch['messages'][1:].flatten()
    return loss_fn(logits.view(-1, message_dim)[mask], targets[mask])

In [426]:
# Optimization parameters:
epochs = 10
batch_size = 10
learning_rate = 1e-3

# Model parameters: 
message_dim = 388
embed_dim = 256 
hidden_size = 1024
recurrent_layers = 3

# Checkpoint location: 
checkpoint_dir = 'sc_lstm_checkpoints'

train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)
model = BaselineLSTM(message_dim, embed_dim, hidden_size, recurrent_layers)
optimizer = torch.optim.Adam(model.parameters())
loss_fn = torch.nn.CrossEntropyLoss()

In [427]:
train_losses = [0 for epoch in range(epochs)]
test_losses = [0 for epoch in range(epochs)]
for epoch in range(epochs):
    print('Starting epoch %d' %(epoch))
    model.train()
    for b, batch in enumerate(train_dataloader):
        if b%100 == 0:
            print('Starting iteration %d' %(b))
        loss = compute_loss(model, loss_fn, batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    torch.save(model.state_dict(), checkpoint_dir + '/epoch' + str(epoch) + '.pth')
        
    model.eval()
    print('Computing test loss')
    for b, batch in enumerate(test_dataloader):
        test_losses[epoch] += compute_loss(model, loss_fn, batch).data
        
    test_losses[epoch] /= len(test_dataloader)
        
    print('Computing train loss')
    for b, batch in enumerate(train_dataloader):
        train_losses[epoch] += compute_loss(model, loss_fn, batch).data
        
    train_losses[epoch] /= len(train_dataloader)
        
    print('Train loss: %f, Test loss: %f' %(train_losses[epoch], test_losses[epoch]))

Starting epoch 0
Starting iteration 0
Starting iteration 100
Starting iteration 200
Starting iteration 300
Starting iteration 400
Starting iteration 500
Computing test loss
Computing train loss
Train loss: 5.048178, Test loss: 5.065695
Starting epoch 1
Starting iteration 0
Starting iteration 100
Starting iteration 200
Starting iteration 300
Starting iteration 400
Starting iteration 500
Computing test loss
Computing train loss
Train loss: 5.037400, Test loss: 5.058115
Starting epoch 2
Starting iteration 0
Starting iteration 100


KeyboardInterrupt: 

In [None]:
np.save('maestro_lstm_checkpoints/train_losses.npy', np.array(train_losses))
np.save('maestro_lstm_checkpoints/test_losses.npy', np.array(test_losses))

In [None]:
model = BaselineLSTM(message_dim, embed_dim, hidden_size, recurrent_layers)
model.load_state_dict(torch.load('maestro_lstm_checkpoints/epoch' + str(epochs - 1)))

In [311]:
# Sample from model
def generate_music(model, primer, gen_length=1000):
    hidden = None
    gen = torch.tensor(primer).unsqueeze(1)
    message = gen
    for i in range(gen_length):
        logits, hidden = model(message, hidden)
        message = torch.multinomial(torch.nn.functional.softmax(logits[-1].flatten(), dim=0), 1).view(1, 1)
        gen = torch.cat((gen, message))
        
    return gen

In [316]:
generated_music = generate_event(model, start_event=[16,60, 330, 188, 64, 330, 192, 67,330, 195, 72, 330, 200,67,330, 195,64, 330, 192,60, 330, 188], generation_length=1000)

100%|██████████| 1000/1000 [00:11<00:00, 86.23it/s]


In [317]:
np.save('maestro_midi.npy', generated_music.flatten().detach().numpy())