In [1]:
from torch.utils.data import DataLoader
from utils.dataset import VCTKDataset
import numpy as np
import os
from pathlib import Path
import torch
import torch.nn as nn
from src.model.cvae_tacotron_wrapper import CVAETacotron2, cvae_taco_loss

In [2]:
dataset = VCTKDataset("./dataset/VCTK",)
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4,
                    collate_fn=VCTKDataset.collate_cvae)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Selected speakers: ['p323', 'p301', 'p240', 'p299', 'p225', 'p285', 'p252', 'p279', 'p287', 'p311']
Speaker 0: p323: F, 19yo, SouthAfrican (Pretoria)
Speaker 1: p301: F, 23yo, American (North Carolina)
Speaker 2: p240: F, 21yo, English (Southern England)
Speaker 3: p299: F, 25yo, American (California)
Speaker 4: p225: F, 23yo, English (Southern England)
Speaker 5: p285: M, 21yo, Scottish (Edinburgh)
Speaker 6: p252: M, 22yo, Scottish (Edinburgh)
Speaker 7: p279: M, 23yo, English (Leicester)
Speaker 8: p287: M, 23yo, English (York)
Speaker 9: p311: M, 21yo, American (Iowa)
Using device: cuda


In [3]:
for batch in loader:
    text, text_len, mel, mel_len, gate, spk = batch
    print(f"text: {text.shape}, text_len: {text_len.shape}, mel: {mel.shape}, mel_len: {mel_len.shape}, gate: {gate.shape}, spk: {spk.shape}")
    break  # Remove this line to iterate through the entire dataset
    

text: torch.Size([32, 76]), text_len: torch.Size([32]), mel: torch.Size([32, 80, 562]), mel_len: torch.Size([32]), gate: torch.Size([32, 562]), spk: torch.Size([32])


In [4]:
model = CVAETacotron2(ckpt_path="./src/model/tacotron2_pretrained.pt", z_dim=64, spk_dim_raw=256, spk_dim_proj=128)

Using cache found in C:\Users\jx/.cache\torch\hub\NVIDIA_DeepLearningExamples_torchhub
  speaker_look_up = torch.tensor([i for i in speaker_emb_dict.values()])


In [None]:

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for text, text_len, mel, mel_len, gate, spk_id in loader:
    text, text_len   = text.to(device), text_len.cuda()
    mel, gate   = mel.cuda(), gate.cuda()
    spk_emd = model.spk_emb[spk_id].to(device)  # Get speaker embedding from lookup table

    mel_post, mel_out, gate_out, mu, logvar = model(text, text_len, mel, spk_emd)
    loss, logs = cvae_taco_loss(mel_post, mel, gate_out, gate, mu, logvar)

    optimizer.zero_grad(); loss.backward(); optimizer.step()


KeyboardInterrupt: 