# VQ-VAE Hiérarchique pour Audio
## Inspiré de "I Hear Your True Colors"

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

# Paramètres généraux


In [2]:
latent_dim = 64      # dimension de chaque latent vector
num_embeddings = 512 # taille du codebook
num_levels = 2       # nombre de niveaux hiérarchiques
batch_size = 16
lr = 2e-4

# Dataset exemple

In [3]:
class DummyAudioDataset(Dataset):
    def __init__(self, num_samples=1000, seq_len=16384):
        self.data = torch.randn(num_samples, 1, seq_len) # audio mono

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

dataset = DummyAudioDataset()
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Vector Quantizer

In [4]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings)
        self.commitment_cost = commitment_cost
    
    def forward(self, z):
        # z : (B, C, T)
        z_flattened = z.permute(0,2,1).contiguous().view(-1, self.embedding_dim) # (B*T, C)
        distances = (
            torch.sum(z_flattened**2, dim=1, keepdim=True)
            + torch.sum(self.embedding.weight**2, dim=1)
            - 2 * torch.matmul(z_flattened, self.embedding.weight.t())
        )
        encoding_indices = torch.argmin(distances, dim=1)
        z_q = self.embedding(encoding_indices).view(z.size(0), z.size(2), z.size(1)).permute(0,2,1)
        
        # loss
        e_latent_loss = F.mse_loss(z_q.detach(), z)
        q_latent_loss = F.mse_loss(z_q, z.detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        z_q = z + (z_q - z).detach()  # straight-through estimator
        return z_q, loss

# Encodeur

In [5]:
class AudioEncoder(nn.Module):
    def __init__(self, in_channels=1, hidden_dim=128, latent_dim=latent_dim):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, hidden_dim, 4, stride=2, padding=1)
        self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, 4, stride=2, padding=1)
        self.conv3 = nn.Conv1d(hidden_dim, latent_dim, 3, stride=1, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        z = self.conv3(x)
        return z

# Décodeur

In [6]:
class AudioDecoder(nn.Module):
    def __init__(self, latent_dim=latent_dim, hidden_dim=128, out_channels=1):
        super().__init__()
        self.conv1 = nn.ConvTranspose1d(latent_dim, hidden_dim, 4, stride=2, padding=1)
        self.conv2 = nn.ConvTranspose1d(hidden_dim, hidden_dim, 4, stride=2, padding=1)
        self.conv3 = nn.Conv1d(hidden_dim, out_channels, 3, stride=1, padding=1)
        self.relu = nn.ReLU()

    def forward(self, z):
        x = self.relu(self.conv1(z))
        x = self.relu(self.conv2(x))
        x_recon = self.conv3(x)
        return x_recon

# VQ-VAE hiérarchique

In [7]:
class HierarchicalVQVAE(nn.Module):
    def __init__(self, num_levels=num_levels):
        super().__init__()
        self.num_levels = num_levels
        self.encoders = nn.ModuleList([AudioEncoder() for _ in range(num_levels)])
        self.decoders = nn.ModuleList([AudioDecoder() for _ in range(num_levels)])
        self.quantizers = nn.ModuleList([VectorQuantizer(num_embeddings, latent_dim) for _ in range(num_levels)])

    def forward(self, x):
        z_list = []
        loss_list = []
        for i in range(self.num_levels):
            z = self.encoders[i](x)
            z_q, vq_loss = self.quantizers[i](z)
            x = self.decoders[i](z_q)  # reconstruction intermédiaire
            z_list.append(z_q)
            loss_list.append(vq_loss)
        return x, sum(loss_list)

# Initialisation

In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = HierarchicalVQVAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# Boucle d'entraînement

In [9]:
num_epochs = 5
for epoch in range(num_epochs):
    for batch in dataloader:
        batch = batch.to(device)
        optimizer.zero_grad()
        recon, loss_vq = model(batch)
        recon_loss = F.mse_loss(recon, batch)
        loss = recon_loss + loss_vq
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}")

Epoch 1/5, Loss: 0.9986
Epoch 2/5, Loss: 0.9954
Epoch 3/5, Loss: 10.7252
Epoch 4/5, Loss: 151.0423
Epoch 5/5, Loss: 460.8399
