In [2]:
import torch
import numpy as np
import math
import os
import matplotlib.pyplot as plt

In [3]:
num_notes = 128
num_time_shifts = 100
num_velocities = 32
message_dim = 2*num_notes + num_velocities + num_time_shifts
instrument_numbers = [0, 6, 40, 41, 42, 43, 45, 60, 68, 70, 71, 73]
num_instruments = len(instrument_numbers)

# AssignerLSTM Definition

Predicts which channel is associated with each message in the message history

In [7]:
# Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html.
# Only change is the view/expand in forward (accounts for batches)
class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=10000):
        super(PositionalEncoding, self).__init__()
        self.dropout = torch.nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.shape[0], :].unsqueeze(1).expand(-1, x.shape[1], -1)
        return self.dropout(x)

# Takes a history of MIDI messages and assigns each one to an instrument
class AssignerLSTM(torch.nn.Module):
    # CONSTRUCTOR
    # ARGUMENTS
    # message_dim: dimension of a MIDI message
    # embed_dim: dimension of message embedding
    # num_instruments: number of instrument labels
    # hidden_size: size of hidden LSTM state
    # heads: number of attenion heads
    # recurrent_layers: the number of layers in the lstm
    def __init__(self, message_dim, embed_dim, num_instruments, hidden_size, heads, recurrent_layers=3):
        super(AssignerLSTM, self).__init__()
        
        self.embed_dim = embed_dim
        
        self.i_embedding = torch.nn.Embedding(num_instruments, embed_dim)
        
        # Used to indicate which channel belongs to each instrument
        self.position_encoding = PositionalEncoding(embed_dim)
        
        self.embedding = torch.nn.Embedding(message_dim, embed_dim)
        
        # A 3-layer LSTM takes the history of messages and produces a decoding
        self.lstm = torch.nn.LSTM(embed_dim, hidden_size, num_layers=recurrent_layers)

        # The decoding is passed through a linear layer to get a query for instrument attention 
        self.query = torch.nn.Linear(hidden_size, embed_dim)
        
        self.attention = torch.nn.MultiheadAttention(embed_dim, heads)
    
    # forward: generates a probability distribution for which instrument in the ensemble is associated with each
    # message in the history
    # ARGUMENTS
    # history: an LxB tensor, where L is the length of the longest message history in
    # the batch, and B is the batch size. This should be END-PADDED along dimension 0
    # mask: an LxB tensor, containing True in any locations where history contains padding
    # instruments: a NxB tensor indicating the instruments in each channel. This should be END-PADDED along dimension 0
    # inst_mask: a NxB containing False where an instrument exists and True otherwise
    # RETURN: an LxBxN tensor representing the probabilities that an instrument is associated with a message
    def forward(self, history, mask, instruments, inst_mask):
        L = history.shape[0] # longest length
        B = history.shape[1] # batch size
        assert(mask.shape == history.shape)
        assert(instruments.shape == inst_mask.shape)
        
        # NxBxD
        inst_embed = self.position_encoding(torch.tanh(self.i_embedding(instruments)))
        
        # LxBxD
        inputs = self.embedding(history[:, :])
        
        decoding, last_hidden = self.lstm(inputs)
               
        queries = self.query(decoding)
               
        # att_weights is BxLxN
        att, att_weights = self.attention(queries, inst_embed, inst_embed, inst_mask.transpose(0, 1))
               
        return att_weights.transpose(0, 1)

# Tests for AssignerLSTM

In [14]:
embed_dim = 256
hidden_size = 1024
heads = 4

grad_clip = 10

model = AssignerLSTM(message_dim, embed_dim, num_instruments, hidden_size, heads)
for p in model.parameters():
    p.register_hook(lambda grad: torch.clamp(grad, -grad_clip, grad_clip))
    
model.eval() # Training with eval just to see if we can overfit without dropout
pass

In [None]:
model.load_state_dict(torch.load('overfit_assigner.pth'))

In [None]:
recording = np.load('train_unified/recording0.npy', allow_pickle=True)
instruments_np = np.load('train_unified/instruments0.npy', allow_pickle=True)

nsamples = 200

message_history = torch.tensor(recording[:nsamples, 0], dtype=torch.long).view(-1, 1)
channel_history = torch.tensor(recording[:nsamples, 1], dtype=torch.long)
mask = torch.zeros(message_history.shape, dtype=torch.bool)
instruments = torch.tensor([instrument_numbers.index(i) for i in instruments_np], dtype=torch.long).view(-1, 1)
inst_mask = torch.zeros(instruments.shape, dtype=torch.bool)

batch_size = 1
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = torch.nn.NLLLoss()
epochs = 250
train_losses = np.zeros(epochs)

target_channels = channel_history.flatten()

for epoch in range(epochs):
    print('Starting epoch %d' %(epoch))
    
    channel_weights = torch.log(model(message_history, mask, instruments, inst_mask))
    
    time_shift_mask = target_channels >= 0
    loss = loss_fn(channel_weights.squeeze(1)[time_shift_mask], target_channels[time_shift_mask])
                
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_losses[epoch] = loss.data
    print('Loss: %f' %(loss.data))

In [22]:
torch.save(model.state_dict(), 'overfit_assigner.pth')

In [21]:
model.eval() # Turns off the dropout for evaluation. Need to do this to get repeatable evaluation outputs

channels = torch.multinomial(model(message_history, mask, instruments, inst_mask).squeeze(1), 1)

print(torch.sum(channels.flatten()[time_shift_mask] != target_channels[time_shift_mask]))

tensor(0)


In [None]:
print(channels[time_shift_mask].flatten())
print(target_channels[time_shift_mask].flatten())

# Custom dataset class

In [94]:
# Custom Dataset class
class MIDIDataset(torch.utils.data.Dataset):
    # CONSTRUCTOR: creates a list of recording files and a list
    # of instrument files in root_dir. Assumes that the directory
    # contains recording0.npy to recordingM.npy,
    # as well as instruments0.npy to instrumentsM.npy
    # ARGUMENTS
    # root_dir: the directory to search
    def __init__(self, root_dir, transform=None):
        files = os.listdir(root_dir)
        self.recording_files = []
        self.instrument_files = []
        for file in files:
            if 'recording' in file:
                self.recording_files.append(os.path.join(root_dir, file))
            elif 'instruments' in file:
                self.instrument_files.append(os.path.join(root_dir, file))
                
        assert(len(self.recording_files) == len(self.instrument_files))
        self.recording_files.sort()
        self.instrument_files.sort()
        
        self.recordings = []
        self.instruments = []
        for f in range(len(self.recording_files)):
            self.recordings.append(np.load(self.recording_files[f], allow_pickle=True))
            self.instruments.append(np.load(self.instrument_files[f], allow_pickle=True))
            
        self.transform = transform

    # __len__
    # RETURN: the number of recording files in the dataset
    def __len__(self):
        return len(self.recordings)

    # __getitem__
    # ARGUMENTS
    # idx: indicates which file to get
    # RETURN: an instance with keys 'instruments', 'history'
    # instance['history'] is an Lx2 numpy array containing messages and associated channels
    # instance['instruments'] a numpy array of instrument numbers
    def __getitem__(self, idx):
        instance = {'history': self.recordings[idx], \
                    'instruments': self.instruments[idx]}
        
        if self.transform:
            instance = self.transform(instance)
            
        return instance

In [95]:
# collate_fn: takes a list of samples from the dataset and turns them into a batch.
# ARGUMENTS
# batch: a list of dictionaries
# RETURN: a sample with keys 'history', 'instruments', and 'mask'
# sample['history']: an LxBx2 tensor containing messages and their associated channels
# sample['instruments']: a CxB tensor containing instrument numbers for each channel
# sample['mask']: an LxB tensor containing False where a message is
# valid, and True where it isn't (accounts for variable length sequences
# and zero padding)
# sample['nchan']: a length B tensor containing the number of channels for each batch
# element (including the dummy time-shift channel)
def collate_fn(batch):
    batch_size = len(batch)
    
    # We size our tensors to accomodate the longest sequence and the largest number of instruments
    max_inst = max([instance['instruments'].shape[0] for instance in batch])
    longest_len = max([instance['history'].shape[0] for instance in batch])

    sample = {'history': torch.zeros((longest_len, batch_size, 2), dtype=torch.long), \
              'instruments': torch.zeros((max_channels, batch_size), dtype=torch.long), \
              'mask': torch.ones((longest_len, batch_size), dtype=torch.bool), \
              'inst_mask': torch.ones((max_inst, batch_size), dtype=torch.bool)}

    for b, instance in enumerate(batch):
        instrument_idx = [instrument_numbers.index(i) for i in instance['instruments']]
        
        sample['instruments'][:len(instrument_idx), b] = torch.tensor(instrument_idx, dtype=torch.long)
        
        sample['inst_mask'][:len(instrument_idx), b] = False
        
        seq_length = instance['history'].shape[0]
        sample['history'][:seq_length, b] = torch.tensor(instance['history'], dtype=torch.long)
        sample['mask'][:seq_length, b] = False
            
    return sample

# Train the model

In [96]:
# compute_loss: computes the loss for the model over the batch
# ARGUMENTS
# model: AssignerLSTM model
# loss_fn: torch.nn.NLLLoss object
# batch: see collate_fn definition
# RETURN: a scalar loss tensor
def compute_loss(model, loss_fn, batch):
    batch_size = batch['history'].shape[1]

    channel_probs = torch.log(model(batch['history'][:, :, 0], batch['mask'], batch['instruments'], batch['inst_mask']))

    target_mask = torch.logical_not(batch['mask'])

    target_channels = batch['history'][:, :, 1][target_mask]

    return loss_fn(channel_probs[target_mask], target_channels)

In [117]:
embed_dim = 256
hidden_size = 1024
heads = 4

grad_clip = 10

model = AssignerLSTM(message_dim, embed_dim, num_instruments, hidden_size, heads)
for p in model.parameters():
    p.register_hook(lambda grad: torch.clamp(grad, -grad_clip, grad_clip))
    
model.eval() # Training with eval just to see if we can overfit without dropout
pass

In [118]:
optimizer = torch.optim.Adam(model.parameters())

In [None]:
batch_size = 10
learning_rate = 0.001
chunk_size = 500

train_dataset = MIDIDataset('train_unified')
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

test_dataset = MIDIDataset('test_unified')
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

loss_fn = torch.nn.NLLLoss()
epochs = 20
train_losses = np.zeros(epochs)
test_losses = np.zeros(epochs)

for epoch in range(epochs):
    print('Starting epoch %d' %(epoch))
    model.train()
    for b, batch in enumerate(train_dataloader):
        print('Starting iteration %d' %(b))
        for chunk_start in range(0, batch['history'].shape[0], chunk_size):
            chunk_end = min(batch['history'].shape[0], chunk_start + chunk_size)
            
            chunk = {'history': batch['history'][chunk_start:chunk_end],
                     'instruments': batch['instruments'],
                     'mask': batch['mask'][chunk_start:chunk_end],
                     'inst_mask': batch['inst_mask']}

            loss = compute_loss(model, loss_fn, chunk)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
    torch.save(model.state_dict(), 'assigner_models/epoch' + str(epoch) + '.pth')

    print('Computing test loss')
    model.eval()
    for batch in test_dataloader:
        for chunk_start in range(0, batch['history'].shape[0], chunk_size):
            chunk_end = min(batch['history'].shape[0], chunk_start + chunk_size)
            
            chunk = {'history': batch['history'][chunk_start:chunk_end],
                     'instruments': batch['instruments'],
                     'mask': batch['mask'][chunk_start:chunk_end],
                     'inst_mask': batch['inst_mask']}

            loss = compute_loss(model, loss_fn, chunk)
            test_losses[epoch] += loss.data
        
    print('Computing train loss')
    for batch in train_dataloader:
        for chunk_start in range(0, batch['history'].shape[0], chunk_size):
            chunk_end = min(batch['history'].shape[0], chunk_start + chunk_size)
            
            chunk = {'history': batch['history'][chunk_start:chunk_end],
                     'instruments': batch['instruments'],
                     'mask': batch['mask'][chunk_start:chunk_end],
                     'inst_mask': batch['inst_mask']}

            loss = compute_loss(model, loss_fn, chunk)
            train_losses[epoch] += loss.data
    
    train_losses[epoch] /= len(train_dataloader)
    test_losses[epoch] /= len(test_dataloader)
    print('Train Loss: %f, Test Loss: %f' %(train_losses[epoch], test_losses[epoch]))

In [None]:
plt.plot(train_losses)

# Sample from the model

In [115]:
model.eval() # Disable dropout to make results repeatable

recording = np.load('train_unified/recording0.npy', allow_pickle=True)
instruments_np = np.load('train_unified/instruments0.npy', allow_pickle=True)

nsamples = 200

message_history = torch.tensor(recording[:nsamples, 0], dtype=torch.long).view(-1, 1)
channel_history = torch.tensor(recording[:nsamples, 1], dtype=torch.long)
mask = torch.zeros(message_history.shape, dtype=torch.bool)
instruments = torch.tensor([instrument_numbers.index(i) for i in instruments_np], dtype=torch.long).view(-1, 1)
inst_mask = torch.zeros(instruments.shape, dtype=torch.bool)
target_channels = channel_history.flatten()
time_shift_mask = target_channels >= 0

channels = torch.multinomial(model(message_history, mask, instruments).squeeze(1), 1)
print(torch.sum(channels.flatten()[time_shift_mask] != target_channels[time_shift_mask]))

In [116]:
# Associate all messages with channel 0 for now
history = torch.cat((message_history, channels), dim=1)
np.save('test_history.npy', gen_history.detach().numpy())
np.save('test_instruments.npy', instruments_np)