In [1]:
import torch
import numpy as np
import math
import os
import matplotlib.pyplot as plt

In [9]:
num_notes = 128
num_time_shifts = 100
num_velocities = 32
message_dim = 2*num_notes + num_velocities + num_time_shifts
instrument_numbers = [0, 6, 40, 41, 42, 43, 45, 60, 68, 70, 71, 73]
num_instruments = len(instrument_numbers)
max_channels = 14

# Ensemble LSTM definition

In [3]:
# Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html.
# Only change is the view/expand in forward (accounts for batches)
class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=10000):
        super(PositionalEncoding, self).__init__()
        self.dropout = torch.nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.shape[0], :].unsqueeze(1).expand(-1, x.shape[1], -1)
        return self.dropout(x)
        

# EnsembleLSTM: takes a history of MIDI messages 
# for instruments in an ensemble and generates a distribution for the next message,
# as well as the instrument who should issue the message
class EnsembleLSTM(torch.nn.Module):
    # CONSTRUCTOR
    # ARGUMENTS
    # message_dim: dimension of a MIDI message
    # embed_dim: dimension of message embedding
    # num_instruments: number of instrument labels
    # hidden_size: size of hidden LSTM state
    # heads: number of attention heads
    # recurrent_layers: the number of layers in the lstm
    def __init__(self, message_dim, embed_dim, num_instruments, hidden_size, heads, recurrent_layers=3):
        super(EnsembleLSTM, self).__init__()
        
        self.embed_dim = embed_dim
        
        # We add the tanhed instrument embedding to each input message
        self.i_embedding = torch.nn.Embedding(num_instruments, embed_dim)
        
        self.embedding = torch.nn.Embedding(message_dim, embed_dim)
        
        # A 3-layer LSTM takes the history of messages (concatenated with their
        # associated instrument encoding) and produces a decoding
        self.lstm = torch.nn.LSTM(embed_dim, hidden_size, num_layers=recurrent_layers)

        # The decoding is passed through a linear layer to get the logits for the next message        
        self.message_logits = torch.nn.Linear(hidden_size, message_dim)
        
        # The decoding becomes a query for attention across the instruments, which is used to
        # predict the next instrument
        self.inst_query = torch.nn.Linear(hidden_size, embed_dim)
        self.inst_attention = torch.nn.MultiheadAttention(embed_dim, heads)
        
        # Indicates which channel is associated with each instrument
        self.position_encoding = PositionalEncoding(embed_dim)
    
    # forward: generates a probability distribution for the next MIDI message
    # and the channel that issues the message, given a message history for the instrument ensemble
    # ARGUMENTS
    # history: an LxBx2 tensor, where L is the length of the longest message history in
    # the batch, and B is the batch size. The first index along dimension 2 stores the
    # message number. The second stores the channel number. This should be END-PADDED
    # along dimension 0. All time shifts should be associated with channel -1.
    # mask: an LxB tensor, containing True in any locations where history contains
    # padding
    # instruments: a CxB tensor indicating the instrument number for each channel, where
    # N is the maximum number of channels in the batch. This should be END-PADDED along dimension 0
    # inst_mask: contains False where an instrument exists and True elsewhere
    # RETURN: two tensors. The first is LxBxD, representing the distribution for the next message at each time
    # step (need to take the softmax to get actual probabilities). The second is LxBxC, representing the
    # distribution for the next channel at each time step (need to take the softmax to get actual probabilities)
    def forward(self, history, mask, instruments, inst_mask):
        L = history.shape[0] # longest length
        B = history.shape[1] # batch size
        N = instruments.shape[0]
        assert(mask.shape == (L, B))
        
        # CxBxD
        inst_embed = self.position_encoding(torch.tanh(self.i_embedding(instruments)))
        
        # Which messages are time shifts?
        time_shift_mask = history[:, :, 1] < 0
        
        # LxBxD, instrument embedding associated with each message
        inst_sel = history[:, :, 1].unsqueeze(2).expand(-1, -1, self.embed_dim).clone()
        inst_sel[time_shift_mask] = 0
        
        inst_tags = torch.gather(inst_embed, 0, inst_sel)
        inst_tags[time_shift_mask] = 0
        
        # LxBxD
        inputs = self.embedding(history[:, :, 0]) + inst_tags
        
        decoding, last_hidden = self.lstm(inputs)
        
        # LxBxD
        message_dist = self.message_logits(decoding)
        
        # channel_dist (BxLxN) contains the attention weights for each instrument.
        # We have L queries (the elements of decoding). Our keys and values
        # are the instrument embeddings
        att_out, channel_dist = self.inst_attention(self.inst_query(decoding), \
                                                    inst_embed, inst_embed,
                                                    key_padding_mask = inst_mask.transpose(0, 1))
        
        return message_dist, channel_dist.transpose(0, 1)
    
    # forward: generates a probability distribution for the next MIDI message
    # and the channel that should issue it, given the previous message and channel,
    # as well as the LSTM hidden state
    # ARGUMENTS
    # last_token: a 1x1x2 tensor. The first index along dimension 2 stores the
    # message number. The second stores the channel number
    # instruments: a Cx1 tensor indicating the instrument number for each channel, where C is the number of channels
    # The the last channel index should contain num_instruments (indicating the "time-shift instrument")
    # hidden: the last hidden state for the LSTM
    # RETURN: a 1x1x2 tensor, predicting the next message and the channel that should issue it, as well as the new
    def forward_generate(self, last_token, instruments, hidden):
        assert(last_token.shape == (1, 1, 2))
        C = instruments.shape[0]
        
        # Cx1xD
        inst_embed = self.position_encoding(torch.tanh(self.i_embedding(instruments)))
        
        time_shift_mask = last_token[:, :, 1] < 0
        inst_sel = last_token[:, :, 1].unsqueeze(2).expand(-1, -1, self.embed_dim).clone()
        inst_sel[time_shift_mask] = 0
        
        # 1x1xD, instrument embedding associated with each message
        inst_tags = torch.gather(inst_embed, 0, inst_sel)
        inst_tags[time_shift_mask] = 0
        
        # 1x1xD
        inputs = self.embedding(last_token[:, :, 0]) + inst_tags
        
        decoding, new_hidden = self.lstm(inputs, hidden)
        
        # 1x1xD
        message_dist = self.message_logits(decoding)
        
        # channel_dist (1x1xC) contains the attention weights for each instrument.
        # We have 1 query (the decoding). Our keys and values
        # are the instrument embeddings
        att_out, channel_dist = self.inst_attention(self.inst_query(decoding), \
                                                    inst_embed, inst_embed)
        
        message = torch.multinomial(torch.softmax(message_dist.flatten(), dim=0), 1)
        channel = torch.multinomial(channel_dist.flatten(), 1)
        
        #message = torch.argmax(torch.softmax(message_dist.flatten(), dim=0))
        #channel = torch.argmax(channel_dist.flatten())
        
        ret = torch.cat((message.view(1, 1, 1), channel.view(1, 1, 1)), dim=2)
        
        return ret, new_hidden

# Tests for EnsembleLSTM
We train with model.eval() to disable dropout, since these tests try to get the model to overfit to a small sequence

Get the model to overfit to a single song

In [36]:
embed_dim = 256
hidden_size = 1024
heads = 4

grad_clip = 10

model = EnsembleLSTM(message_dim, embed_dim, num_instruments, hidden_size, heads)
for p in model.parameters():
    p.register_hook(lambda grad: torch.clamp(grad, -grad_clip, grad_clip))
    
model.eval() # Training with eval just to see if we can overfit without dropout
pass

In [60]:
model.load_state_dict(torch.load('overfit_single_song2.pth'))

<All keys matched successfully>

In [None]:
recording = np.load('train_unified/recording0.npy', allow_pickle=True)
instruments_np = np.load('train_unified/instruments0.npy', allow_pickle=True)

nsamples = 100

history = torch.tensor(recording[:nsamples], dtype=torch.long).view(-1, 1, 2)
mask = torch.zeros((history.shape[0], history.shape[1]), dtype=torch.bool)
instruments = torch.tensor([instrument_numbers.index(i) for i in instruments_np], dtype=torch.long).view(-1, 1)
inst_mask = torch.zeros(instruments.shape, dtype=torch.bool)

num_channels = instruments.shape[0]

batch_size = 1
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
message_loss_fn = torch.nn.CrossEntropyLoss()
channel_loss_fn = torch.nn.NLLLoss(ignore_index=-1)
epochs = 500
train_losses = np.zeros(epochs)

target_messages = history[1:, :, 0].flatten()
target_channels = history[1:, :, 1].flatten()

for epoch in range(epochs):
    print('Starting epoch %d' %(epoch))
    
    message_logits, channel_probs = model(history[:-1], mask[:-1], instruments, inst_mask)
    channel_log_probs = torch.log(channel_probs + 1e-10)
    
    loss = message_loss_fn(message_logits.view(-1, message_dim), target_messages) + \
           channel_loss_fn(channel_log_probs.view(-1, instruments.shape[0]), target_channels)
                
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_losses[epoch] = loss.data
    print('Loss: %f' %(loss.data))

In [57]:
torch.save(model.state_dict(), 'overfit_single_song2.pth')

In [None]:
plt.plot(train_losses)

In [None]:
# Sample from model
gen_history = history[0].unsqueeze(0)
mask = torch.zeros((1, 1), dtype=torch.bool)
model.eval() # Turns off the dropout for evaluation. Need to do this to get repeatable evaluation outputs

# Move forward in time
wrong_cnt = 0
hidden = (torch.zeros(3, 1, hidden_size), torch.zeros(3, 1, hidden_size))
for t in range(0, history.shape[0] - 1):
    ret, hidden = model.forward_generate(gen_history[-1].view(1, 1, 2), instruments, hidden)
    
    gen_history = torch.cat((gen_history, ret), dim=0)
    
    if gen_history[-1, 0, 0] != history[t + 1, 0, 0]:
        print('Wrong message at time %d!' %(t))
        wrong_cnt += 1
        
    if gen_history[-1, 0, 1] != history[t + 1, 0, 1]:
        print('Wrong instrument at time %d!' %(t))
        wrong_cnt += 1
    
    mask = torch.cat((mask, torch.zeros((1, 1), dtype=torch.bool)), dim=0)

print(wrong_cnt)

In [44]:
np.save('test_history.npy', gen_history.squeeze(1).detach().numpy())
np.save('test_instruments.npy', [instrument_numbers[i] for i in instruments[:-1, 0]])

# Custom dataset class

In [10]:
# Custom Dataset class
class MIDIDataset(torch.utils.data.Dataset):
    # CONSTRUCTOR: creates a tensor of message chunks and associated instruments.
    # Assumes that the directory contains recording0.npy to recordingM.npy
    # as well as instruments0.npy to instrumentsM.npy
    # ARGUMENTS
    # root_dir: the directory to search
    # chunk_size: we'll chunk the data into chunks of this size (or less)
    # max_channels: what's the largest number of instruments in any file?
    def __init__(self, root_dir, chunk_size, max_channels, transform=None):
        files = os.listdir(root_dir)
        recording_files = []
        instrument_files = []
        for file in files:
            if 'recording' in file:
                recording_files.append(os.path.join(root_dir, file))
            elif 'instruments' in file:
                instrument_files.append(os.path.join(root_dir, file))
                
        assert(len(recording_files) == len(instrument_files))
        recording_files.sort()
        instrument_files.sort()
        
        self.chunks = []
        self.masks = []
        self.instruments = []
        self.inst_masks = []
        
        ch = 0
        for f in range(len(recording_files)):
            recording = np.load(recording_files[f], allow_pickle=True)
            inst = [instrument_numbers.index(i) for i in np.load(instrument_files[f], allow_pickle=True)]
            
            nchunks = int(np.ceil(recording.shape[0]/chunk_size))
            self.chunks += [torch.zeros((chunk_size, 2), dtype=torch.long) for c in range(nchunks)]
            self.masks += [torch.ones(chunk_size, dtype=torch.bool) for c in range(nchunks)]
            self.instruments += [torch.zeros(max_channels, dtype=torch.long) for c in range(nchunks)]
            self.inst_masks += [torch.ones(max_channels, dtype=torch.long) for c in range(nchunks)]
            for chunk_start in range(0, recording.shape[0], chunk_size):
                chunk_end = min(chunk_start + chunk_size, recording.shape[0])
                size = chunk_end - chunk_start
                self.chunks[ch][:size] = torch.tensor(recording[chunk_start:chunk_end], dtype=torch.long)
                self.masks[ch][:size] = False
                self.instruments[ch][:len(inst)] = torch.tensor(inst, dtype=torch.long)
                self.inst_masks[ch][:len(inst)] = False
                ch += 1
            
        self.transform = transform

    # __len__
    # RETURN: the number of recording chunks in the dataset
    def __len__(self):
        return len(self.chunks)

    # __getitem__
    # ARGUMENTS
    # idx: indicates which chunk(s) to get
    # RETURN: instance, a dictionary with keys 'history' and 'instruments'
    # instance['history'] is an Lx2 tensor containing messages and associated channels
    # instance['instruments'] a length N tensor of instrument numbers
    # instance['mask'] a length L tensor containing False where messages exist and True otherwise
    # instance['inst_mask'] a length N tensor containing False where instruments exist and True otherwise
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        instance = {'history': self.chunks[idx], \
                    'instruments': self.instruments[idx],
                    'mask': self.masks[idx],
                    'inst_mask': self.inst_masks[idx]}
        
        if self.transform:
            instance = self.transform(instance)
            
        return instance
    
def collate_fn(batch):
    chunk_size = batch[0]['history'].shape[0]
    max_channels = batch[0]['instruments'].shape[0]
    sample = {'history': torch.zeros((chunk_size, len(batch), 2), dtype=torch.long), \
              'instruments': torch.ones((max_channels, len(batch)), dtype=torch.long), \
              'mask': torch.ones((chunk_size, len(batch)), dtype=torch.bool),
              'inst_mask': torch.ones((max_channels, len(batch)), dtype=torch.bool)}
    
    for b, instance in enumerate(batch):
        sample['history'][:, b] = instance['history']
        sample['instruments'][:, b] = instance['instruments']
        sample['mask'][:, b] = instance['mask']
        sample['inst_mask'][:, b] = instance['inst_mask']
    return sample

# Train the model

In [11]:
# compute_loss: computes the loss for the model over the batch
# ARGUMENTS
# model: EnsembleLSTM model
# message_loss_fn: torch.nn.CrossEntropyLoss object
# channel_loss_fn: torch.nn.NLLLoss object
# batch: see collate_fn definition
# RETURN: a scalar loss tensor
def compute_loss(model, message_loss_fn, channel_loss_fn, batch):  
    max_seq_length = batch['history'].shape[0]

    message_logits, channel_dist = model(batch['history'][:-1], batch['mask'][:-1], batch['instruments'], batch['inst_mask'])
    log_channel_dist = torch.log(channel_dist + 1e-10)

    target_mask = torch.logical_not(batch['mask'][1:])

    message_loss = message_loss_fn(message_logits[target_mask], batch['history'][1:, :, 0][target_mask])
    channel_loss = channel_loss_fn(log_channel_dist[target_mask], batch['history'][1:, :, 1][target_mask])

    return message_loss + channel_loss

In [12]:
embed_dim = 256
hidden_size = 1024
heads = 4

grad_clip = 10

model = EnsembleLSTM(message_dim, embed_dim, num_instruments, hidden_size, heads)
for p in model.parameters():
    p.register_hook(lambda grad: torch.clamp(grad, -grad_clip, grad_clip))

In [13]:
optimizer = torch.optim.Adam(model.parameters())

In [None]:
batch_size = 10
learning_rate = 0.001
chunk_size = 500

train_dataset = MIDIDataset('train_unified', chunk_size, max_channels)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

test_dataset = MIDIDataset('test_unified', chunk_size, max_channels)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

message_loss_fn = torch.nn.CrossEntropyLoss()
channel_loss_fn = torch.nn.NLLLoss(ignore_index=-1)
epochs = 20
train_losses = np.zeros(epochs)
test_losses = np.zeros(epochs)

for epoch in range(epochs):
    print('Starting epoch %d' %(epoch))
    model.train()
    for b, batch in enumerate(train_dataloader):
        print('Starting iteration %d' %(b))
        loss = compute_loss(model, message_loss_fn, channel_loss_fn, batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    torch.save(model.state_dict(), 'unified_transformer_models/epoch' + str(epoch) + '.pth')

    print('Computing test loss')
    model.eval()
    for batch in test_dataloader:
        loss = compute_loss(model, message_loss_fn, channel_loss_fn, batch)
        test_losses[epoch] += loss.data
        
    print('Computing train loss')
    for batch in train_dataloader:
        loss = compute_loss(model, message_loss_fn, channel_loss_fn, batch)
        train_losses[epoch] += loss.data
    
    train_losses[epoch] /= len(train_dataloader)
    test_losses[epoch] /= len(test_dataloader)
    print('Train Loss: %f, Test Loss: %f' %(train_losses[epoch], test_losses[epoch])),

In [None]:
plt.plot(train_losses)

# Sample from the model

In [None]:
model.load_state_dict(torch.load('trained_models_12_3/epoch4.pth'))

In [None]:
model.eval() # Disable dropout to make results repeatable

time_steps = 500 # How many time steps do we sample?

# Start with a time shift
gen_history = torch.zeros((1, 1, 2), dtype=torch.long)
gen_history[0, 0, 0] = 387
gen_history[0, 0, 1] = -1

# Violin
instruments = torch.zeros((1, 1), dtype=torch.long)
instruments[0, 0] = 2
    
hidden = (torch.zeros(3, 1, hidden_size), torch.zeros(3, 1, hidden_size))
for t in range(0, time_steps):
    ret, hidden = model.forward_generate(gen_history[-1].view(1, 1, 2), instruments, hidden)
    gen_history = torch.cat((gen_history, ret), dim=0)

In [None]:
np.save('gen_history.npy', gen_history.squeeze(1).detach().numpy())
np.save('gen_instruments.npy', [instrument_numbers[i] for i in instruments[:, 0]])