In [1]:
import torch
import numpy as np
import os
import time
import functools
from IPython import display as ipythondisplay
from tqdm import tqdm
!apt-get install abcmidi timidity > /dev/null 2>&1

In [2]:
# If CPU
dev = torch.device('cpu')

In [None]:
# If GPU
dev = torch.device('gpu')

In [3]:
instrument_numbers = [0, 6, 40, 41, 42, 43, 45, 60, 68, 70, 71, 73]
num_instruments = len(instrument_numbers)

In [4]:
# Custom Dataset class
class MIDIDataset(torch.utils.data.Dataset):
    # CONSTRUCTOR: creates an array of message chunks. Assumes that the directory
    # contains recording0.npy to recordingM.npy
    # ARGUMENTS
    # root_dir: the directory to search
    # chunk_size: we'll chunk the data into chunks of this size (or less)
    def __init__(self, root_dir, chunk_size, transform=None):
        recording_files = []
        instrument_files = []
        for file in os.listdir(root_dir):
            if 'recording' in file:
                recording_files.append(os.path.join(root_dir, file))
            elif 'instruments' in file:
                instrument_files.append(os.path.join(root_dir, file))
                
        assert(len(recording_files) == len(instrument_files))
        recording_files.sort()
        instrument_files.sort()
        
        self.chunks = []
        self.masks = []
        self.instruments = []
        
        ch = 0
        for f in range(len(recording_files)):
            data = np.load(recording_files[f])
            inst = [instrument_numbers.index(i) for i in np.load(instrument_files[f])]
            nchunks = int(np.ceil(data.shape[0]/chunk_size))
            self.chunks += [torch.zeros((chunk_size, 2), dtype=torch.long, device=dev) for c in range(nchunks)]
            self.masks += [torch.ones(chunk_size, dtype=torch.bool, device=dev) for c in range(nchunks)]
            self.instruments += [torch.tensor(inst, dtype=torch.long, device=dev) for c in range(nchunks)]
            for chunk_start in range(0, data.shape[0], chunk_size):
                chunk_end = min(chunk_start + chunk_size, data.shape[0])
                size = chunk_end - chunk_start
                self.chunks[ch][:size] = torch.tensor(data[chunk_start:chunk_end]).to(dev)
                self.masks[ch][:size] = False
                ch += 1
            
            if f%100 == 0:
                print(f)
            
        self.transform = transform

    # __len__
    # RETURN: the number of chunks in the dataset
    def __len__(self):
        return len(self.chunks)

    # __getitem__
    # ARGUMENTS
    # idx: indicates which chunk to get
    # RETURN: instance, a dictionary with keys 'history,' 'mask,' and 'instruments'.
    # Both values associated with these keys are length L tensors
    def __getitem__(self, idx):  
        instance = {'history': self.chunks[idx], \
                    'mask': self.masks[idx],
                    'instruments': self.instruments[idx]}

        if self.transform:
            instance = self.transform(instance)
            
        return instance
    
def collate_fn(batch):
    chunk_size = batch[0]['history'].shape[0]
    ninst = [instance['instruments'].shape[0] for instance in batch]
    max_inst = max(ninst)
    B = len(batch)
    sample = {'history': torch.zeros((chunk_size, 2, B), dtype=torch.long, device=dev), \
              'mask': torch.ones((chunk_size, B), dtype=torch.bool, device=dev),
              'instruments': torch.zeros((max_inst, B), dtype=torch.long, device=dev), \
              'inst_mask': torch.ones((max_inst, B), dtype=torch.bool, device=dev)}
    
    for b, instance in enumerate(batch):
        sample['history'][:, :, b] = instance['history']
        sample['mask'][:, b] = instance['mask']
        sample['instruments'][:ninst[b], b] = instance['instruments']
        sample['inst_mask'][:ninst[b], b] = False
        
    return sample

In [5]:
train_data = MIDIDataset('train_unified', 1000)
len(train_data)

0
100
200


2545

In [6]:
test_data = MIDIDataset('test_unified', 1000)
len(test_data)

0


636

In [47]:
# Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html.
# Only change is the view/expand in forward (accounts for batches)
class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=10000):
        super(PositionalEncoding, self).__init__()
        self.dropout = torch.nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.shape[0], :].unsqueeze(1).expand(-1, x.shape[1], -1)
        return self.dropout(x)

# MultiTransformer: generates a sequence of MIDI messages and their associated channels
class MultiTransformer(torch.nn.Module):
    # CONSTRUCTOR
    # ARGUMENTS
    # message_dim: dimension of a MIDI message
    # embed_dim: dimension of message embedding
    # heads: number of heads
    # attention layers: number of attention layers
    def __init__(self, message_dim, embed_dim, heads, attention_layers, ff_size, num_instruments):
        super(MultiTransformer, self).__init__()
        self.embed_dim = embed_dim
        self.embedding = torch.nn.Embedding(message_dim, embed_dim)
        self.inst_embedding = torch.nn.Embedding(num_instruments, embed_dim)
        self.position_encoding = PositionalEncoding(embed_dim)
        encoder_layer = torch.nn.TransformerEncoderLayer(embed_dim, heads, ff_size)
        self.encoder = torch.nn.TransformerEncoder(encoder_layer, attention_layers)
        self.message_logits = torch.nn.Linear(embed_dim, message_dim)
        
        # Instrument-wise attention
        self.inst_logits = torch.nn.Linear(embed_dim, embed_dim)
        self.inst_attention = torch.nn.MultiheadAttention(embed_dim, heads)
    
    # forward: generates a probability distribution for the next MIDI message at each time step in a sequence
    # ARGUMENTS
    # history: an Lx2xB tensor, where L is the length of the longest message sequence in
    # the batch, and B is the batch size. This should be END-PADDED along dimension 0.
    # The first index along dimension 1 contains the message, the second contains the channel.
    # All time-shifts should be associated with channel -1
    # mask: LxB, contains False where messages exist, and True otherwise
    # instruments: NxB, instrument numbers for each instance
    # inst_mask: NxB, contains False where instruments exist, and True otherwise
    # RETURN: an LxBxD tensor representing the logits for the next message in each batch
    def forward(self, history, mask, instruments, inst_mask):
        L = history.shape[0]
        B = history.shape[1]
        
        inst_embed = self.inst_embedding(instruments)
        message_embed = self.position_encoding(self.embedding(history[:, 0, :]))
        
        time_shift_mask = history[:, 1, :] < 0
        channel_sel = history[:, 1, :].clone()
        channel_sel[time_shift_mask] = 0
        
        # LxBxD
        inst_tags = torch.gather(inst_embed, 0, channel_sel.unsqueeze(2).expand(-1, -1, self.embed_dim))
        inst_tags[time_shift_mask] = 0
        
        inputs = message_embed + inst_tags + torch.sum(inst_embed, dim=0).unsqueeze(0).expand(L, -1, -1)
        
        src_mask = torch.triu(torch.ones((history.shape[0], history.shape[0]), dtype=torch.bool))
        src_mask.fill_diagonal_(False)
        
        encoding = self.encoder(inputs, mask=src_mask, src_key_padding_mask=mask.transpose(0, 1))
        
        message_logits = self.message_logits(encoding)
        inst_logits = torch.tanh(self.inst_logits(encoding))
        
        # attn_weights is BxLxN
        attn_output, attn_weights = self.inst_attention(inst_logits, inst_embed, inst_embed, key_padding_mask=inst_mask.transpose(0, 1))
        return message_logits, attn_weights.transpose(0, 1)

def compute_loss(model, batch):
    message_logits, channel_weights = model(batch['history'][:-1], batch['mask'][:-1], batch['instruments'], batch['inst_mask'])
    
    target_mask = torch.logical_not(batch['mask'][1:].flatten())
    target_messages = batch['history'][1:, 0].flatten()
    message_loss = torch.nn.functional.cross_entropy(message_logits.view(-1, message_dim)[target_mask], target_messages[target_mask])
    
    target_channels = batch['history'][1:, 1].flatten()
    max_inst = batch['instruments'].shape[0]
    channel_loss = torch.nn.functional.nll_loss(channel_weights.reshape(-1, max_inst)[target_mask], target_channels[target_mask], ignore_index=-1)
    
    return message_loss + channel_loss

In [48]:
# Optimization parameters:
epochs = 10
batch_size = 10
learning_rate = 1e-3

# Model parameters: 
num_notes = 128
num_time_shifts = 100
message_dim = 2*num_notes + num_time_shifts
embed_dim = 512
heads = 4
attention_layers = 6
ff_size = 512

# Checkpoint location: 
checkpoint_dir = 'multi_transformer_checkpoints'

train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
model = MultiTransformer(message_dim, embed_dim, heads, attention_layers, ff_size, num_instruments).to(dev)
optimizer = torch.optim.Adam(model.parameters())

In [None]:
train_losses = [0 for epoch in range(epochs)]
test_losses = [0 for epoch in range(epochs)]
for epoch in range(epochs):
    print('Starting epoch %d' %(epoch))
    model.train()
    for b, batch in enumerate(train_dataloader):
        if b%100 == 0:
            print('Starting iteration %d' %(b))
        loss = compute_loss(model, batch)
        optimizer.zero_grad()
        print(loss.data)
        loss.backward()
        optimizer.step()
        
    torch.save(model.state_dict(), checkpoint_dir + '/epoch' + str(epoch) + '.pth')
        
    model.eval()
    print('Computing test loss')
    for b, batch in enumerate(test_dataloader):
        test_losses[epoch] += compute_loss(model,  batch).data
        
    test_losses[epoch] /= len(test_dataloader)
        
    print('Computing train loss')
    for b, batch in enumerate(train_dataloader):
        train_losses[epoch] += compute_loss(model, batch).data
        
    train_losses[epoch] /= len(train_dataloader)
        
    print('Train loss: %f, Test loss: %f' %(train_losses[epoch], test_losses[epoch]))

Starting epoch 0
Starting iteration 0
tensor(5.8959)
tensor(5.1866)
tensor(5.0292)


In [None]:
import matplotlib.pyplot as plt
COLOR = 'white'
plt.rcParams['text.color'] = COLOR
plt.rcParams['axes.labelcolor'] = COLOR
plt.rcParams['xtick.color'] = COLOR
plt.rcParams['ytick.color'] = COLOR
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['figure.dpi'] = 100 # 200 e.g. is really fine, but slower
plt.rcParams.update({'font.size': 22})
plt.plot(train_losses)
plt.xlabel('Epochs')
plt.ylabel('Average Loss (Nats)')
plt.title('Training Loss')
plt.savefig('transformer_baseline_train_loss.png')

In [None]:
plt.plot(test_losses)
plt.xlabel('Epochs')
plt.ylabel('Average Loss (Nats)')
plt.title('Test Loss')
plt.savefig('transformer_baseline_test_loss.png')

In [77]:
np.save(checkpoint_dir + '/train_losses.npy', np.array(train_losses))
np.save(checkpoint_dir + '/test_losses.npy', np.array(test_losses))

In [9]:
model = MultiTransformer(message_dim, embed_dim, heads, attention_layers, ff_size, num_instruments)
model.load_state_dict(torch.load(checkpoint_dir + '/epoch9.pth'))

<All keys matched successfully>

In [31]:
# Sample from model
def generate_music(model, primer, instruments, gen_length=1000):
    gen = primer.clone() # Lx2x1
    mask = torch.zeros(gen.shape[0], dtype=torch.bool, device=dev).unsqueeze(1) # Lx1
    inst_mask = torch.zeros(instruments.shape, dtype=torch.bool)
    for i in range(gen_length):
        message_logits, channel_attn = model(gen, mask, instruments, inst_mask)
        message = torch.multinomial(torch.nn.functional.softmax(message_logits[-1].flatten(), dim=0), 1).view(1, 1)
        channel = torch.multinomial(channel_attn[-1].flatten(), 1).view(1, 1)
        cat = torch.cat((message, channel), dim=1).unsqueeze(2)
        gen = torch.cat((gen, cat), dim=0)
        mask = torch.cat((mask, torch.zeros((1, 1), dtype=torch.bool, device=dev)), dim=0)
        
    return gen

In [32]:
instruments = torch.tensor([0], dtype=torch.long, device=dev).view(1, 1)
primer = torch.tensor([16, 0], dtype=torch.long, device=dev).view(1, 2, 1)
gen = generate_music(model, primer, instruments, 1000)

In [12]:
np.save('multi_transformer_midi.npy', generated_music.squeeze(2).detach().numpy())

In [34]:
print(gen[:10, 0, :].flatten())

tensor([ 16,  66, 210, 198, 204, 200, 235,  76, 179, 193])
