In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import librosa

## Datas

TO DO: Find the data, this could be midi songs from [Lakh MIDI Dataset](https://colinraffel.com/projects/lmd/)

## Vector Quantizer Layer

Les embeddings (vecteurs ou tensor) vont être quantifier (placés) dans le codebook (dictionnaire) pour créer un espace latenet de vecteurs discrets

In [None]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        
        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        
        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
        self._commitment_cost = commitment_cost

    def forward(self, inputs):
        # flatern input
        flat_input = inputs.view(-1, self._embedding_dim)

        # calculate euclidian distances (sum(A^2) + sum(B^2) - 2 * dot(A, B))
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))

        # Find the closest codebook vectors and their indices
        indices = torch.argmin(distances, dim=1)

        # Look up the corresponding embeddings from the codebook
        quantized = self._embedding(indices).view_as(inputs)

        # Loss (compute mse from quantized vectors to input vectors and from input vectors to quantized vec)
        # detach is used to prevent gradients from flowing into the embedding vectors
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self._commitment_cost * e_latent_loss

        # Used to keep relation between quantized and input vectors
        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(indices.float(), dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        return ...


## Encoder

TO DO: Find a good way to do a bottleneck on input data. A solution could be to copy the [ss-vq-vae encoder](https://arxiv.org/pdf/2102.05749.pdf) (2x conv[4, 2], conv[1, 1])

In [3]:
class Encoder(nn.Module):
    def __init__(self, input_shape):
        super(Encoder, self).__init__()
        self.model = nn.Sequential(
            nn.Conv1d(1024, 1024, 4, 2),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(),
            nn.Conv1d(1024, 1024, 4, 2),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(),
            nn.Conv1d(1024, 1024, 1, 1),
       )

    def forward(self, inputs):
        return self.model(inputs)


## Decoder

In [3]:
class Decoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.model = nn.Sequential(
            nn.Conv1d(1024, 1024, 1, 1),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(),
            nn.GRU(1024, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(),
            nn.Conv1d(1024, 1024, 4, 2),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(),
            nn.Conv1d(1024, 1024, 1, 1),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(),
            nn.GRU(1024, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(),
            nn.Conv1d(1024, 1024, 4, 2),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(),
            nn.Conv1d(1025, 1025, 1, 1),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(),
            nn.GRU(1025, 1025),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(),

        )
        pass

    def forward(self, inputs):
        pass

## Model

In [None]:
class Model(nn.Module):
    def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens, 
                 num_embeddings, embedding_dim, commitment_cost, decay=0):
        super(Model, self).__init__()
        
        self._encoder = Encoder()
#       self._pre_vq_conv = nn.Conv2d(in_channels=num_hiddens, 
#                                     out_channels=embedding_dim,
#                                     kernel_size=1, 
#                                     stride=1)
        self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
                                           commitment_cost)
        self._decoder = Decoder()

    def forward(self, x):
        return