In [38]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset

import torchaudio

import os

from tqdm.notebook import tqdm

In [None]:
torchaudio.utils.ffmpeg_utils.get_audio_decoders()

In [6]:
def precompute_mel_spectrogram(root_dir, output_dir, **mel_params):
    os.makedirs(output_dir, exist_ok=True)
    
    mel_transform = torchaudio.transforms.MelSpectrogram(**mel_params)
    amp_to_db = torchaudio.transforms.AmplitudeToDB(top_db=80.0)

    audio_files = [f for f in os.listdir(root_dir) if f.endswith((".mp3", ".wav", ".flac"))]

    for audio_file in tqdm(audio_files):
        waveform, sr = torchaudio.load(os.path.join(root_dir, audio_file))

        if sr != mel_params["sample_rate"]:
            resampler = torchaudio.transforms.Resample(sr, mel_params["sample_rate"])
            waveform = resampler(waveform)

        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)

        # Compute mel spectrogram
        mel_spec = mel_transform(waveform)
        # Convert to decibels
        mel_spec = amp_to_db(mel_spec)
        # Normalize
        mel_spec = (mel_spec - mel_spec.mean()) / mel_spec.std()

        # Save to disk
        torch.save(mel_spec, os.path.join(output_dir, f"{os.path.splitext(audio_file)[0]}.pt"))


In [7]:
mel_params = {
    "sample_rate": 32000,
    "n_fft": 2048,
    "hop_length": 512,
    "n_mels": 160,
    "f_min": 20,
    "f_max": 16000,
    "power": 2.0,
}

precompute_mel_spectrogram("./data/download", "./data/mel_spectrograms", **mel_params)

  0%|          | 0/1 [00:00<?, ?it/s]

In [8]:
test = torch.load("./data/mel_spectrograms/yes.pt")

In [10]:
test.shape

torch.Size([1, 160, 4028])

In [None]:
# During training: Keep complete chunks for stable VQ-VAE training
# During inference/feature extraction: Handle partial chunks by either:
# a) Zero-padding (simpler) 
# b) Overlapping strategy for the last chunk

In [29]:
class AudioDataset(Dataset):
    def __init__(self, root_dir, chunk_size, overlap):
        self.root_dir = root_dir
        self.chunk_size = chunk_size
        self.overlap = overlap

        self.audio_files = [f for f in os.listdir(root_dir) if f.endswith(".pt")]
        
    def __len__(self):
        return len(self.audio_files)
    
    def __getitem__(self, idx):
        mel_spec = torch.load(os.path.join(self.root_dir, self.audio_files[idx]))
        mel_spec = mel_spec.squeeze(0)

        chunks = []
        step_size = self.chunk_size - self.overlap
        for start in range(0, mel_spec.size(1) - self.chunk_size + 1, step_size):
            end = start + self.chunk_size
            chunk = mel_spec[:, start:end]
            chunks.append(chunk)
        return torch.stack(chunks)

In [30]:
dataset = AudioDataset("./data/mel_spectrograms", chunk_size=512, overlap=256)

In [31]:
dataset[0].shape

torch.Size([14, 160, 512])

In [None]:
# VQ-VAE Model

In [32]:
class Encoder(nn.Module):
    def __init__(self, in_channels, hidden_dims, n_embed):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv1d(in_channels, hidden_dims, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv1d(hidden_dims, hidden_dims, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv1d(hidden_dims, n_embed, kernel_size=3, stride=2, padding=1)
        )
    
    def forward(self, x):
        return self.encoder(x)

In [33]:
class VectorQuantizer(nn.Module):
    def __init__(self, n_embed, embed_dim, beta=0.25):
        super().__init__()
        self.n_embed = n_embed
        self.embed_dim = embed_dim
        self.beta = beta
        
        self.embedding = nn.Embedding(n_embed, embed_dim)
        self.embedding.weight.data.uniform_(-1.0 / n_embed, 1.0 / n_embed)

    def forward(self, z):
        # Reshape z -> (batch, height, width, channel)
        z = z.permute(0, 2, 1).contiguous()
        z_flattened = z.view(-1, self.embed_dim)

        # Distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2ze
        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
            torch.sum(self.embedding.weight**2, dim=1) - \
            2 * torch.matmul(z_flattened, self.embedding.weight.t())

        # Find nearest encoding
        min_encoding_indices = torch.argmin(d, dim=1)
        z_q = self.embedding(min_encoding_indices).view(z.shape)

        # Preserve gradients
        z_q = z + (z_q - z).detach()

        # Compute losses
        commitment_loss = F.mse_loss(z_q.detach(), z)
        codebook_loss = F.mse_loss(z_q, z.detach())
        loss = codebook_loss + self.beta * commitment_loss

        # Reshape back to match original input shape
        z_q = z_q.permute(0, 2, 1).contiguous()

        return z_q, loss, min_encoding_indices

In [34]:
class Decoder(nn.Module):
    def __init__(self, in_channels, hidden_dims, out_channels):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(in_channels, hidden_dims, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(hidden_dims, hidden_dims, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(hidden_dims, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
        )
    
    def forward(self, x):
        return self.decoder(x)

In [35]:
class VQVAE(nn.Module):
    def __init__(self, in_channels, hidden_dims=128, n_embed=512, embed_dim=64):
        super().__init__()
        self.encoder = Encoder(in_channels, hidden_dims, embed_dim)
        self.vector_quantizer = VectorQuantizer(n_embed, embed_dim)
        self.decoder = Decoder(embed_dim, hidden_dims, in_channels)

    def forward(self, x):
        z = self.encoder(x)
        z_q, vq_loss, indices = self.vector_quantizer(z)
        x_recon = self.decoder(z_q)
        
        return x_recon, vq_loss, indices

    def encode(self, x):
        z = self.encoder(x)
        z_q, _, indices = self.vector_quantizer(z)
        return indices

    def decode(self, indices):
        z_q = self.vector_quantizer.embedding(indices)
        z_q = z_q.permute(0, 2, 1)  # Reshape for decoder
        x_recon = self.decoder(z_q)
        return x_recon

In [36]:
model = VQVAE(in_channels=160)

In [37]:
model.to("cuda")

VQVAE(
  (encoder): Encoder(
    (encoder): Sequential(
      (0): Conv1d(160, 128, kernel_size=(3,), stride=(2,), padding=(1,))
      (1): ReLU()
      (2): Conv1d(128, 128, kernel_size=(3,), stride=(2,), padding=(1,))
      (3): ReLU()
      (4): Conv1d(128, 64, kernel_size=(3,), stride=(2,), padding=(1,))
    )
  )
  (vector_quantizer): VectorQuantizer(
    (embedding): Embedding(512, 64)
  )
  (decoder): Decoder(
    (decoder): Sequential(
      (0): ConvTranspose1d(64, 128, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
      (1): ReLU()
      (2): ConvTranspose1d(128, 128, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
      (3): ReLU()
      (4): ConvTranspose1d(128, 160, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    )
  )
)

In [46]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [47]:
def train_step(x):
    optimizer.zero_grad()
    x_recon, vq_loss, _ = model(x)
    recon_loss = F.mse_loss(x_recon, x)
    loss = recon_loss + vq_loss
    loss.backward()
    optimizer.step()
    return loss, recon_loss, vq_loss

In [48]:
epochs = 10

In [None]:
for epoch in range(epochs):
    loss, recon_loss, vq_loss = [], [], []

    model.train()
    for x in tqdm(dataset):
        for chunk in x:
            chunk = chunk.to("cuda")
            chunk = chunk.unsqueeze(0)

            l, r, v = train_step(chunk)
            loss.append(l.item())
            recon_loss.append(r.item())
            vq_loss.append(v.item())

    print(f"Epoch {epoch + 1} Loss: {torch.tensor(loss).mean().item()} Recon Loss: {torch.tensor(recon_loss).mean().item()} VQ Loss: {torch.tensor(vq_loss).mean().item()}")


  0%|          | 0/1 [00:00<?, ?it/s]

Epoch 1 Loss: 141755121664.0 Recon Loss: 0.48870864510536194 VQ Loss: 141755121664.0


  0%|          | 0/1 [00:00<?, ?it/s]

Epoch 2 Loss: 309209497600.0 Recon Loss: 0.48867377638816833 VQ Loss: 309209497600.0


  0%|          | 0/1 [00:00<?, ?it/s]

Epoch 3 Loss: 626874253312.0 Recon Loss: 0.4885919988155365 VQ Loss: 626874253312.0


  0%|          | 0/1 [00:00<?, ?it/s]

Epoch 4 Loss: 1194047832064.0 Recon Loss: 0.4885203242301941 VQ Loss: 1194047832064.0


  0%|          | 0/1 [00:00<?, ?it/s]

Epoch 5 Loss: 2155812487168.0 Recon Loss: 0.48846906423568726 VQ Loss: 2155812487168.0


  0%|          | 0/1 [00:00<?, ?it/s]

Epoch 6 Loss: 3716136566784.0 Recon Loss: 0.48841795325279236 VQ Loss: 3716136566784.0


  0%|          | 0/1 [00:00<?, ?it/s]

Epoch 7 Loss: 6152592556032.0 Recon Loss: 0.4883686602115631 VQ Loss: 6152592556032.0


  0%|          | 0/1 [00:00<?, ?it/s]

Epoch 8 Loss: 9832612495360.0 Recon Loss: 0.48832207918167114 VQ Loss: 9832612495360.0


  0%|          | 0/1 [00:00<?, ?it/s]

Epoch 9 Loss: 15231238537216.0 Recon Loss: 0.48827728629112244 VQ Loss: 15231238537216.0


  0%|          | 0/1 [00:00<?, ?it/s]

Epoch 10 Loss: 22949773967360.0 Recon Loss: 0.48823362588882446 VQ Loss: 22949773967360.0
