In [1]:
# %pip install matplotlib
# %pip install tqdm  
from torch.utils.data import DataLoader
from utils.dataset import VCTKDataset
import numpy as np
import os
from pathlib import Path
import torch
import torch.nn as nn
from src.model.cvae_tacotron_wrapper import CVAETacotron2, cvae_taco_loss
import argparse
from matplotlib import pyplot as plt
import ipywidgets
from tqdm.notebook import tqdm



In [2]:
dataset = VCTKDataset("./dataset/VCTK",)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=0,
                    collate_fn=VCTKDataset.collate_cvae)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Selected speakers: ['p323', 'p301', 'p240', 'p299', 'p225', 'p285', 'p252', 'p279', 'p287', 'p311']
Speaker 0: p323: F, 19yo, SouthAfrican (Pretoria)
Speaker 1: p301: F, 23yo, American (North Carolina)
Speaker 2: p240: F, 21yo, English (Southern England)
Speaker 3: p299: F, 25yo, American (California)
Speaker 4: p225: F, 23yo, English (Southern England)
Speaker 5: p285: M, 21yo, Scottish (Edinburgh)
Speaker 6: p252: M, 22yo, Scottish (Edinburgh)
Speaker 7: p279: M, 23yo, English (Leicester)
Speaker 8: p287: M, 23yo, English (York)
Speaker 9: p311: M, 21yo, American (Iowa)
Using device: cuda


In [3]:
EPOCH = 5
BATCH_SIZE = 32
LR = 1e-3



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = CVAETacotron2(ckpt_path="./src/model/tacotron2_pretrained.pt", z_dim=64, spk_dim_raw=256, spk_dim_proj=128)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)
model.train()

loss_tracker = []

for epoch in range(EPOCH):
    loss_sum = 0
    logsum = {'l1': 0, 'gate': 0, 'kl': 0}

    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCH}", leave=False, position=0, dynamic_ncols=True)

    for text, text_len, mel, mel_len, gate, spk_id in pbar:
        print(f"Processing batch with text length {text_len} and mel length {mel_len}")
        text, text_len = text.to(device), text_len.to(device)
        mel, gate = mel.to(device), gate.to(device)
        spk_emd = model.spk_emb[spk_id].to(device)
        print(1)

        mel_post, mel_out, gate_out, mu, logvar = model(text, text_len, mel, spk_emd)
        print(2)
        loss, logs = cvae_taco_loss(mel_post, mel, gate_out, gate, mu, logvar)
        print(3)
        

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_sum += loss.item()
        logsum['l1'] += logs['l1']
        logsum['gate'] += logs['gate']
        logsum['kl'] += logs['kl']

        pbar.set_postfix(loss=loss.item(), l1=logs['l1'], gate=logs['gate'], kl=logs['kl'])
        break

    loss_avg = loss_sum / len(dataloader)
    loss_tracker.append(loss_avg)
    print(f"[Epoch {epoch+1}/{EPOCH}] Loss: {loss_avg:.4f} | l1: {logsum['l1'] / len(dataloader):.4f}, gate: {logsum['gate'] / len(dataloader):.4f}, kl: {logsum['kl'] / len(dataloader):.4f}")

plt.plot(loss_tracker)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid()
plt.show()

model_save_path = Path("./src/model/cvae_tacotron2_trial1.pth")
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

Using device: cuda


Using cache found in C:\Users\jx/.cache\torch\hub\NVIDIA_DeepLearningExamples_torchhub


Epoch 1/5:   0%|          | 0/244 [00:00<?, ?it/s]

Processing batch with text length tensor([48, 29, 36, 12, 91, 55, 34, 26, 32, 38, 61, 39, 27, 25, 47, 24]) and mel length tensor([422, 181, 251, 204, 449, 352, 229, 243, 284, 346, 366, 240, 280, 204,
        370, 236])
1


RuntimeError: `lengths` array must be sorted in decreasing order when `enforce_sorted` is True. You can pass `enforce_sorted=False` to pack_padded_sequence and/or pack_sequence to sidestep this requirement if you do not need ONNX exportability.

In [None]:

dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
text, text_len, mel, mel_len, gate, spk_id = next(iter(dataloader))
text, text_len = text.to(device), text_len.to(device)
mel, gate = mel.to(device), gate.to(device)
spk_emd = model.spk_emb[spk_id].to(device)
mel_post, mel_out, gate_out, mu, logvar = model(text, text_len, mel, spk_emd)
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.imshow(mel_post[0].cpu().detach().numpy(), aspect='auto', origin='lower')
plt.title('Mel Post')
plt.subplot(1, 2, 2)



    