In [16]:
import torch
import numpy as np
import math

In [None]:
###### Music Transformer version of our network ######
# Uses the relative position representation from the Music Transformer


# HuangMHA: multi-headed attention using relative position representation
# (specifically, the representation introduced by Shaw and optimized by Huang)
class HuangMHA(torch.nn.Module):
    def __init__(self, heads, embed_dim, heads):
        pass

In [None]:
##### TransformerXL version of our network #####

In [64]:
##### Baseline Transformer version of our network #####
# Vanilla transformer, uses absolute position representation

# Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html.
# Only change is the unsqueeze/expand in forward (accounts for multiple instruments)
class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = torch.nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :].expand(-1, x.shape[1], -1)
        return self.dropout(x)
        

# EnsembleTransformer: takes a history of MIDI messages 
# for instruments in an ensemble and generates a distribution for the next message
# for a specific instrument
class EnsembleTransformer(torch.nn.Module):
    # CONSTRUCTOR
    # ARGUMENTS
    # message_dim: dimension of a MIDI message
    # embed_dim: dimension of message embedding
    # num_instruments: number of possible instruments
    # heads: number of attention heads
    # attention_layers: number of attention layers
    # ff_size: size of the feedforward output at the end of the decoder
    def __init__(self, message_dim, embed_dim, num_instruments, heads, attention_layers, ff_size):
        super(EnsembleTransformer, self).__init__()
        
        # We project the one-hot instrument identity using a linear layer, then pass it through
        # a tanh, finally adding it to each input message (this is the global conditioning idea
        # from DeepJ, which comes from WaveNet)
        self.i_projection = torch.nn.Linear(num_instruments, embed_dim, False)
        
        self.position_encoding = PositionalEncoding(embed_dim)
        
        # TODO: replace this and instrument projection with torch.nn.Embedding
        # (would require us to switch from one-hots to indices)
        self.embedding = torch.nn.Linear(message_dim, embed_dim, False)
        
        # An encoder is used to transform histories of all instruments in the ensemble
        # except the instrument we're generating music for
        encoder_layer = torch.nn.TransformerEncoderLayer(embed_dim, heads, ff_size)
        self.encoder = torch.nn.TransformerEncoder(encoder_layer, attention_layers)
        
        # A decoder is used to transform the history of the instrument we're generating
        # music for, then combine this with the encoder output to generate the next message
        decoder_layer = torch.nn.TransformerDecoderLayer(embed_dim, heads, ff_size)
        self.decoder = torch.nn.TransformerDecoder(decoder_layer, attention_layers)
        
        self.logits = torch.nn.Linear(embed_dim, message_dim)
    
    # forward: generates a probability distribution for the next MIDI message
    # for an instrument, given the message history of the ensemble
    # ARGUMENTS
    # history: a LxNxD tensor, where L is the length of the longest history in
    # the ensemble, N is the number of instruments in the ensemble, and D
    # is the MIDI message dimension. Histories with length less than L should
    # be LEFT-ALIGNED and RIGHT-ZERO-PADDED
    # instruments: a 1xNxI tensor, where I is the number of possible instruments
    # gen_idx: index along the second dimension of history indicating which instrument
    # we want to generate a message for
    # RETURN: a 1x1xD tensor, representing the log probabilities of
    # the next MIDI message for each sequence
    def forward(self, history, instruments, gen_idx):
        inputs = self.embedding(history) + torch.tanh(self.i_projection(instruments)).expand(history.shape[0], 1, -1)
        
        # TODO: we'll probably have to chunk the data, and we'll have to keep track of the starting position for
        # this chunk
        inputs = self.position_encoding(inputs)
        
        encode_idx = np.array([i for i in range(inputs.shape[1]) if i != gen_idx])
        encoding = self.encoder(inputs[:, encode_idx].view(-1, 1, inputs.shape[2])) # 2nd dimension is batch size. We'll probably just use batch size 1
        
        decoding = self.decoder(inputs[:, gen_idx].unsqueeze(1), encoding) # Unsqueeze is for batch dimension
        
        return torch.nn.functional.log_softmax(self.logits(decoding[-1].unsqueeze(0)), 2)

In [66]:
##### Tests for baseline transformer #####
recording = np.load('preprocessed_data/recording0.npy', allow_pickle=True)

history = torch.tensor(recording[0]).unsqueeze(1) # Test for a single instrument
instruments = torch.zeros((1, 1, 12), dtype=torch.double)
instruments[0, 0, 0] = 1 # Pretend it's piano

et = EnsembleTransformer(history.shape[2], 256, 12, 4, 6, 2048).double()

print(et(history, instruments, 0))

tensor([[[-6.3004, -6.0479, -6.1753, -5.6479, -6.1512, -5.2086, -7.3799,
          -7.9261, -5.7544, -5.7205, -5.8728, -6.7821, -7.0666, -6.8350,
          -6.1641, -7.4795, -6.3174, -6.1323, -6.3944, -6.0106, -6.9416,
          -5.3092, -5.6186, -5.4328, -5.5351, -5.9735, -7.3433, -6.6063,
          -6.2782, -6.8659, -6.2641, -6.7544, -6.8649, -5.3094, -5.3802,
          -6.3272, -6.3699, -6.6545, -7.2354, -5.7922, -5.3774, -6.9441,
          -6.8730, -4.8403, -5.5289, -6.8899, -6.4514, -6.3966, -5.4889,
          -5.6810, -6.6744, -5.0264, -5.4491, -6.9853, -5.9075, -5.9432,
          -4.9297, -7.2022, -6.9893, -6.0436, -6.5146, -5.4863, -6.6307,
          -5.6968, -5.7590, -5.9590, -6.2410, -6.7688, -6.7982, -6.6794,
          -5.9398, -5.7591, -6.3195, -6.3400, -6.2436, -6.5169, -5.7151,
          -5.4733, -5.3360, -6.5196, -5.6794, -6.6981, -6.4688, -6.4744,
          -6.2493, -6.3226, -6.7693, -7.2626, -4.6927, -6.2605, -7.1629,
          -8.0683, -6.5177, -6.2967, -7.2308, -6.93

In [None]:
# Custom DataLoader - Tushar

In [None]:
# Training - Tian

# Instantiate the model
# PyTorch boilerplate, see tutorial if needed (optimizer and dataloader intialization, stuff like that)

# DataLoader spits out one file
# For loop over instruments in file
# Generate the next message of this instrument given all previous midi messages from this instrument and ALL OTHER INSTRUMENTS

In [None]:
# Sample from model