<a href="https://colab.research.google.com/github/autumnjohnson/AbstractMicrophone/blob/master/sperm-whalevqvae-demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VQ-VAE for sperm whales

Autumn Johnson

Spring 2024

CS 294 Unsupervised Deep Learning



---




Variational Auto Encoders (VAEs) estimate a probability distribution for high-dimensional data such as audio by learning the underlying structure which includes how the individual elements relate to one another in the dataset. The learned representation can be used for generation or prediction.

The Audio Quantied VAE (AQ-VAE) learns a discrete latent representation to model language, which is compositional and composed of discrete acoustic units. Continuous representations interpolate between elements which makes it much harder to learn depencencies between them.
We define a latent embedding space $E$ of size $\lbrack K,D\rbrack$ where

*  $K$ is the number of embeddings, and
*  $D$ is the dimensionality of each latent embedding vector, i.e. $e_i \in \mathbb{R}^{D}$

The model comprises

*   An encoder which maps the input to a sequence of discrete latent variables, and
*   A decoder which tries to reconstructs the input from latent sequences
The total loss comprises


* A reconstruction loss which optimizes the encoder and decoder
* A codebook loss which, due to the fact that gradients bypass the embedding, we use a dictionary learning algorithm  which uses an $l_{2}$  error to move the embedding vectors $e_{i}$ towards the encoder output
* A commitment loss which, sthe volume of the embedding space is dimensionless, it can grow arbirtarily if the embeddings $e_{i}$ do not train as fast as  the encoder parameters, and thus we add a commitment loss to make sure that the encoder commits to an embedding


In [None]:
%pip install umap-learn datasets POT
import numpy as np
import io
import ot
import cv2
import ot
import scipy
import numpy as np
import librosa
import matplotlib.pyplot as plt
from transformers import EncodecModel, AutoProcessor
import umap.umap_ as umap
import torchaudio
from __future__ import print_function
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import savgol_filter
from six.moves import xrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torchvision.datasets
from torchvision.datasets import CIFAR10 as CIFAR10
import torchvision.transforms as transforms
from torchvision.utils import make_grid
import huggingface_hub
from huggingface_hub import login
from datasets import load_dataset, Audio, Features
import pandas
import random
import matplotlib.cm as cm
from librosa import to_mono
from google.colab import userdata
import torch
import requests
from google.colab import userdata
import torchaudio
from IPython.display import Audio as AudioPlayer

Collecting umap-learn
  Downloading umap_learn-0.5.6-py3-none-any.whl (85 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.7/85.7 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-2.19.1-py3-none-any.whl (542 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.0/542.0 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting POT
  Downloading POT-0.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (823 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.0/823.0 kB[0m [31m23.0 MB/s[0m eta [36m0:00:00[0m
Collecting pynndescent>=0.5 (from umap-learn)
  Downloading pynndescent-0.5.12-py3-none-any.whl (56 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.8/56.8 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow>=12.0.0 (from datasets)
  Downloading pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl (40.8 MB)
[2K     [90m━━━━

# Hyperparameters

In [None]:
# Hyperparameters from authors code
batch_size, validation_batch_size, num_training_updates = 256, 32, 15000
num_hiddens, num_residual_hiddens, num_residual_layers = 128, 32, 2
embedding_dim, num_embeddings = 64, 512
commitment_cost, decay, learning_rate = 0.25, 0.99, 1e-3 # decay > 0 uses VQ EMA

# Hugging face dataset repo auth
repo = "autumnjohnson/ceti_audio"
token = userdata.get('HF_TOKEN')
login(token = token)

# Select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


## Load datasets

In [None]:
def get_spectrogram(waveform):
    waveform = waveform['audio']['array']
    # Compute STFT and convert to spectrogram (magnitude)
    spectrogram = np.abs(librosa.stft(waveform))

    # Resize spectrogram to 32x32
    resized_spectrogram = cv2.resize(spectrogram, dsize=(32, 32), interpolation=cv2.INTER_LINEAR)

        # Choose colormap
    colormap = cm.plasma

    # Normalize spectrogram
    normalized_spectrogram = (resized_spectrogram - resized_spectrogram.min()) / (resized_spectrogram.max() - resized_spectrogram.min())

    # Convert magnitude values to colors and discard alpha channel
    return (colormap(normalized_spectrogram)[::, :, :3] * 255).astype(np.uint8)

def get_spectrograms(data):
    """
    Get spectrograms
    """
    spectrograms = []
    for waveform in data:
        spectrogram = get_spectrogram(waveform)
        spectrograms.append(spectrogram)

    spectrograms = np.array(spectrograms)
    return spectrograms

def plot_spectrogram(spectrogram):
    plt.imshow(spectrogram)
    plt.axis('off')
    plt.show()

def get_spectrograms(data):
    """
    Get spectrograms
    """
    spectrograms = []
    for waveform in data:
        waveform = waveform['audio']['array']
        # Compute STFT and convert to spectrogram (magnitude)
        spectrogram = np.abs(librosa.stft(waveform))

        # Resize spectrogram to 32x32
        resized_spectrogram = cv2.resize(spectrogram, dsize=(32, 32), interpolation=cv2.INTER_LINEAR)

        # Choose colormap
        colormap = cm.plasma

        # Normalize spectrogram
        normalized_spectrogram = (resized_spectrogram - resized_spectrogram.min()) / (resized_spectrogram.max() - resized_spectrogram.min())

        # Convert magnitude values to colors and discard alpha channel
        spectrogram = (colormap(normalized_spectrogram)[::, :, :3] * 255).astype(np.uint8)
        spectrograms.append(spectrogram)

    return  np.array(spectrograms)

def plot_spectrogram(spectrogram):
    plt.imshow(spectrogram)
    plt.axis('off')
    plt.show()


In [None]:
whale_data = load_dataset(repo)
train_whale  = whale_data['train'].select_columns(['audio']).cast_column("audio", Audio(decode=True, sampling_rate=16_000))
train_whale = np.transpose(get_spectrograms(train_whale), (0, 3, 1, 2))
whale_loader = DataLoader(train_whale, batch_size=batch_size, shuffle=True, pin_memory=True)


data_variance = np.var(whale_loader.dataset / 255.0)

test_whale  = whale_data['test'].select_columns(['audio']).cast_column("audio", Audio(decode=True, sampling_rate=16_000))
test_whale = np.transpose(get_spectrograms(test_whale), (0, 3, 1, 2))
whale_test_loader = DataLoader(test_whale, batch_size=validation_batch_size, shuffle=False, pin_memory=True)

train_speech = torchaudio.datasets.SPEECHCOMMANDS('.', subset="training", download=True)
speech_loader = DataLoader(train_speech, batch_size=batch_size, shuffle=True, pin_memory=True)
test_speech = torchaudio.datasets.SPEECHCOMMANDS('.', subset="validation", download=True)
speech_test_loader = DataLoader(test_speech, batch_size=validation_batch_size, shuffle=False)


Downloading readme:   0%|          | 0.00/3.56k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/146M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/16.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3160 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/352 [00:00<?, ? examples/s]

100%|██████████| 2.26G/2.26G [00:21<00:00, 114MB/s]


In [None]:

train_speech = torchaudio.datasets.SPEECHCOMMANDS(".", download=True)

## Vector quantization

In [None]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()

        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings

        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
        self._commitment_cost = commitment_cost

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape

        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)

        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))

        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)

        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)

        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self._commitment_cost * e_latent_loss

        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings

In [None]:
class VectorQuantizerEMA(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
        super(VectorQuantizerEMA, self).__init__()

        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings

        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.normal_()
        self._commitment_cost = commitment_cost

        self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
        self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
        self._ema_w.data.normal_()

        self._decay = decay
        self._epsilon = epsilon

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape

        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)

        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))

        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)

        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)

        # Use EMA to update the embedding vectors
        if self.training:
            self._ema_cluster_size = self._ema_cluster_size * self._decay + \
                                     (1 - self._decay) * torch.sum(encodings, 0)

            # Laplace smoothing of the cluster size
            n = torch.sum(self._ema_cluster_size.data)
            self._ema_cluster_size = (
                (self._ema_cluster_size + self._epsilon)
                / (n + self._num_embeddings * self._epsilon) * n)

            dw = torch.matmul(encodings.t(), flat_input)
            self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)

            self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))

        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        loss = self._commitment_cost * e_latent_loss

        # Straight Through Estimator
        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings

## Model architecture

In [None]:
class Residual(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
        super(Residual, self).__init__()
        self._block = nn.Sequential(
            nn.ReLU(True),
            nn.Conv2d(in_channels=in_channels,
                      out_channels=num_residual_hiddens,
                      kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(True),
            nn.Conv2d(in_channels=num_residual_hiddens,
                      out_channels=num_hiddens,
                      kernel_size=1, stride=1, bias=False)
        )

    def forward(self, x):
        return x + self._block(x)

class ResidualStack(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(ResidualStack, self).__init__()
        self._num_residual_layers = num_residual_layers
        self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
                             for _ in range(self._num_residual_layers)])

    def forward(self, x):
        for i in range(self._num_residual_layers):
            x = self._layers[i](x)
        return F.relu(x)

In [None]:
class Decoder(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(Decoder, self).__init__()

        self._conv_1 = nn.Conv2d(in_channels=in_channels,
                                 out_channels=num_hiddens,
                                 kernel_size=3,
                                 stride=1, padding=1)

        self._residual_stack = ResidualStack(in_channels=num_hiddens,
                                             num_hiddens=num_hiddens,
                                             num_residual_layers=num_residual_layers,
                                             num_residual_hiddens=num_residual_hiddens)

        self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens,
                                                out_channels=num_hiddens//2,
                                                kernel_size=4,
                                                stride=2, padding=1)

        self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens//2,
                                                out_channels=3,
                                                kernel_size=4,
                                                stride=2, padding=1)

    def forward(self, inputs):
        x = self._conv_1(inputs)

        x = self._residual_stack(x)

        x = self._conv_trans_1(x)
        x = F.relu(x)

        return self._conv_trans_2(x)

In [None]:
class Encoder(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(Encoder, self).__init__()

        self._conv_1 = nn.Conv2d(in_channels=in_channels,
                                 out_channels=num_hiddens//2,
                                 kernel_size=4,
                                 stride=2, padding=1)
        self._conv_2 = nn.Conv2d(in_channels=num_hiddens//2,
                                 out_channels=num_hiddens,
                                 kernel_size=4,
                                 stride=2, padding=1)
        self._conv_3 = nn.Conv2d(in_channels=num_hiddens,
                                 out_channels=num_hiddens,
                                 kernel_size=3,
                                 stride=1, padding=1)
        self._residual_stack = ResidualStack(in_channels=num_hiddens,
                                             num_hiddens=num_hiddens,
                                             num_residual_layers=num_residual_layers,
                                             num_residual_hiddens=num_residual_hiddens)

    def forward(self, inputs):
        x = self._conv_1(inputs)
        x = F.relu(x)

        x = self._conv_2(x)
        x = F.relu(x)

        x = self._conv_3(x)
        return self._residual_stack(x)

In [None]:
class Model(nn.Module):
    def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
                 num_embeddings, embedding_dim, commitment_cost, decay=0):
        super(Model, self).__init__()

        self._encoder = Encoder(3, num_hiddens,
                                num_residual_layers,
                                num_residual_hiddens)
        self._pre_vq_conv = nn.Conv2d(in_channels=num_hiddens,
                                      out_channels=embedding_dim,
                                      kernel_size=1,
                                      stride=1)
        if decay > 0.0:
            self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim,
                                              commitment_cost, decay)
        else:
            self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
                                           commitment_cost)
        self._decoder = Decoder(embedding_dim,
                                num_hiddens,
                                num_residual_layers,
                                num_residual_hiddens)

    def forward(self, x):
        z = self._encoder(x)
        z = self._pre_vq_conv(z)
        loss, quantized, perplexity, _ = self._vq_vae(z)
        x_recon = self._decoder(quantized)

        return loss, x_recon, perplexity

## Train

In [None]:
# Create optimizer and model
model = Model(num_hiddens, num_residual_layers, num_residual_hiddens,
              num_embeddings, embedding_dim,
              commitment_cost, decay).to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate, amsgrad=False)

In [None]:
# Run training loop
model.train()
train_res_recon_error = []
train_res_perplexity = []
for i in xrange(num_training_updates):
    data = next(iter(whale_loader))
    data = data.to(device).float()
    optimizer.zero_grad()

    vq_loss, data_recon, perplexity = model(data)
    recon_error = F.mse_loss(data_recon, data) / data_variance
    loss = recon_error + vq_loss
    loss.backward()

    optimizer.step()

    train_res_recon_error.append(recon_error.item())
    train_res_perplexity.append(perplexity.item())

    if (i+1) % 100 == 0:
        print('%d iterations' % (i+1))
        print('recon_error: %.3f' % np.mean(train_res_recon_error[-100:]))
        print('perplexity: %.3f' % np.mean(train_res_perplexity[-100:]))
        print()

100 iterations
recon_error: 75696.479
perplexity: 1.583

200 iterations
recon_error: 14903.742
perplexity: 1.808

300 iterations
recon_error: 9155.877
perplexity: 1.881

400 iterations
recon_error: 7959.294
perplexity: 3.148

500 iterations
recon_error: 6728.203
perplexity: 4.263

600 iterations
recon_error: 6272.222
perplexity: 6.077

700 iterations
recon_error: 5520.390
perplexity: 8.092

800 iterations
recon_error: 4502.855
perplexity: 9.417

900 iterations
recon_error: 3937.015
perplexity: 11.075

1000 iterations
recon_error: 3330.674
perplexity: 15.119

1100 iterations
recon_error: 2935.734
perplexity: 16.542

1200 iterations
recon_error: 2671.669
perplexity: 17.140

1300 iterations
recon_error: 2503.912
perplexity: 18.286

1400 iterations
recon_error: 2375.434
perplexity: 20.499

1500 iterations
recon_error: 2237.609
perplexity: 24.514

1600 iterations
recon_error: 2097.613
perplexity: 30.389

1700 iterations
recon_error: 2009.899
perplexity: 34.608

1800 iterations
recon_error: 

## Plot loss

In [None]:
# Run the code in this cell to define the variable 'train_res_recon_error_smooth'
train_res_recon_error_smooth = savgol_filter(train_res_recon_error, 201, 7)
train_res_perplexity_smooth = savgol_filter(train_res_perplexity, 201, 7)

# Then run the code in this cell to plot the data
f = plt.figure(figsize=(8,3))
ax = f.add_subplot(1,2,1)
ax.plot(train_res_recon_error_smooth)
ax.set_yscale('log')
ax.set_title('Smoothed NMSE.')
ax.set_xlabel('iteration')

ax = f.add_subplot(1,2,2)
ax.plot(train_res_perplexity_smooth)
ax.set_title('Smoothed Average codebook usage (perplexity).')
ax.set_xlabel('iteration')

## Visualize embeddings

In [None]:
model.eval()

valid_originals = next(iter(whale_test_loader))
valid_originals = valid_originals.to(device).float()

vq_output_eval = model._pre_vq_conv(model._encoder(valid_originals))
_, valid_quantize, _, _ = model._vq_vae(vq_output_eval)
valid_reconstructions = model._decoder(valid_quantize)


In [None]:
train_originals  = next(iter(whale_loader))
train_originals = train_originals.to(device).float()
_, train_reconstructions, _, _ = model._vq_vae(train_originals)

In [None]:
def show(img):
    fig = plt.imshow(img[0])
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)

show(make_grid(valid_reconstructions.data.cpu()+.5))

In [None]:
show(make_grid(valid_originals.cpu()+0.5))

In [None]:
#!wget https://mikeyfarrow.github.io/nn/SPEECH_embeds512x64.torch -q -O SPEECH_embeds512x64.torch

#!wget https://mikeyfarrow.github.io/nn/WHALES_embeds512x64.torch -q -O WHALES_embeds512x64.torch
#whale_embed = torch.load('/content/WHALES_embeds512x64.torch', map_location=torch.device(device))

In [None]:
# Load the tensor embeddings and compute cost matrices
embed_whale          = model._vq_vae._embedding.weight.data.detach().cpu().numpy()
embed_speech         = torch.load('/content/SPEECH_embeds512x64.torch', map_location=torch.device(device)).detach().cpu().numpy()

q_whale              = torch.tensor(np.ones(len(embed_whale)) / len(embed_whale)).detach().cpu().numpy()
p_speech             = torch.tensor(np.ones(len(embed_speech)) / len(embed_speech)).detach().cpu().numpy()


cost_matrix_whale    = ot.dist(embed_whale, embed_whale)
cost_matrix_speech   = ot.dist(embed_speech, embed_speech)


In [None]:
# Plot latent embedding space
def plot_scatter(xs, xt):
    xs = xs.detach().numpy()
    xt = xt.detach().numpy()
    plt.figure(1)
    plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
    plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
    plt.legend(loc=0)
    plt.title('Source and target distributions')

def plot_cost_matrix(matrix):
    # Plot cost matrix between latent embeddings (You may choose not to visualize the transport plan)
    plt.figure(figsize=(5, 4))
    plt.imshow(cost_matrix, cmap='viridis', interpolation='nearest')
    plt.colorbar(label='Cost')
    plt.title('Cost Matrix between Latent Embeddings')
    plt.xlabel('Embeddings 2')
    plt.ylabel('Embeddings 1')
    plt.show()

gw_dist = ot.gromov.gromov_wasserstein(cost_matrix_speech, cost_matrix_whale, p_speech, q_whale, 'kl_loss', log=True)
print("Gromov-Wasserstein Distance:", gw_dist)



## References


*   [VQ-VAE by Aäron van den Oord et al. in PyTorch:](https://colab.research.google.com/drive/104d_Z_WL5SFfohG0cCuhqSurOK-Zg59c#scrollTo=1cbm0ffGJaj6&uniqifier=2)
*   [google-deepmind/sonnet/vqvae.py](https://github.com/google-deepmind/sonnet/blob/v1/sonnet/python/modules/nets/vqvae.py)
*[google-deepmind/sonnet/vqvae_example.py](https://github.com/google-deepmind/sonnet/blob/v1/sonnet/examples/vqvae_example.ipynb)
*   [Neural Discrete Representation Learning](https://arxiv.org/abs/1711.00937)
*   [AQ-VAE - final](https://colab.research.google.com/drive/1buDO6LC5ap_jK5CkuurV5QkZrYVJy5AD#scrollTo=esQWOX1b_iHs)





## Gromov-Wasserstein alignent


In [None]:
%pip install POT
import numpy as np
import ot
import matplotlib.pyplot as pl

# Assuming model._vq_vae._embedding.weight stores your embeddings
latent_embeddings = model._vq_vae._embedding.weight.detach().cpu().numpy()

# Step 2: Compute pairwise distances
cost_matrix = ot.dist(latent_embeddings, latent_embeddings)

# Step 3: Compute Gromov-Wasserstein distance
p = np.ones(len(latent_embeddings)) / len(latent_embeddings)
q = np.ones(len(latent_embeddings)) / len(latent_embeddings)
gw_dist,_= ot.gromov.gromov_wasserstein(
    cost_matrix, cost_matrix, p, q, 'square_loss', log=True)

print("Gromov-Wasserstein Distance:", gw_dist)

# Step 4: Visualize the results (You may choose not to visualize the transport plan)
plt.figure(figsize=(10, 8))
plt.imshow(cost_matrix, cmap='Blues', interpolation='nearest')
plt.colorbar(label='Cost')
plt.title('Cost Matrix between Latent Embeddings')
plt.xlabel('Embeddings 2')
plt.ylabel('Embeddings 1')
plt.show()
p = ot.unif(n_samples + n_noise)
q = ot.unif(n_samples + n_noise)

w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=0.5, log=True)
w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=0.1, m=0.5,
                                                 log=True)

print('Partial Wasserstein distance (m = 0.5): ' + str(log0['partial_w_dist']))
print('Entropic partial Wasserstein distance (m = 0.5): ' +
      str(log['partial_w_dist']))

pl.figure(1, (10, 5))
pl.subplot(1, 2, 1)
pl.imshow(w0, cmap='jet')
pl.title('Partial Wasserstein')
pl.subplot(1, 2, 2)
pl.imshow(w, cmap='jet')
pl.title('Entropic partial Wasserstein')
pl.show()


n_samples = 20  # nb samples
n_noise = 10  # nb of samples (noise)

p = ot.unif(n_samples + n_noise)
q = ot.unif(n_samples + n_noise)

mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])

mu_t = np.array([0, 0, 0])
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])


xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) + 1) * 4)), axis=0)
P = sp.linalg.sqrtm(cov_t)
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
xt = np.concatenate((xt, ((np.random.rand(n_noise, 3) + 1) * 10)), axis=0)

fig = pl.figure()
ax1 = fig.add_subplot(121)
ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
ax2 = fig.add_subplot(122, projection='3d')
ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
pl.show()


C1 = sp.spatial.distance.cdist(xs, xs)
C2 = sp.spatial.distance.cdist(xt, xt)

# transport 100% of the mass
print('------m = 1')
m = 1
res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)
res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
                                                          m=m, log=True,
                                                          verbose=True)

print('Wasserstein distance (m = 1): ' + str(log0['partial_gw_dist']))
print('Entropic Wasserstein distance (m = 1): ' + str(log['partial_gw_dist']))

pl.figure(1, (10, 5))
pl.title("mass to be transported m = 1")
pl.subplot(1, 2, 1)
pl.imshow(res0, cmap='jet')
pl.title('Gromov-Wasserstein')
pl.subplot(1, 2, 2)
pl.imshow(res, cmap='jet')
pl.title('Entropic Gromov-Wasserstein')
pl.show()

# transport 2/3 of the mass
print('------m = 2/3')
m = 2 / 3
res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True,
                                                   verbose=True)
res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
                                                          m=m, log=True,
                                                          verbose=True)

print('Partial Wasserstein distance (m = 2/3): ' +
      str(log0['partial_gw_dist']))
print('Entropic partial Wasserstein distance (m = 2/3): ' +
      str(log['partial_gw_dist']))

pl.figure(1, (10, 5))
pl.title("mass to be transported m = 2/3")
pl.subplot(1, 2, 1)
pl.imshow(res0, cmap='jet')
pl.title('Partial Gromov-Wasserstein')
pl.subplot(1, 2, 2)
pl.imshow(res, cmap='jet')
pl.title('Entropic partial Gromov-Wasserstein')
pl.show()


In [None]:
def plot_waveform(waveform, sample_rate):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
    figure.suptitle("waveform")

In [None]:
plot_specgram(waveform, sample_rate, title="Original")
play_audio(waveform, sample_rate)

configs = [
    ({"format": "wav", "encoding": 'ULAW', "bits_per_sample": 8}, "8 bit mu-law"),
    ({"format": "gsm"}, "GSM-FR"),
    ({"format": "mp3", "compression": -9}, "MP3"),
    ({"format": "vorbis", "compression": -1}, "Vorbis"),
]
for param, title in configs:
  augmented = F.apply_codec(waveform, sample_rate, **param)
  plot_specgram(augmented, sample_rate, title=title)
  play_audio(augmented, sample_rate)