In [1]:
import torch
import numpy as np
import os
import time
import functools
from IPython import display as ipythondisplay
from tqdm import tqdm
!apt-get install abcmidi timidity > /dev/null 2>&1

In [2]:
# If CPU
dev = torch.device('cpu')

In [2]:
# If GPU
dev = torch.device('cuda')

In [3]:
instrument_numbers = [0, 6, 40, 41, 42, 43, 45, 60, 68, 70, 71, 73]
num_instruments = len(instrument_numbers)

In [4]:
# Custom Dataset class
class MIDIDataset(torch.utils.data.Dataset):
    # CONSTRUCTOR: creates an array of message chunks. Assumes that the directory
    # contains recording0.npy to recordingM.npy
    # ARGUMENTS
    # root_dir: the directory to search
    # chunk_size: we'll chunk the data into chunks of this size (or less)
    def __init__(self, root_dir, chunk_size, transform=None):
        recording_files = []
        instrument_files = []
        for file in os.listdir(root_dir):
            if 'recording' in file:
                recording_files.append(os.path.join(root_dir, file))
            elif 'instruments' in file:
                instrument_files.append(os.path.join(root_dir, file))
                
        assert(len(recording_files) == len(instrument_files))
        recording_files.sort()
        instrument_files.sort()
        
        self.chunks = []
        self.masks = []
        self.instruments = []
        
        ch = 0
        for f in range(len(recording_files)):
            data = np.load(recording_files[f])
            inst = [instrument_numbers.index(i) for i in np.load(instrument_files[f])]
            nchunks = int(np.ceil(data.shape[0]/chunk_size))
            self.chunks += [torch.zeros(chunk_size, dtype=torch.long, device=dev) for c in range(nchunks)]
            self.masks += [torch.ones(chunk_size, dtype=torch.bool, device=dev) for c in range(nchunks)]
            self.instruments += [torch.tensor(inst, device=dev) for c in range(nchunks)]
            for chunk_start in range(0, data.shape[0], chunk_size):
                chunk_end = min(chunk_start + chunk_size, data.shape[0])
                size = chunk_end - chunk_start
                self.chunks[ch][:size] = torch.tensor(data[chunk_start:chunk_end, 0], device=dev)
                self.masks[ch][:size] = False
                ch += 1
            
            if f%100 == 0:
                print(f)
            
        self.transform = transform

    # __len__
    # RETURN: the number of chunks in the dataset
    def __len__(self):
        return len(self.chunks)

    # __getitem__
    # ARGUMENTS
    # idx: indicates which chunk to get
    # RETURN: instance, a dictionary with keys 'messages' and 'mask'.
    # Both values associated with these keys are length L tensors
    def __getitem__(self, idx):  
        instance = {'messages': self.chunks[idx], \
                    'mask': self.masks[idx], \
                    'instruments': self.instruments[idx]}

        if self.transform:
            instance = self.transform(instance)
            
        return instance
    
def collate_fn(batch):
    B = len(batch)
    chunk_size = batch[0]['messages'].shape[0]
    ninst = [instance['instruments'].shape[0] for instance in batch]
    max_inst = max(ninst)
    sample = {'messages': torch.zeros((chunk_size, B), dtype=torch.long, device=dev), \
              'mask': torch.ones((chunk_size, B), dtype=torch.bool, device=dev),
              'instruments': torch.zeros((max_inst, B), dtype=torch.long, device=dev),
              'inst_mask': torch.zeros((max_inst, B), dtype=torch.bool, device=dev)}
    
    for b, instance in enumerate(batch):
        sample['messages'][:, b] = instance['messages']
        sample['mask'][:, b] = instance['mask']
        sample['instruments'][:ninst[b], b] = instance['instruments']
        sample['inst_mask'][:ninst[b], b] = True
    return sample

In [None]:
train_data = MIDIDataset('train_unified', 1000)
len(train_data)

In [None]:
test_data = MIDIDataset('test_unified', 1000)
len(test_data)

In [17]:
# Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html.
# Only change is the view/expand in forward (accounts for batches)
class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=10000):
        super(PositionalEncoding, self).__init__()
        self.dropout = torch.nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model, device=dev)
        position = torch.arange(0, max_len, dtype=torch.float, device=dev).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2, device=dev).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.shape[0], :].unsqueeze(1).expand(-1, x.shape[1], -1)
        return self.dropout(x)

# UnifiedTransformer: generates a sequence of MIDI messages
class UnifiedTransformer(torch.nn.Module):
    # CONSTRUCTOR
    # ARGUMENTS
    # message_dim: dimension of a MIDI message
    # embed_dim: dimension of message embedding
    # heads: number of heads
    # attention layers: number of attention layers
    def __init__(self, message_dim, embed_dim, heads, attention_layers, ff_size, num_instruments):
        super(UnifiedTransformer, self).__init__()
        self.embed_dim = embed_dim
        self.inst_embedding = torch.nn.Embedding(num_instruments, embed_dim)
        self.embedding = torch.nn.Embedding(message_dim, embed_dim)
        self.position_encoding = PositionalEncoding(embed_dim)
        encoder_layer = torch.nn.TransformerEncoderLayer(embed_dim, heads, ff_size)
        self.encoder = torch.nn.TransformerEncoder(encoder_layer, attention_layers)
        self.logits = torch.nn.Linear(embed_dim, message_dim)
    
    # forward: generates a probability distribution for the next MIDI message at each time step in a sequence
    # ARGUMENTS
    # seq: an LxB tensor, where L is the length of the longest message sequence in
    # the batch, and B is the batch size. This should be END-PADDED along dimension 0
    # hidden: previous hidden state (default None)
    # RETURN: an LxBxD tensor representing the logits for the next message in each batch,
    # as well as the last hidden state for the GRU
    def forward(self, seq, instruments, mask, inst_mask, hidden=None):
        L = seq.shape[0]
        B = seq.shape[1]
        inst_embed = torch.tanh(self.inst_embedding(instruments))*inst_mask.unsqueeze(2).expand(-1, -1, self.embed_dim)
        inst_tags = torch.sum(inst_embed, dim=0).view(1, B, -1).expand(L, -1, -1)
        transformer_inputs = self.position_encoding(self.embedding(seq)) + inst_tags
        src_mask = torch.triu(torch.ones((seq.shape[0], seq.shape[0]), dtype=torch.bool, device=dev))
        src_mask.fill_diagonal_(False)
        return self.logits(self.encoder(transformer_inputs, mask=src_mask, src_key_padding_mask=mask.transpose(0, 1)))

In [18]:
def compute_loss(model, loss_fn, batch):
    logits = model(batch['messages'][:-1], batch['instruments'], batch['mask'][:-1], batch['inst_mask'])
    mask = torch.logical_not(batch['mask'][1:].flatten())
    targets = batch['messages'][1:].flatten()
    return loss_fn(logits.view(-1, message_dim)[mask], targets[mask])

In [19]:
# Optimization parameters:
epochs = 10
batch_size = 16
learning_rate = 1e-3

# Model parameters:
num_notes = 128
num_time_shifts = 100
message_dim = 2*num_notes + num_time_shifts
embed_dim = 256 
heads = 4
attention_layers = 6
ff_size = 512

# Checkpoint location: 
checkpoint_dir = 'unified_transformer_checkpoints'

train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
model = UnifiedTransformer(message_dim, embed_dim, heads, attention_layers, ff_size, num_instruments).to(dev)
optimizer = torch.optim.Adam(model.parameters())
loss_fn = torch.nn.CrossEntropyLoss()

In [None]:
train_losses = [0 for epoch in range(epochs)]
test_losses = [0 for epoch in range(epochs)]
for epoch in range(epochs):
    print('Starting epoch %d' %(epoch))
    model.train()
    for b, batch in enumerate(train_dataloader):
        if b%10 == 0:
            print('Starting iteration %d' %(b))
        loss = compute_loss(model, loss_fn, batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    torch.save(model.state_dict(), checkpoint_dir + '/epoch' + str(epoch) + '.pth')
        
    model.eval()
    print('Computing test loss')
    for b, batch in enumerate(test_dataloader):
        test_losses[epoch] += compute_loss(model, loss_fn, batch).data
        
    test_losses[epoch] /= len(test_dataloader)
        
    print('Computing train loss')
    for b, batch in enumerate(train_dataloader):
        train_losses[epoch] += compute_loss(model, loss_fn, batch).data
        
    train_losses[epoch] /= len(train_dataloader)
        
    print('Train loss: %f, Test loss: %f' %(train_losses[epoch], test_losses[epoch]))

In [None]:
np.save(checkpoint_dir + '/train_losses.npy', np.array(train_losses))
np.save(checkpoint_dir + '/test_losses.npy', np.array(test_losses))

In [None]:
model = UnifiedTransformer(message_dim, embed_dim, heads, attention_layers, ff_size, num_instruments).to(dev)
model.load_state_dict(torch.load(checkpoint_dir + '/epoch9.pth', map_location=dev))
model.eval()

In [11]:
# Sample from model
def generate_music(model, primer, instruments, gen_length=1000):
    hidden = None
    gen = torch.tensor(primer, device=dev).unsqueeze(1)
    inst_mask = torch.ones(instruments.shape, dtype=torch.bool, device=dev)
    mask = torch.zeros(gen.shape, dtype=torch.bool, device=dev)
    for i in range(gen_length):
        logits = model(message, instruments, mask, inst_mask)
        message = torch.multinomial(torch.nn.functional.softmax(logits[-1].flatten(), dim=0), 1).view(1, 1)
        gen = torch.cat((gen, message))
        mask = torch.cat((mask, torch.zeros((1, 1), dtype=torch.bool, device=dev)))
        
    return gen

In [58]:
instruments = [2, 2, 3, 4]
instruments = torch.tensor(instruments, device=dev).view(len(instruments), 1)
generated_music = generate_music(model, [16], instruments, gen_length=1000)

In [59]:
np.save('unified_transformer_midi.npy', generated_music.cpu().flatten().detach().numpy())

In [18]:
import random
import os
possible_primers = list(range(50, 70))
base_path = 'composer_transformer_midis/'
folders = [str(i) + 'chan' for i in range(1, 21)]
for nchan in range(1, 21):
    for i in range(10):
        instruments = random.choices(list(range(num_instruments)), k=nchan)
        primer = random.sample(possible_primers, 1)
        generated_music = generate_music(model, primer, torch.tensor(instruments, dtype=torch.long).unsqueeze(1), gen_length=1000)
        np.save(base_path + folders[nchan - 1] + '/sample' + str(i) + '.npy', generated_music.cpu().flatten().detach().numpy())
    