In [1]:
import sys
sys.path.append("../")

import torch
import torchaudio
import librosa
import numpy as np
from numpy import trim_zeros


from src.datasets.fastspeech_dataset import (
    build_path_to_transcript_dict_libri_tts,
    FastSpeechDataset)
from src.tts.models.fastporta.FastPortaVAE import FastPortaVAE
from src.pipelines.fastporta_vae.train_loop import train_loop
from src.spk_embedding.StyleEmbedding import StyleEmbedding

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [2]:
TEST_CLEAN_PATH = '../data/test-clean'
ALIGNER_CHECKPOINT = '../saved_models/aligner.pt'
EMBED_MODEL = "../saved_models/embedding_function.pt"

In [3]:
transcript_dict = build_path_to_transcript_dict_libri_tts(TEST_CLEAN_PATH)

In [4]:
dataset = FastSpeechDataset(
    path_to_transcript_dict=transcript_dict,
    acoustic_checkpoint_path=ALIGNER_CHECKPOINT,  # path to aligner.pt
    cache_dir="./librispeech",
    lang="en",
    loading_processes=2,  # depended on how many CPU you have
    device=device,
)

Prepared a FastSpeech dataset with 10 datapoints in ./librispeech.


In [5]:
net = FastPortaVAE()

In [6]:
train_loop(
    net,
    dataset,
    device=device,
    batch_size=2,
    save_directory="../saved_models",
    path_to_checkpoint=None,
    resume=False,
    phase_1_steps=4,
    phase_2_steps=0,
    steps_per_save=1,
    path_to_embed_model="../saved_models/embedding_function.pt",
    lr=0.01,
)

100%|██████████| 5/5 [01:25<00:00, 17.13s/it]



Steps: 1
Train Loss: 11.680 - Mel Loss: 2.389 - Glow Loss: 3.566 - Duration Loss: 2.453 - Pitch Loss: 1.210 - Energy Loss: 1.694 - Cycle Loss: 2.631 - KL Loss: 0.041 - VAE Reconstruction Loss: 0.327
Time elapsed:  1 Minutes


100%|██████████| 5/5 [01:32<00:00, 18.51s/it]



Steps: 2
Train Loss: 10.395 - Mel Loss: 2.287 - Glow Loss: 3.499 - Duration Loss: 1.592 - Pitch Loss: 1.141 - Energy Loss: 1.508 - Cycle Loss: 2.637 - KL Loss: 0.039 - VAE Reconstruction Loss: 0.327
Time elapsed:  2 Minutes


100%|██████████| 5/5 [01:48<00:00, 21.65s/it]



Steps: 3
Train Loss: 9.771 - Mel Loss: 2.099 - Glow Loss: 3.178 - Duration Loss: 1.418 - Pitch Loss: 1.139 - Energy Loss: 1.570 - Cycle Loss: 2.646 - KL Loss: 0.037 - VAE Reconstruction Loss: 0.329
Time elapsed:  2 Minutes


100%|██████████| 5/5 [01:22<00:00, 16.55s/it]



Steps: 4
Train Loss: 8.611 - Mel Loss: 1.909 - Glow Loss: 2.683 - Duration Loss: 1.359 - Pitch Loss: 0.984 - Energy Loss: 1.315 - Cycle Loss: 2.606 - KL Loss: 0.035 - VAE Reconstruction Loss: 0.326
Time elapsed:  2 Minutes


In [7]:
text_tensors=torch.rand(2,5,62)
text_lens=torch.tensor([5,3])
speech_lens=torch.tensor([10,7])
gold_speech=torch.rand(2,10,80)
gold_durations=torch.tensor([
    [2,2,2,2,2],
    [2,2,3,0,0]
])
gold_pitch=torch.rand(2,5,1)
gold_energy=torch.rand(2,5,1)
is_inference=False
alpha=1.0
utterance_embedding=torch.rand(2,64)
lang_ids=None

In [8]:
outs = net._forward(
    text_tensors=text_tensors,
    text_lens=text_lens,
    utterance_embedding=utterance_embedding,
    is_inference=True,
    lang_ids=lang_ids,
)

for v in outs:
    print(v.shape if v != None else None)

torch.Size([2, 10, 80])
torch.Size([2, 5])
torch.Size([2, 5, 1])
torch.Size([2, 5, 1])
None
torch.Size([])
torch.Size([])


In [9]:
outs = net._forward(
    text_tensors=text_tensors,
    text_lens=text_lens,
    gold_speech=gold_speech,
    speech_lens=speech_lens,
    gold_durations=gold_durations,
    gold_pitch=gold_pitch,
    gold_energy=gold_energy,
    utterance_embedding=utterance_embedding,
    is_inference=False,
    lang_ids=lang_ids,
)

for v in outs:
    print(v.shape if v != None else None)

torch.Size([2, 10, 80])
torch.Size([2, 5])
torch.Size([2, 5, 1])
torch.Size([2, 5, 1])
torch.Size([])
torch.Size([])
torch.Size([])


In [10]:
with torch.no_grad():
    outs = net.forward(
        text_tensors=text_tensors,
        text_lengths=text_lens,
        gold_speech=gold_speech,
        speech_lengths=speech_lens,
        gold_durations=gold_durations,
        gold_pitch=gold_pitch,
        gold_energy=gold_energy,
        utterance_embedding=utterance_embedding,
        lang_ids=lang_ids,
    )

    print(outs)

(tensor(4.4390), tensor(1.1356), tensor(1.0728), tensor(0.4119), tensor(0.7976), tensor(0.6318), tensor(0.0305), tensor(0.3588))
