# Multitrack MusicVAE

The folowing shows a graphical representation of the model.

<img src="data/multitrack-music-vae.png" width=400>


The model uses a variational autoencoder. Both the encoder and the decoder have a hierarchical architecture with gated recurrent units (GRU).

The encoder takes as input a list of musical sequences encoded in MIDI-like format, to produce one embedding vector. Each musical sequence correspond to one track (piano, saxophone, etc.). Each sequence is processed separately by a shared bi-directional GRU, which produces one embedding vector per track. These track embeddings are processed by a different bi-directional GRU which produces an embedding vector for the whole input.

Following the setup of a variational autoencoder, the output of the encoder is fed to two linear layers, to respectively produce a mean vector and a log-variance vector, which are used to parametrize the distribution for sampling the latent vector used for decoding.

The decoder takes as input one latent vector and produces a list of musical sequences. The latent vector is fed to a decoder GRU, called the /conductor/, which outputs one embedding for each track to decode. Each track embedding is then fed to a shared decoder GRU which produces a sequnce of event embeddings. Event embeddings are processed by a final linear layer with softmax activation, to calculate the probability distribution over the events in the MIDI-like representation.

Links to [long paper](https://arxiv.org/pdf/1806.00195.pdf), [short paper](https://nips2018creativity.github.io/doc/Learning_a_Latent_Space_of_Multitrack_Measures.pdf), [poster](https://colinraffel.com/posters/neurips2018learning.pdf), and the [official implementation in Tensorflow](https://github.com/magenta/magenta/blob/be6558f1a06984faff6d6949234f5fe9ad0ffdb5/magenta/models/music_vae/lstm_models.py).

# The encoder
The encoder takes as input a tensor of shape (batch_size, n_tracks, seq_len, vocab_size), where
- batch_size is the size of batch.
- n_tracks is the number of tracks.
- seq_len is the length of a sequence.
- vocab_size is the size of the vocabulary of events in the MIDI-like format.

まず初めに、dummy valuesで通せる様にする。

encoder の返り値のshapeは (batch_size, hidden_dim), where `hidden_dim` は潜在空間Zのembedding dimensionであること。
とりあえず、hidden_dimは512でやってみる。

In [8]:
batch_size = 4
n_tracks = 3
seq_len = 10
vocab_size = 342

# Here, we create some dummy data to test the model
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.zeros((batch_size, n_tracks, seq_len, vocab_size)).to(device)

In [3]:
import torch
from torch import nn

class MultitrackEncoder(nn.Module):

    def __init__(self, input_size, hidden_dim, num_layers):
        """ This initializes the encoder """
        super(MultitrackEncoder,self).__init__()
        self.track_rnn = nn.LSTM(input_size, hidden_dim, batch_first=True, num_layers=num_layers, bidirectional=True, dropout=0.6)
        self.score_rnn = nn.LSTM(hidden_dim * 2, hidden_dim, batch_first=True, num_layers=num_layers, bidirectional=True, dropout=0.6)
        self.input_size = input_size
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim

    def forward(self,x, h0=None, c0=None):
        batch_size, n_tracks, seq_len, input_size  = x.shape
        x = x.view(batch_size * n_tracks, seq_len, input_size)
        if h0 is None and c0 is None:
            h0, c0 = self.init_hidden(batch_size * n_tracks)

        _, (h, _) = self.track_rnn(x, (h0, c0))

        h = h.view(self.num_layers, 2, batch_size, -1)  # 2 for forward/backward
        h = h[-1]
        h = torch.cat([h[0], h[1]], dim=1)
        h = h.view(batch_size, n_tracks, -1)

        h0, c0 = self.init_hidden(batch_size)
        _, (h, _) = self.score_rnn(h, (h0, c0))

        h = h.view(self.num_layers, 2, batch_size, -1)
        h = h[-1]
        h = torch.cat([h[0], h[1]], dim=1)
        h = h.view(batch_size, -1)

        return h

    def init_hidden(self, batch_size=1):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # Bidirectional lstm so num_layers*2
        return (torch.zeros(self.num_layers * 2, batch_size, self.hidden_dim, dtype=torch.float, device=device),
                torch.zeros(self.num_layers * 2, batch_size, self.hidden_dim, dtype=torch.float, device=device))

In [4]:
encoder_config = {
    "input_size": vocab_size,
    "hidden_dim": 512,
    "num_layers": 3
}

encoder = MultitrackEncoder(**encoder_config).to(device)

In [5]:
hidden = encoder(x)
hidden, hidden.shape

(tensor([[-6.8038e-03, -6.2736e-03,  4.3171e-05,  ..., -2.4107e-02,
           2.2319e-02, -1.3294e-02],
         [-2.5326e-03, -3.9660e-03,  1.0518e-02,  ..., -2.0999e-02,
           1.7384e-02, -1.0262e-02],
         [-4.2294e-03, -1.4460e-02,  5.9571e-03,  ..., -2.2872e-02,
           1.3252e-02, -1.1259e-02],
         [-1.7087e-03, -1.0652e-02,  1.1336e-02,  ..., -2.5888e-02,
           1.3819e-02, -1.9869e-02]], grad_fn=<ViewBackward0>),
 torch.Size([4, 1024]))

# Latent space sampling
In this step, the output of the encoder is used to parametrize a gaussian distribution for sampling a latent vector, which will be used by the decoder for recovering the musical score.

We compute the mu and log-variance vectors with a linear layer, and we then output sample a latent vector using the reparametrization trick.
$$ \epsilon \sim \mathcal{N}(0, 1) \\
z = \epsilon \cdot \text{stddev} + \mu
$$

In [6]:
import torch
from torch import nn
from torch.distributions import Normal
# from torch.nn import functional as F
import random

latent_dim = 256

hidden_to_mu = nn.Linear(2 * encoder_config["hidden_dim"], latent_dim).to(device)
hidden_to_sig = nn.Linear(2 * encoder_config["hidden_dim"], latent_dim).to(device)
mu = hidden_to_mu(hidden)
sigma = hidden_to_sig(hidden).exp_()
latent_distribution = Normal(mu, sigma)
latent = latent_distribution.rsample() # Reparametrization using rsample method in PyTorch
latent, latent.shape

(tensor([[ 1.3365,  1.4243,  1.0314,  ..., -0.2053,  0.3318,  0.1156],
         [-0.4156,  0.7189,  0.4965,  ..., -0.1918,  0.6623,  0.9574],
         [ 1.3969, -1.8929,  1.6019,  ..., -1.0083, -1.5105, -1.2558],
         [ 1.0896,  0.2108, -0.8237,  ...,  0.0660, -0.2607, -1.0459]],
        grad_fn=<AddBackward0>),
 torch.Size([4, 256]))

# The decoder (TODO)
decoderは潜在ベクトルを入力とする.
潜在ベクトルのshapeは (batch_size, latent_dim). 
出力は MIDI-likeの確率値(tensor)となり、その形は`(batch_size, n_tracks, seq_len, vocab_size)`.
`n_tracks`は簡単に言うと楽器の数

デコーダーは2つのモデルが存在
- トレーニングモード: 各time-stepのGRUには、本物のデータを入力とする。
- 推論モード: 一個前のtime-sequeceの出力を入力年、自己回帰的に出力を求めてゆく


まずは、inferenceモードを実装する。inferenceモードができればtrainingモードは簡単に実装できる(多分).

In [23]:
import torch
from torch import nn
from torch.nn import functional as F

class MultitrackDecoder(nn.Module):

    def __init__(self, input_size, latent_dim, conductor_hidden_dim, conductor_output_dim,
                 decoder_hidden_dim, num_layers, seq_len, n_tracks):
        super(MultitrackDecoder, self).__init__()
        self.tanh = nn.Tanh()

        conductor_input_dim = latent_dim
        self.conductor_input = nn.Parameter(torch.rand(conductor_input_dim))
        self.conductor_hidden_linear = nn.Linear(conductor_input_dim, conductor_hidden_dim * num_layers)
        self.conductor_rnn = nn.LSTM(latent_dim, conductor_hidden_dim, batch_first=True, num_layers=num_layers, bidirectional=False, dropout=0.6)
        assert conductor_output_dim is None
        conductor_output_dim = conductor_hidden_dim
        self.conductor_output_linear = nn.Linear(conductor_hidden_dim, conductor_output_dim)
        
        # self.decoder_hidden_linear = nn.Linear(conductor_output_dim, decoder_hidden_dim * num_layers)
        self.decoder_hidden_linear = nn.Linear(latent_dim, decoder_hidden_dim * num_layers)
        # self.decoder_rnn = nn.LSTM(conductor_output_dim + input_size, decoder_hidden_dim, batch_first=True, num_layers=num_layers, bidirectional=False, dropout=0.6)
        self.decoder_rnn = nn.LSTM(latent_dim + input_size, decoder_hidden_dim, batch_first=True, num_layers=num_layers, bidirectional=False, dropout=0.6)
        self.decoder_output = nn.Linear(decoder_hidden_dim, input_size)

        self.input_size = input_size
        self.conductor_hidden_dim = conductor_hidden_dim
        self.decoder_hidden_dim = decoder_hidden_dim
        self.num_layers = num_layers
        self.seq_len = seq_len
        self.n_tracks = n_tracks
        self.teacher_forcing_ratio = 0.5

    def sample(self, x, method="argmax"):
        idx = x.max(1)[1]
        x = torch.zeros_like(x)
        arange = torch.arange(x.size(0)).long()
        if torch.cuda.is_available():
            arange = arange.cuda()
        x[arange, idx] = 1
        return x

    def forward(self, latent, target=None, seq_len=0., teacher_forcing=True):
        batch_size = latent.shape[0]
        device = latent.device
        conductor_hidden = self.conductor_hidden_linear(self.conductor_input.repeat(batch_size, 1))
        conductor_hidden = conductor_hidden.view(batch_size, self.num_layers, -1).transpose(0, 1).contiguous()
        conductor_cell = conductor_hidden
        tracks_sequences = []

        for track in range(self.n_tracks):
            _, (conductor_hidden, conductor_cell) = self.conductor_rnn(latent.unsqueeze(1),
                                                                       (conductor_hidden, conductor_cell))
            track_embedding = conductor_hidden[-1]
            track_embedding = self.conductor_output_linear(track_embedding)

            decoder_hidden = self.decoder_hidden_linear(track_embedding)
            decoder_hidden = decoder_hidden.view(batch_size, self.num_layers, -1).transpose(0, 1).contiguous()
            decoder_cell = decoder_hidden
            previous_event = torch.zeros((batch_size, self.input_size)).to(device)
            previous_event[:, -1] = 1.         # Set as SOS token in one-hot representation
            track_sequence = []
            for event in range(seq_len):
                previous_event = torch.cat([previous_event, track_embedding], 1).unsqueeze(1)
                _, (decoder_hidden, decoder_cell) = self.decoder_rnn(previous_event, (decoder_hidden, decoder_cell))
                event_probabilities = F.log_softmax(self.decoder_output(decoder_hidden[-1]), 1)
                track_sequence.append(event_probabilities)
                if self.training and teacher_forcing:
                    assert target is not None, "How am I supposed to teacher force without the target data??"
                    p = torch.rand(1).item()
                    if p < self.teacher_forcing_ratio:
                        previous_event = target[:, track, event, :]
                    else:
                        previous_event = self.sample(event_probabilities)
                else:
                    previous_event = self.sample(event_probabilities)
            tracks_sequences.append(torch.stack(track_sequence, dim=1))

        return torch.stack(tracks_sequences, dim=1)

In [25]:
decoder_config = {
    "input_size": vocab_size,
    "latent_dim": latent_dim,
    "conductor_hidden_dim": 256,
    "conductor_output_dim": 256,
    "decoder_hidden_dim": 256,
    "num_layers":3,
    "seq_len": seq_len,
    "n_tracks": n_tracks
}
decoder = MultitrackDecoder(**decoder_config)

AssertionError: 