Notebook — VQ-VAE hiérarchique audio 1D 

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math
import random
import soundfile as sf

Hyper-paramètres (article)

In [2]:
SAMPLE_RATE = 16000
DURATION = 4  # secondes pour le test
NUM_SAMPLES = SAMPLE_RATE * DURATION

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

BATCH_SIZE = 8
LR = 2e-4
EPOCHS = 20 

# Architecture VQ-VAE Hiérarchique - Audio

Signal audio brut (4s, 16kHz, 64 000 échantillons)  
      │  
      ▼  
[Encodeur Audio 1D - 5 couches conv, stride 2]  
      │  
      │  
      ├─ Après 3 convolutions → **h_low (LOW)**  
      │      ↓ downsampling factor ≈ 8  
      │      ↓ shape : (BATCH_SIZE, hidden=128, T_low=8000)  
      │      │  
      │      ▼  
      │   [VQ Low]  
      │      ↓ quantification  
      │      ↓ output : **z_low** (BBATCH_SIZE, 128, 8000)  
      │  
      └─ Après 5 convolutions → **h_up (UP)**  
             ↓ downsampling factor ≈ 32  
             ↓ shape : (BATCH_SIZE, hidden=128, T_up=2000)  
             │  
             ▼  
         [VQ Up]  
             ↓ quantification  
             ↓ output : **z_up** (BATCH_SIZE, 128, 2000)  
             │  
             ▼  
[Décodage Audio 1D - 5 couches conv transpose]  
             │  
             ▼  
Signal audio reconstruit (4s, 16kHz, 64 000 échantillons)  

---

### Explications :
- **Hiérarchie multi-niveaux** :  
  - LOW capture la **structure globale / rythme**  
  - UP capture le **timbre et les détails fins**  
- Chaque niveau possède son propre **codebook VQ** (2048 codes, dim 128)  
- Les vecteurs latents sont **discrets** grâce à la quantification  
- La reconstruction se fait uniquement à partir du latent **UP** dans ce modèle


Dataset audio (signal brut 1D)

In [3]:
class AudioDataset(Dataset):
    def __init__(self, file_list):
        self.file_list = file_list

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        wav, sr = torchaudio.load(self.file_list[idx])
        wav = wav.mean(dim=0)  # mono

        if sr != SAMPLE_RATE:
            wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)

        wav = wav[:NUM_SAMPLES]
        if wav.shape[0] < NUM_SAMPLES:
            wav = F.pad(wav, (0, NUM_SAMPLES - wav.shape[0]))

        return wav

Encodeur 1D hiérarchique (5 convs)

In [4]:
class AudioEncoder(nn.Module):
    def __init__(self, hidden=128):
        super().__init__()
        self.convs = nn.ModuleList([
            nn.Conv1d(1, hidden, 4, stride=2, padding=1),
            nn.Conv1d(hidden, hidden, 4, stride=2, padding=1),
            nn.Conv1d(hidden, hidden, 4, stride=2, padding=1),  # LOW
            nn.Conv1d(hidden, hidden, 4, stride=2, padding=1),
            nn.Conv1d(hidden, hidden, 4, stride=2, padding=1),  # UP
        ])

    def forward(self, x):
        h = x.unsqueeze(1)  # (B, 1, T)
        for i, conv in enumerate(self.convs):
            h = F.relu(conv(h))
            if i == 2:
                h_low = h
        h_up = h
        return h_low, h_up

Vector Quantizer (VQ)

In [5]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_codes=2048, dim=128, beta=0.25):
        super().__init__()
        self.embedding = nn.Embedding(num_codes, dim)
        self.embedding.weight.data.uniform_(-1/num_codes, 1/num_codes)
        self.beta = beta

    def forward(self, z):
        z_perm = z.permute(0, 2, 1).contiguous()
        z_flat = z_perm.view(-1, z_perm.size(-1))

        dist = (
            z_flat.pow(2).sum(1, keepdim=True)
            - 2 * z_flat @ self.embedding.weight.t()
            + self.embedding.weight.pow(2).sum(1)
        )

        indices = dist.argmin(1)
        z_q = self.embedding(indices).view(z_perm.shape)
        z_q = z_q.permute(0, 2, 1)

        loss = (
            (z_q.detach() - z).pow(2).mean()
            + self.beta * (z_q - z.detach()).pow(2).mean()
        )

        z_q = z + (z_q - z).detach()
        return z_q, loss

Décodeur 1D

In [6]:
class AudioDecoder(nn.Module):
    def __init__(self, hidden=128):
        super().__init__()
        self.deconvs = nn.ModuleList([
            nn.ConvTranspose1d(hidden, hidden, 4, 2, 1),
            nn.ConvTranspose1d(hidden, hidden, 4, 2, 1),
            nn.ConvTranspose1d(hidden, hidden, 4, 2, 1),
            nn.ConvTranspose1d(hidden, hidden, 4, 2, 1),
            nn.ConvTranspose1d(hidden, 1, 4, 2, 1),
        ])

    def forward(self, z):
        h = z
        for deconv in self.deconvs:
            h = F.relu(deconv(h))
        return h.squeeze(1)


STFT loss (perceptuelle)

In [7]:
def stft_loss(x, x_hat):
    X = torch.stft(x, n_fft=1024, hop_length=256, return_complex=True)
    X_hat = torch.stft(x_hat, n_fft=1024, hop_length=256, return_complex=True)
    return (X.abs() - X_hat.abs()).pow(2).mean()

Modèle VQ-VAE complet

In [8]:
class VQVAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = AudioEncoder()
        self.vq_low = VectorQuantizer()
        self.vq_up = VectorQuantizer()
        self.decoder = AudioDecoder()

    def forward(self, x):
        h_low, h_up = self.encoder(x)
        z_low, loss_low = self.vq_low(h_low)
        z_up, loss_up = self.vq_up(h_up)
        x_hat = self.decoder(z_up)
        return x_hat, loss_low + loss_up

Génération d’un dataset audio provisoire (DEBUG)

Objectif
Cette cellule génère un petit dataset artificiel et temporaire, uniquement pour :

vérifier que le VQ-VAE s’entraîne

débugger l’architecture

tester la reconstruction

⚠️ Ce dataset n’a aucune valeur sémantique
Il sera remplacé plus tard par VGGSound / ESC-50 / etc.

In [9]:
"""
DATASET PROVISOIRE (DEBUG) - SANS BRUIT

- Généré artificiellement pour valider :
    - Chaîne encodeur → VQ → décodeur
    - Stabilité des loss
    - Shapes et reconstruction audio
- Signaux volontairement simples :
    - sinusoïdes
    - glissando (chirp)
⚠️ À NE PAS UTILISER pour des résultats finaux
"""

NUM_EXAMPLES = 20       # Petit dataset debug
class DebugAudioDataset(Dataset):
    def __init__(self, num_examples=NUM_EXAMPLES):
        self.num_examples = num_examples

    def __len__(self):
        return self.num_examples

    def __getitem__(self, idx):
        # Création de la timeline
        t = torch.linspace(0, DURATION, NUM_SAMPLES)

        # Choix aléatoire du type de signal (sans bruit)
        choice = random.choice(["sine", "chirp"])

        if choice == "sine":
            # Sinusoïde simple
            freq = random.uniform(100, 2000)
            audio = torch.sin(2 * math.pi * freq * t)

        elif choice == "chirp":
            # Glissando linéaire de f0 à f1
            f0 = random.uniform(100, 500)
            f1 = random.uniform(1000, 4000)
            audio = torch.sin(2 * math.pi * (f0 * t + (f1 - f0) * t**2))

        # Normalisation entre -1 et 1
        audio = audio / audio.abs().max()

        # Retourne un tensor float32
        return audio.float()


Dataset DataLoader

Objectif
Charger le dataset provisoire exactement comme un vrai dataset audio
(VGGSound plus tard), sans changer le reste du code.

In [10]:
"""
DATALOADER POUR DATASET PROVISOIRE
"""

dataset = DebugAudioDataset(NUM_EXAMPLES)

dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    num_workers=0 
)

# Test
audio_batch = next(iter(dataloader))
print("Batch shape:", audio_batch.shape)

Batch shape: torch.Size([8, 64000])


Entraînement

In [11]:
model = VQVAE().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

In [12]:
for epoch in range(EPOCHS):
    for audio in dataloader:
        audio = audio.to(DEVICE)

        audio_hat, loss_vq = model(audio)
        loss_rec = stft_loss(audio, audio_hat)

        loss = loss_vq + loss_rec

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch} | Loss {loss.item():.4f}")


  return _VF.stft(  # type: ignore[attr-defined]


Epoch 0 | Loss 1120.9595
Epoch 1 | Loss 1071.0157
Epoch 2 | Loss 1027.5417
Epoch 3 | Loss 989.4696
Epoch 4 | Loss 951.3610
Epoch 5 | Loss 914.2968
Epoch 6 | Loss 878.8477
Epoch 7 | Loss 846.5333
Epoch 8 | Loss 810.1745
Epoch 9 | Loss 778.4980
Epoch 10 | Loss 745.1718
Epoch 11 | Loss 712.9871
Epoch 12 | Loss 676.6991
Epoch 13 | Loss 647.6391
Epoch 14 | Loss 622.5359
Epoch 15 | Loss 598.3116
Epoch 16 | Loss 580.7785
Epoch 17 | Loss 579.5220
Epoch 18 | Loss 594.6170
Epoch 19 | Loss 647.9504


Vérification

In [14]:
# On récupère le batch audio et le reconstruit
model.eval()
with torch.no_grad():
    audio = next(iter(dataloader)).to(DEVICE)
    audio_hat, _ = model(audio)

# Conversion en numpy et ajout d'une dimension channel (1, N)
original = audio[0].cpu().numpy()
reconstructed = audio_hat[0].cpu().numpy()

# Sauvegarde avec soundfile
sf.write("audio_original.wav", original, SAMPLE_RATE)
sf.write("audio_reconstructed.wav", reconstructed, SAMPLE_RATE)