In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Datas

TO DO: Find the data, this could be midi songs from [Lakh MIDI Dataset](https://colinraffel.com/projects/lmd/)

## Vector Quantizer Layer

Les embeddings (vecteurs ou tensor) vont être quantifier (placés) dans le codebook (dictionnaire) pour créer un espace latenet de vecteurs discrets

In [None]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        
        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        
        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
        self._commitment_cost = commitment_cost

    def forward(self, inputs):
        # flatern input
        flat_input = inputs.view(-1, self._embedding_dim)

        # calculate euclidian distances (sum(A^2) + sum(B^2) - 2 * dot(A, B))
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))

        # Find the closest codebook vectors and their indices
        indices = torch.argmin(distances, dim=1)

        # Look up the corresponding embeddings from the codebook
        quantized = self._embedding(indices).view_as(inputs)

        # Loss (compute mse from quantized vectors to input vectors and from input vectors to quantized vec)
        # detach is used to prevent gradients from flowing into the embedding vectors
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self._commitment_cost * e_latent_loss

        # Used to keep relation between quantized and input vectors
        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(indices.float(), dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
