In [25]:
# Imports

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt

In [26]:
# --- Config ---
DATA_DIR = "../data/processed"
BATCH_SIZE = 16
LATENT_DIM = 256 #could be 128 or 256
SPEAKER_EMB_DIM = 64 # could be 32 or 64
EPOCHS = 10
LEARNING_RATE = 1e-3
BETA = 0.1 # KL divergence weight
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Using device: cuda


In [27]:
# --- Dataset ---
class MelDataset(Dataset):
    def __init__(self, data_dir):
        self.files = list(Path(data_dir).glob("*.pt"))
        self.speakers = sorted(list({f.stem.split('_')[0] for f in self.files}))
        self.spk2idx = {spk: i for i, spk in enumerate(self.speakers)}

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        sample = torch.load(self.files[idx])
        mel = sample['mel']
        speaker = sample['speaker_id']
        speaker_idx = self.spk2idx[speaker]
        return mel, speaker_idx

In [28]:
def collate_fn(batch, target_len=400):
    mels, spk_ids = zip(*batch)
    padded = []

    for mel in mels:
        if mel.shape[1] >= target_len:
            mel_fixed = mel[:, :target_len]
        else:
            pad_width = target_len - mel.shape[1]
            mel_fixed = F.pad(mel, (0, pad_width))  # pad on the right
        padded.append(mel_fixed)

    mels_tensor = torch.stack(padded)
    spk_ids_tensor = torch.tensor(spk_ids)
    return mels_tensor, spk_ids_tensor

In [30]:
# --- Model ---
class VAE(nn.Module):
    def __init__(self, input_dim=80, latent_dim=128, speaker_emb_dim=32, num_speakers=100):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv1d(input_dim, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1)
        )
        self.to_mu = nn.Linear(128, latent_dim)
        self.to_logvar = nn.Linear(128, latent_dim)

        self.speaker_embed = nn.Embedding(num_speakers, speaker_emb_dim)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + speaker_emb_dim, 512),
            nn.ReLU(),
            nn.Linear(512, input_dim * 400),
            nn.Tanh()
        )
        self.input_dim = input_dim

    def forward(self, x, speaker_id):
        h = self.encoder(x).squeeze(-1)  # [B, 128]
        mu = self.to_mu(h)
        logvar = self.to_logvar(h)
        std = torch.exp(0.5 * logvar)
        z = mu + std * torch.randn_like(std)
        speaker_emb = self.speaker_embed(speaker_id)
        z_cat = torch.cat([z, speaker_emb], dim=-1)
        out = self.decoder(z_cat).view(-1, self.input_dim, 400)
        return out, mu, logvar

In [31]:
def vae_loss(recon, x, mu, logvar, beta=BETA):
    recon_loss = F.mse_loss(recon, x)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)
    return recon_loss + beta * kl_loss, recon_loss.item(), kl_loss.item()

In [32]:
# --- Training ---
def train(data_dir="../data/processed",
          latent_dim=256,
          speaker_emb_dim=64,
          batch_size=16,
          epochs=20,
          learning_rate=1e-3,
          beta=0.01,
          save_dir="../checkpoints"):

    dataset = MelDataset(data_dir)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    
    model = VAE(input_dim=80,
                latent_dim=latent_dim,
                speaker_emb_dim=speaker_emb_dim,
                num_speakers=len(dataset.speakers)).to(DEVICE)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    os.makedirs(save_dir, exist_ok=True)

    for epoch in range(epochs):
        model.train()
        total_loss, total_recon, total_kl = 0, 0, 0
        for x, spk in tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}"):
            x, spk = x.to(DEVICE), spk.to(DEVICE)
            recon, mu, logvar = model(x, spk)
            loss, recon_l, kl_l = vae_loss(recon, x, mu, logvar, beta=beta)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            total_recon += recon_l
            total_kl += kl_l

        print(f"Epoch {epoch+1}: Total={total_loss:.2f} | Recon={total_recon:.2f} | KL={total_kl:.2f}")

    model_name = f"vae_lat{latent_dim}_spk{speaker_emb_dim}_ep{epochs}_beta{int(beta*1000):03}.pt"
    save_path = os.path.join(save_dir, model_name)
    torch.save(model.state_dict(), save_path)
    print(f"✅ Model saved to {save_path}")


In [22]:
# Run training
train()

Epoch 1/10: 100%|██████████| 5521/5521 [00:47<00:00, 117.16it/s]


Epoch 1: Total=3629.66 | Recon=3629.47 | KL=1.97


Epoch 2/10: 100%|██████████| 5521/5521 [00:46<00:00, 119.37it/s]


Epoch 2: Total=3609.07 | Recon=3609.05 | KL=0.17


Epoch 3/10: 100%|██████████| 5521/5521 [00:47<00:00, 116.21it/s]


Epoch 3: Total=3605.94 | Recon=3605.93 | KL=0.13


Epoch 4/10: 100%|██████████| 5521/5521 [00:47<00:00, 116.93it/s]


Epoch 4: Total=3604.45 | Recon=3604.44 | KL=0.11


Epoch 5/10: 100%|██████████| 5521/5521 [00:46<00:00, 117.48it/s]


Epoch 5: Total=3603.61 | Recon=3603.60 | KL=0.11


Epoch 6/10: 100%|██████████| 5521/5521 [00:47<00:00, 116.83it/s]


Epoch 6: Total=3602.92 | Recon=3602.91 | KL=0.10


Epoch 7/10: 100%|██████████| 5521/5521 [00:49<00:00, 112.58it/s]


Epoch 7: Total=3602.53 | Recon=3602.52 | KL=0.09


Epoch 8/10: 100%|██████████| 5521/5521 [00:49<00:00, 111.93it/s]


Epoch 8: Total=3602.30 | Recon=3602.29 | KL=0.09


Epoch 9/10: 100%|██████████| 5521/5521 [00:47<00:00, 116.66it/s]


Epoch 9: Total=3602.07 | Recon=3602.06 | KL=0.09


Epoch 10/10: 100%|██████████| 5521/5521 [00:48<00:00, 113.18it/s]


Epoch 10: Total=3601.91 | Recon=3601.90 | KL=0.09
✅ Model saved as vae_model.pt


In [33]:
train(latent_dim=256, speaker_emb_dim=64, epochs=10, beta=0.1)

Epoch 1/10: 100%|██████████| 5521/5521 [00:52<00:00, 105.81it/s]


Epoch 1: Total=3628.87 | Recon=3628.46 | KL=4.11


Epoch 2/10: 100%|██████████| 5521/5521 [00:46<00:00, 118.25it/s]


Epoch 2: Total=3609.66 | Recon=3609.64 | KL=0.22


Epoch 3/10: 100%|██████████| 5521/5521 [00:47<00:00, 117.22it/s]


Epoch 3: Total=3606.77 | Recon=3606.75 | KL=0.17


Epoch 4/10: 100%|██████████| 5521/5521 [00:47<00:00, 115.99it/s]


Epoch 4: Total=3605.33 | Recon=3605.32 | KL=0.16


Epoch 5/10: 100%|██████████| 5521/5521 [00:47<00:00, 116.15it/s]


Epoch 5: Total=3604.61 | Recon=3604.60 | KL=0.15


Epoch 6/10: 100%|██████████| 5521/5521 [00:47<00:00, 116.26it/s]


Epoch 6: Total=3604.27 | Recon=3604.26 | KL=0.15


Epoch 7/10: 100%|██████████| 5521/5521 [00:48<00:00, 114.87it/s]


Epoch 7: Total=3603.97 | Recon=3603.95 | KL=0.15


Epoch 8/10: 100%|██████████| 5521/5521 [00:47<00:00, 115.58it/s]


Epoch 8: Total=3603.71 | Recon=3603.69 | KL=0.15


Epoch 9/10: 100%|██████████| 5521/5521 [00:47<00:00, 115.82it/s]


Epoch 9: Total=3603.51 | Recon=3603.50 | KL=0.15


Epoch 10/10: 100%|██████████| 5521/5521 [00:47<00:00, 115.54it/s]


Epoch 10: Total=3603.45 | Recon=3603.44 | KL=0.15
✅ Model saved to ../checkpoints\vae_lat256_spk64_ep10_beta100.pt
