In [1]:
import torch
import numpy as np
import os
import time
import functools
from IPython import display as ipythondisplay
from tqdm import tqdm
!apt-get install abcmidi timidity > /dev/null 2>&1

In [2]:
dev = torch.device('cpu')

In [7]:
dev = torch.device('cuda')

In [3]:
# Custom Dataset class
class MIDIDataset(torch.utils.data.Dataset):
    # CONSTRUCTOR: creates an array of message chunks. Assumes that the directory
    # contains recording0.npy to recordingM.npy
    # ARGUMENTS
    # root_dir: the directory to search
    # chunk_size: we'll chunk the data into chunks of this size (or less)
    def __init__(self, root_dir, chunk_size, transform=None):
        self.chunks = []
        self.masks = []
        
        ch = 0
        for f, file in enumerate(os.listdir(root_dir)):
            data = np.load(root_dir + '/' + file)
            nchunks = int(np.ceil(data.shape[0]/chunk_size))
            self.chunks += [torch.zeros(chunk_size, dtype=torch.long, device=dev) for c in range(nchunks)]
            self.masks += [torch.ones(chunk_size, dtype=torch.bool, device=dev) for c in range(nchunks)]
            for chunk_start in range(0, data.shape[0], chunk_size):
                chunk_end = min(chunk_start + chunk_size, data.shape[0])
                size = chunk_end - chunk_start
                self.chunks[ch][:size] = torch.tensor(data[chunk_start:chunk_end], device=dev)
                self.masks[ch][:size] = False
                ch += 1
            
            if f%100 == 0:
                print(f)
            
        self.transform = transform

    # __len__
    # RETURN: the number of chunks in the dataset
    def __len__(self):
        return len(self.chunks)

    # __getitem__
    # ARGUMENTS
    # idx: indicates which chunk to get
    # RETURN: instance, a dictionary with keys 'messages' and 'mask'.
    # Both values associated with these keys are length L tensors
    def __getitem__(self, idx):  
        instance = {'messages': self.chunks[idx], \
                    'mask': self.masks[idx]}

        if self.transform:
            instance = self.transform(instance)
            
        return instance
    
def collate_fn(batch):
    chunk_size = batch[0]['messages'].shape[0]
    sample = {'messages': torch.zeros((chunk_size, len(batch)), dtype=torch.long, device=dev), \
              'mask': torch.ones((chunk_size, len(batch)), dtype=torch.bool, device=dev)}
    
    for b, instance in enumerate(batch):
        sample['messages'][:, b] = instance['messages']
        sample['mask'][:, b] = instance['mask']
    return sample

In [None]:
train_data = MIDIDataset('train_sc', 1000)
len(train_data)

In [None]:
test_data = MIDIDataset('test_sc', 1000)
len(test_data)

In [6]:
# BaselineGRU: generates a sequence of MIDI messages
class BaselineGRU(torch.nn.Module):
    # CONSTRUCTOR
    # ARGUMENTS
    # message_dim: dimension of a MIDI message
    # embed_dim: dimension of message embedding
    # hidden_size: size of hidden GRU state
    # recurrent_layers: the number of layers in the gru
    def __init__(self, message_dim, embed_dim, hidden_size, recurrent_layers=3):
        super(BaselineGRU, self).__init__()
        self.embedding = torch.nn.Embedding(message_dim, embed_dim)
        self.gru = torch.nn.GRU(embed_dim, hidden_size, num_layers=recurrent_layers)  
        self.logits = torch.nn.Linear(hidden_size, message_dim)
    
    # forward: generates a probability distribution for the next MIDI message at each time step in a sequence
    # ARGUMENTS
    # seq: an LxB tensor, where L is the length of the longest message sequence in
    # the batch, and B is the batch size. This should be END-PADDED along dimension 0
    # hidden: previous hidden state (default None)
    # RETURN: an LxBxD tensor representing the logits for the next message in each batch,
    # as well as the last hidden state for the GRU
    def forward(self, seq, hidden=None):
        out, new_hidden = self.gru(self.embedding(seq), hidden)
        return self.logits(out), new_hidden

In [7]:
def compute_loss(model, loss_fn, batch):
    logits, new_hidden = model(batch['messages'][:-1])
    mask = torch.logical_not(batch['mask'][1:].flatten())
    targets = batch['messages'][1:].flatten()
    return loss_fn(logits.view(-1, message_dim)[mask], targets[mask])

In [28]:
# Optimization parameters:
epochs = 10
batch_size = 16
learning_rate = 1e-3

# Model parameters: 
message_dim = 388
embed_dim = 256 
hidden_size = 1024
recurrent_layers = 3

# Checkpoint location: 
checkpoint_dir = 'sc_gru_checkpoints'

train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
model = BaselineGRU(message_dim, embed_dim, hidden_size, recurrent_layers).to(dev)
optimizer = torch.optim.Adam(model.parameters())
loss_fn = torch.nn.CrossEntropyLoss()

In [14]:
train_losses = [0 for epoch in range(epochs)]
test_losses = [0 for epoch in range(epochs)]
for epoch in range(epochs):
    print('Starting epoch %d' %(epoch))
    model.train()
    for b, batch in enumerate(train_dataloader):
        if b%100 == 0:
            print('Starting iteration %d' %(b))
        loss = compute_loss(model, loss_fn, batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    torch.save(model.state_dict(), checkpoint_dir + '/epoch' + str(epoch) + '.pth')
        
    model.eval()
    print('Computing test loss')
    for b, batch in enumerate(test_dataloader):
        test_losses[epoch] += compute_loss(model, loss_fn, batch).data
        
    test_losses[epoch] /= len(test_dataloader)
        
    print('Computing train loss')
    for b, batch in enumerate(train_dataloader):
        train_losses[epoch] += compute_loss(model, loss_fn, batch).data
        
    train_losses[epoch] /= len(train_dataloader)
        
    print('Train loss: %f, Test loss: %f' %(train_losses[epoch], test_losses[epoch]))

Starting epoch 0
Starting iteration 0
Starting iteration 100
Starting iteration 200
Starting iteration 300
Computing test loss
Computing train loss
Train loss: 2.752546, Test loss: 2.882606
Starting epoch 1
Starting iteration 0
Starting iteration 100
Starting iteration 200
Starting iteration 300
Computing test loss
Computing train loss
Train loss: 2.502642, Test loss: 2.774918
Starting epoch 2
Starting iteration 0
Starting iteration 100
Starting iteration 200
Starting iteration 300
Computing test loss
Computing train loss
Train loss: 2.309583, Test loss: 2.757482
Starting epoch 3
Starting iteration 0
Starting iteration 100
Starting iteration 200
Starting iteration 300
Computing test loss
Computing train loss
Train loss: 2.120018, Test loss: 2.787809
Starting epoch 4
Starting iteration 0
Starting iteration 100
Starting iteration 200
Starting iteration 300
Computing test loss
Computing train loss
Train loss: 1.929709, Test loss: 2.865342
Starting epoch 5
Starting iteration 0
Starting ite

In [15]:
np.save(checkpoint_dir + '/train_losses.npy', np.array(train_losses))
np.save(checkpoint_dir + '/test_losses.npy', np.array(test_losses))

In [None]:
import matplotlib.pyplot as plt
COLOR = 'white'
plt.rcParams['text.color'] = COLOR
plt.rcParams['axes.labelcolor'] = COLOR
plt.rcParams['xtick.color'] = COLOR
plt.rcParams['ytick.color'] = COLOR
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['figure.dpi'] = 100 # 200 e.g. is really fine, but slower
plt.rcParams.update({'font.size': 22})
plt.plot(train_losses)
plt.xlabel('Epochs')
plt.ylabel('Average Loss (Nats)')
plt.title('Training Loss')
plt.savefig('gru_baseline_train_loss.png')

In [None]:
plt.plot(test_losses)
plt.xlabel('Epochs')
plt.ylabel('Average Loss (Nats)')
plt.title('Training Loss')
plt.savefig('gru_baseline_test_loss.png')

In [60]:
model = BaselineGRU(message_dim, embed_dim, hidden_size, recurrent_layers).to(dev)
model.load_state_dict(torch.load(checkpoint_dir + '/epoch9.pth', map_location=dev))
model.eval()

BaselineGRU(
  (embedding): Embedding(388, 256)
  (gru): GRU(256, 1024, num_layers=3)
  (logits): Linear(in_features=1024, out_features=388, bias=True)
)

In [61]:
# Sample from model
def generate_music(model, primer, gen_length=1000):
    hidden = None
    gen = torch.tensor(primer, device=dev).unsqueeze(1)
    message = gen
    for i in range(gen_length):
        logits, hidden = model(message, hidden)
        message = torch.multinomial(torch.nn.functional.softmax(logits[-1].flatten(), dim=0), 1).view(1, 1)
        gen = torch.cat((gen, message))
        
    return gen

In [64]:
generated_music = generate_music(model, primer=[50], gen_length=200)

In [65]:
np.save('sc_gru_midi.npy', generated_music.cpu().flatten().detach().numpy())