In [2]:
import torch
import numpy as np
import math
import os
import matplotlib.pyplot as plt
num_notes = 128
num_time_shifts = 100
num_velocities = 32
message_dim = 2*num_notes + num_velocities + num_time_shifts
instrument_numbers = [0, 6, 40, 41, 42, 43, 45, 60, 68, 70, 71, 73]
num_instruments = len(instrument_numbers)
# Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html.
# Only change is the view/expand in forward (accounts for multiple instruments and batches)
# Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html.
# Only change is the view/expand in forward (accounts for multiple instruments and batches)
class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=100000):
        super(PositionalEncoding, self).__init__()
        self.dropout = torch.nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.shape[0], :].view(x.shape[0], 1, 1, -1).expand(-1, x.shape[1], x.shape[2], -1)
        return self.dropout(x)
        

# EnsembleTransformer: takes a history of MIDI messages 
# for instruments in an ensemble and generates a distribution for the next message
# for a specific instrument
class EnsembleTransformer(torch.nn.Module):
    # CONSTRUCTOR
    # ARGUMENTS
    # message_dim: dimension of a MIDI message
    # embed_dim: dimension of message embedding
    # num_instruments: number of instrument labels
    # heads: number of attention heads
    # attention_layers: number of attention layers
    # ff_size: size of the feedforward output at the end of the decoder
    def __init__(self, message_dim, embed_dim, num_instruments, heads, attention_layers, ff_size):
        super(EnsembleTransformer, self).__init__()
        
        self.embed_dim = embed_dim
        
        # We add the tanhed instrument embedding to each input message 
        # (this is the global conditioning idea from DeepJ, which comes from WaveNet)
        self.i_embedding = torch.nn.Embedding(num_instruments, embed_dim)
        
        self.position_encoding = PositionalEncoding(embed_dim)
        
        self.embedding = torch.nn.Embedding(message_dim, embed_dim)
        
        # An encoder is used to transform histories of all instruments in the ensemble
        # except the instrument we're generating music for
        encoder_layer = torch.nn.TransformerEncoderLayer(embed_dim, heads, ff_size)
        self.encoder = torch.nn.TransformerEncoder(encoder_layer, attention_layers)
        
        # A decoder is used to transform the history of the instrument we're generating
        # music for, then combine this with the encoder output to generate the next message
        decoder_layer = torch.nn.TransformerDecoderLayer(embed_dim, heads, ff_size)
        self.decoder = torch.nn.TransformerDecoder(decoder_layer, attention_layers)
        
        self.logits = torch.nn.Linear(embed_dim, message_dim)
    
    # forward: generates a probability distribution for the next MIDI message
    # for an instrument, given the message history of the ensemble
    # ARGUMENTS
    # history: an LxNxB tensor, where L is the length of the longest history in
    # the batch, N is the max number of instruments in the batch, and B is the batch size. As we
    # walk along the first dimension, we should see indices of MIDI events for a particular
    # instrument
    # mask: an LxNxB tensor, containing True where a message or instrument doesn't exist
    # instruments: a 1xNxB tensor indicating the instrument numbers for each batch
    # gen_idx: a the index (along dimension 1 of history) of the instrument we want to generate music for
    # RETURN: an LxBxD tensor, representing the distribution for
    # the next MIDI message for the instrument indicated by gen_idx at each time step. 
    # Note, to get the actual probabilities you'll have to take the softmax
    # of this tensor along dimension 1
    def forward(self, history, mask, instruments, gen_idx):
        L = history.shape[0] # longest length
        N = history.shape[1] # max instruments
        B = history.shape[2] # batch size
        assert(instruments.shape == (1, N, B))
        assert(mask.shape == history.shape)
        
        inputs = self.embedding(history) + torch.tanh(self.i_embedding(instruments)).expand(L, -1, -1, -1)
        
        inputs = self.position_encoding(inputs)
        
        memory_key_padding_mask = None
        memory_mask = None
        tgt_mask = torch.triu(torch.ones((L, L), dtype=torch.bool))
        tgt_mask.fill_diagonal_(False)
        
        # If only one instrument, only run the decoder
        if N == 1:
            memory = torch.zeros((1, B, self.embed_dim))
        else:
            memory_idx = [i for i in range(N) if i != gen_idx]
            memory = inputs[:, memory_idx].view(-1, B, self.embed_dim)
            memory_key_padding_mask = mask[:, memory_idx].view(-1, B).transpose(0, 1)
            memory_mask = tgt_mask.repeat(1, N - 1)
        decoder_inputs = inputs[:, gen_idx]
        tgt_key_padding_mask = mask[:, gen_idx].transpose(0, 1)
        
        chunk_size = 50
        decoding = None
        
        for i in range(0, decoder_inputs.shape[0], chunk_size):
            if i == 0:
                start, end = i, min(i+chunk_size, decoder_inputs.shape[0])
            else:
                start, end = i-chunk_size,  min(i+chunk_size, decoder_inputs.shape[0])
            chunked_decoder_inputs = decoder_inputs[start:end, ]
            chunked_tgt_key_padding_mask = tgt_key_padding_mask[:, start:end]
            size = end - start
            chunked_tgt_mask = torch.triu(torch.ones((size, size), dtype=torch.bool))
            chunked_tgt_mask.fill_diagonal_(False)
            chunked_memory_mask = memory_mask[start:end,]
            chunked_decoding = self.decoder(chunked_decoder_inputs, memory, tgt_mask=chunked_tgt_mask, memory_mask=chunked_memory_mask, \
                                tgt_key_padding_mask=chunked_tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)
            if decoding == None:
                decoding = chunked_decoding
            else:
                #only take the second half of the output since each chunk attends to the current and the previous chunks
                decoding = torch.cat((decoding, chunked_decoding[chunk_size:chunk_size*2,]))
            print(decoding.shape)
        return self.logits(decoding)

    # Custom Dataset class
class MIDIDataset(torch.utils.data.Dataset):
    # CONSTRUCTOR: creates a list of recording files and a list
    # of instrument files in root_dir. Assumes that the directory
    # contains recording0.npy to recordingM.npy,
    # as well as instruments0.npy to instrumentsM.npy
    # ARGUMENTS
    # root_dir: the directory to search
    def __init__(self, root_dir, transform=None):
        files = os.listdir(root_dir)
        self.recordings = []
        self.instrument_files = []
        self.time_files = []
        for file in files:
            if 'recording' in file:
                self.recordings.append(os.path.join(root_dir, file))
            elif 'instruments' in file:
                self.instrument_files.append(os.path.join(root_dir, file))
            elif 'times' in file:
                self.time_files.append(os.path.join(root_dir, file))
                
        assert(len(self.recordings) == len(self.instrument_files))
        assert(len(self.recordings) == len(self.time_files))
        self.recordings.sort()
        self.instrument_files.sort()
        self.time_files.sort()
        self.transform = transform

    # __len__
    # RETURN: the number of recording files in the dataset
    def __len__(self):
        return len(self.recordings)

    # __getitem__
    # ARGUMENTS
    # idx: indicates which file to get
    # RETURN: an instance with keys 'instruments', 'history', and 'times'
    # instance['history'] is a numpy array of message sequences for each instrument
    # instance['instruments'] a numpy array of instrument numbers
    # instance['times'] is a numpy array of message time sequences for each instrument
    def __getitem__(self, idx):
        instance = {'history': np.load(self.recordings[idx], allow_pickle=True), \
                    'instruments': np.load(self.instrument_files[idx], allow_pickle=True), \
                    'times': np.load(self.time_files[idx], allow_pickle=True)}
        
        assert(len(instance['history']) == len(instance['times']))
        assert(len(instance['history']) == len(instance['instruments']))
        
        for i in range(len(instance['history'])):
            assert(len(instance['history'][i]) == len(instance['times'][i]))
        
        if self.transform:
            instance = self.transform(instance)
        return instance
    
# collate_fn: takes a list of samples from the dataset and turns them into a batch
# ARGUMENTS
# batch: a list of dictionaries
# RETURN: a sample with keys 'history', 'instruments', 'times', and 'mask'
# sample['history']: an LxNxB tensor containing messages
# sample['instruments']: a 1xNxB tensor containing instrument numbers
# sample['mask']: an LxNxB tensor containing False where a message is
# valid, and True where it isn't (accounts for variable length sequences
# and zero padding)
# sample['times']: an LxNxB tensor containing times of each message
def collate_fn(batch):
    batch_size = len(batch)
    
    # We size our tensors to accomodate the longest sequence and the largest ensemble
    max_instruments = max([len(instance['history']) for instance in batch])
    longest_len = max([max([seq.shape[0] for seq in instance['history']]) for instance in batch])

    sample = {'history': torch.ones((longest_len, max_instruments, batch_size), dtype=torch.long), \
              'instruments': torch.zeros((1, max_instruments, batch_size), dtype=torch.long), \
              'times': torch.zeros((longest_len, max_instruments, batch_size)), \
              'mask': torch.ones((longest_len, max_instruments, batch_size), dtype=torch.bool)}

    for b in range(batch_size):
        instrument_idx = [instrument_numbers.index(inst) for inst in batch[b]['instruments']]
        sample['instruments'][0, :len(instrument_idx), b] = torch.tensor(instrument_idx, dtype=torch.long)
        
        for inst_idx in range(len(batch[b]['history'])):
            seq_length = len(batch[b]['history'][inst_idx])
            sample['history'][:seq_length, inst_idx, b] = torch.tensor(batch[b]['history'][inst_idx], dtype=torch.long)
            sample['mask'][:seq_length, inst_idx, b] = False
            sample['times'][:seq_length, inst_idx, b] = torch.tensor(batch[b]['times'][inst_idx])
            
    return sample

In [None]:
ff_size = 512
heads = 8
attention_layers=  6
embed_dim = 256
grad_clip = 10

model = EnsembleTransformer(message_dim, embed_dim, len(instrument_numbers), heads, attention_layers, ff_size)

for p in model.parameters():
    p.register_hook(lambda grad: torch.clamp(grad, -grad_clip, grad_clip))
    
batch_size = 1
learning_rate = 0.001

dataset = MIDIDataset('preprocessed_data')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

optimizer = torch.optim.Adam(model.parameters())
loss_fn = torch.nn.CrossEntropyLoss()
epochs = 10
train_losses = np.zeros(epochs)
test_losses = np.zeros(epochs) # TODO: train/test split. Can we do this with Dataloader?

model.train()

for epoch in range(epochs):
    print('Starting epoch %d' %(epoch))
    for b, batch in enumerate(dataloader):
        print('Starting iteration %d' %(b))
        max_seq_length = batch['history'].shape[0]
        num_targets = max_seq_length - 1 # Messages start from t = 0, but we start generating at t = 1
        max_instruments = batch['history'].shape[1]
        loss = torch.tensor([0])
        for inst in range(max_instruments):         
            mask = batch['mask']
            
            logits = model(batch['history'][:-1], mask[:-1], batch['instruments'], inst)
            logits = logits.view(-1, message_dim)
            target_messages = batch['history'][1:, inst].flatten()
            output_mask = torch.logical_not(mask[1:, inst].flatten())
            loss = loss + loss_fn(logits[output_mask], target_messages[output_mask])
        
        loss /= max_instruments
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_losses[epoch] += loss.data

Starting epoch 0
Starting iteration 0
torch.Size([50, 1, 256])
torch.Size([100, 1, 256])
torch.Size([150, 1, 256])
torch.Size([200, 1, 256])
torch.Size([250, 1, 256])
torch.Size([300, 1, 256])
torch.Size([350, 1, 256])
torch.Size([400, 1, 256])
torch.Size([450, 1, 256])
torch.Size([500, 1, 256])
torch.Size([550, 1, 256])
torch.Size([600, 1, 256])
torch.Size([650, 1, 256])
torch.Size([700, 1, 256])
torch.Size([750, 1, 256])
torch.Size([800, 1, 256])
torch.Size([850, 1, 256])
torch.Size([900, 1, 256])
torch.Size([950, 1, 256])
