In [315]:
import torch
import numpy as np
import math
import os

In [217]:
num_notes = 128
num_time_shifts = 100
num_velocities = 32
message_dim = 2*num_notes + num_velocities + num_time_shifts
instrument_numbers = [0, 6, 40, 41, 42, 43, 45, 60, 68, 70, 71, 73]
num_instruments = len(instrument_numbers)

In [119]:
###### Music Transformer version of our network ######
# Uses the relative position representation from the Music Transformer


# HuangMHA: multi-headed attention using relative position representation
# (specifically, the representation introduced by Shaw and optimized by Huang)
class HuangMHA(torch.nn.Module):
    def __init__(self, heads, embed_dim, heads):
        pass

SyntaxError: duplicate argument 'heads' in function definition (<ipython-input-119-da3088c90439>, line 11)

In [120]:
##### TransformerXL version of our network #####

In [308]:
##### Baseline Transformer version of our network #####
# Vanilla transformer, uses absolute position representation

# Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html.
# Only change is the view/expand in forward (accounts for multiple instruments and batches)
class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = torch.nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.shape[0], :].view(x.shape[0], 1, 1, -1).expand(-1, x.shape[1], x.shape[2], -1)
        return self.dropout(x)
        

# EnsembleTransformer: takes a history of MIDI messages 
# for instruments in an ensemble and generates a distribution for the next message
# for a specific instrument
class EnsembleTransformer(torch.nn.Module):
    # CONSTRUCTOR
    # ARGUMENTS
    # message_dim: dimension of a MIDI message
    # embed_dim: dimension of message embedding
    # num_instruments: number of instrument labels
    # heads: number of attention heads
    # attention_layers: number of attention layers
    # ff_size: size of the feedforward output at the end of the decoder
    def __init__(self, message_dim, embed_dim, num_instruments, heads, attention_layers, ff_size):
        super(EnsembleTransformer, self).__init__()
        
        self.embed_dim = embed_dim
        
        # We add the tanhed instrument embedding to each input message 
        # (this is the global conditioning idea from DeepJ, which comes from WaveNet)
        self.i_embedding = torch.nn.Embedding(num_instruments, embed_dim)
        
        self.position_encoding = PositionalEncoding(embed_dim)
        
        self.embedding = torch.nn.Embedding(message_dim, embed_dim)
        
        # An encoder is used to transform histories of all instruments in the ensemble
        # except the instrument we're generating music for
        encoder_layer = torch.nn.TransformerEncoderLayer(embed_dim, heads, ff_size)
        self.encoder = torch.nn.TransformerEncoder(encoder_layer, attention_layers)
        
        # A decoder is used to transform the history of the instrument we're generating
        # music for, then combine this with the encoder output to generate the next message
        decoder_layer = torch.nn.TransformerDecoderLayer(embed_dim, heads, ff_size)
        self.decoder = torch.nn.TransformerDecoder(decoder_layer, attention_layers)
        
        self.logits = torch.nn.Linear(embed_dim, message_dim)
    
    # forward: generates a probability distribution for the next MIDI message
    # for an instrument, given the message history of the ensemble
    # ARGUMENTS
    # history: an LxNxB tensor, where L is the length of the longest history in
    # the batch, N is the max number of instruments in the batch, and B is the batch size. As we
    # walk along the first dimension, we should see indices of MIDI events for a particular
    # instrument. The END of the sequence should be padded, if needed
    # mask: an LxNxB tensor, containing False along dimension 0 where a message doesn't
    # exist for an instrument, and containing False along dimension 1 where an instrument
    # doesn't exist for a batch
    # instruments: a 1xNxB tensor indicating the instrument numbers for each batch
    # gen_idx: a length B tensor indicating the index of the instrument we want to generate music
    # for, for each batch. The END of dimension 1 should be padded, if needed
    # RETURN: a 1xD tensor, representing the distribution for
    # the next MIDI message for the instrument indicated by gen_idx. 
    # Note, to get the actual probabilities you'll have to take the softmax
    # of this tensor along dimension 1
    def forward(self, history, mask, instruments, gen_idx):
        longest_length = history.shape[0]
        batch_size = history.shape[2]
        max_instruments = instruments.shape[1]
        assert(history.shape[1] == max_instruments)
        assert(mask.shape[1] == max_instruments)
        assert(instruments.shape[2] == batch_size)
        assert(mask.shape[2] == batch_size)
        assert(mask.shape[0] == longest_length)
        
        batch_size = history.shape[2]
        
        inputs = self.embedding(history) + torch.tanh(self.i_embedding(instruments)).expand(history.shape[0], -1, -1, -1)
        
        inputs = self.position_encoding(inputs)
        
        encode_idx = torch.tensor([[i for i in range(max_instruments) if i != gen_idx[b]] for b in range(gen_idx.shape[0])])
        encode_idx = encode_idx.transpose(0, 1).unsqueeze(0)
        encoder_inputs = torch.gather(inputs, 1, encode_idx.unsqueeze(3).expand(inputs.shape[0], -1, -1, self.embed_dim)).view(-1, batch_size, self.embed_dim)

        encoder_mask = torch.gather(mask, 1, encode_idx.expand(inputs.shape[0], -1, -1)).view(-1, batch_size).transpose(0, 1)
        encoding = self.encoder(encoder_inputs, encoder_mask)
        
        decode_idx = gen_idx.view(1, 1, -1).expand(inputs.shape[0], -1, -1)
        decoder_inputs = torch.gather(inputs, 1, decode_idx)
        decoder_mask = torch.gather(mask, 1, decode_idx).transpose(0, 1)
        decoding = self.decoder(decoder_inputs, encoding)
        
        return self.logits(decoding[-1])

In [309]:
##### Tests for baseline transformer #####
recording = np.load('preprocessed_data/recording0.npy', allow_pickle=True)
instruments_np = np.load('preprocessed_data/instruments0.npy', allow_pickle=True)

instrument_idx = [instrument_numbers.index(i) for i in instruments_np]

# Form a single-element batch from this recording
seq_lengths = [messages.shape[0] for messages in recording]
longest_len = max(seq_lengths)
batch = torch.ones((longest_len, recording.shape[0], 1), dtype=torch.long)
mask = torch.zeros((longest_len, recording.shape[0], 1), dtype=torch.bool)
instruments = torch.zeros((1, recording.shape[0], 1), dtype=torch.long)

for i, messages in enumerate(recording):
    batch[:seq_lengths[i], i, 0] = torch.tensor(messages, dtype=torch.long)
    mask[:seq_lengths[i], i, 0] = 1
    
instruments[0, :len(instrument_idx), 0] = torch.tensor(instrument_idx, dtype=torch.long)

et = EnsembleTransformer(message_dim, 256, num_instruments, 4, 6, 2048).double()

print(et(batch, mask, instruments, torch.tensor([0])))

RuntimeError: [enforce fail at CPUAllocator.cpp:64] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 2065444992 bytes. Error code 12 (Cannot allocate memory)


In [310]:
# Custom Dataset class
class MIDIDataset(torch.utils.data.Dataset):
    # CONSTRUCTOR: creates a list of recording files and a list
    # of instrument files in root_dir. Assumes that the directory
    # contains recording0.npy to recordingN.npy,
    # as well as instruments0.npy to instrumentsN.npy
    # ARGUMENTS
    # root_dir: the directory to search
    def __init__(self, root_dir, transform=None):
        files = os.listdir(root_dir)
        self.recordings = []
        self.instrument_files = []
        self.time_files = []
        for file in files:
            if 'recording' in file:
                self.recordings.append(os.path.join(root_dir, file))
            elif 'instruments' in file:
                self.instrument_files.append(os.path.join(root_dir, file))
            elif 'times' in file:
                self.time_files.append(os.path.join(root_dir, file))
                
        assert(len(self.recordings) == len(self.instrument_files))
        assert(len(self.recordings) == len(self.time_files))
        self.recordings.sort()
        self.instrument_files.sort()
        self.time_files.sort()
        self.transform = transform

    # __len__
    # RETURN: the number of recording files in the dataset
    def __len__(self):
        return len(self.recordings)

    # __getitem__
    # ARGUMENTS
    # idx: either an int or a list of ints saying which items to get
    # RETURN: a dictionary d with keys 'instruments'
    # 'history', and 'times'. Assuming idx is a list of indices i1, i2, ... iB,
    # where B is the batch size:
    # d['history'] is an LxNxB tensor (see 'history' in the EnsembleTransformer module)
    # d['instruments'] is a 1xNxB tensor
    # d['times'] is a LxNxB tensor containing message times
    # d['mask'] is a LxNxB boolean tensor masking out invalid message indices
    # (accounts for variable sequence lengths)
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # history_list[i] contains the recording associated with the ith file
        history_list = []
        inst_list = []
        time_list = []
        seq_lengths = []
        
        if type(idx) != list:
            idx = [idx]
            
        batch_size = len(idx)
            
        # Max number of instruments in any sequence in the batch
        max_instruments = 0
        for i in idx:
            history_list.append(np.load(self.recordings[i], allow_pickle=True))
            inst_list.append(np.load(self.instrument_files[i], allow_pickle=True))
            time_list.append(np.load(self.time_files[i], allow_pickle=True))
            seq_lengths.append([len(arr) for arr in history_list[-1]])
            if len(history_list[-1]) > max_instruments:
                max_instruments = len(history_list[-1])
            
        # We size our tensors to accomodate the longest sequence and the largest ensemble
        longest_len = max(max(seq_lengths))
        
        history = torch.ones((longest_len, max_instruments, batch_size), dtype=torch.long)
        mask = torch.zeros((longest_len, max_instruments, batch_size), dtype=torch.bool)
        instruments = torch.zeros((1, max_instruments, batch_size), dtype=torch.long)
        times = torch.zeros((longest_len, max_instruments, batch_size), dtype=torch.double)

        for b in range(batch_size):
            for i, messages in enumerate(history_list[b]):
                history[:seq_lengths[b][i], i, b] = torch.tensor(messages, dtype=torch.long)
                mask[:seq_lengths[b][i], i, b] = 1
                times[:seq_lengths[b][i], i, b] = torch.tensor(times[b][i], dtype=torch.double)
                
            instrument_idx = [instrument_numbers.index(i) for i in inst_list[b]]
            instruments[0, :len(instrument_idx), 0] = torch.tensor(instrument_idx, dtype=torch.long)
            
        sample = {'history': history, 'instruments': instruments, 'times': times, 'mask': mask}

        if self.transform:
            sample = self.transform(sample)
            
        return sample

In [311]:
dataset = MIDIDataset('preprocessed_data')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)

# Training
batch_size = 1
learning_rate = 0.001
ff_size = 512
heads = 8
attention_layers=  6
embed_dim = 256
model = EnsembleTransformer(message_dim, embed_dim, len(instrument_numbers), heads, attention_layers, ff_size)
optimizer = torch.optim.Adam(model.parameters())

for b, batch in enumerate(dataloader):
    for inst in range(batch['history'].shape[1]):
        # Index of the instrument we want to generate music for (same for each batch element)
        gen_idx = torch.tensor([inst for i in range(batch_size)])
        mask = batch['mask']
        
        # Move forward in time
        for t in range(batch['history'].shape[0]):
            pass
        
    # Generate the next message of an instrument given all previous midi messages from the instrument and ALL OTHER INSTRUMENTS
    # Use timing information to create masks
    # torch.nn.CrossEntropyLoss



In [None]:
# Sample from model