In [None]:
import sys
sys.path.append("../")

import torch
import torchaudio
import librosa
import numpy as np
from numpy import trim_zeros


from src.datasets.fastspeech_dataset import (
    build_path_to_transcript_dict_libri_tts,
    FastSpeechDataset)
from src.tts.models.fastporta.FastPorta import FastPorta
from src.pipelines.fastporta.train_loop import train_loop
from src.spk_embedding.StyleEmbedding import StyleEmbedding

device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
TEST_CLEAN_PATH = '../data/test-clean'
ALIGNER_CHECKPOINT = '../saved_models/aligner.pt'
EMBED_MODEL = "../saved_models/embedding_function.pt"

In [None]:
transcript_dict = build_path_to_transcript_dict_libri_tts(TEST_CLEAN_PATH)

In [None]:
dataset = FastSpeechDataset(
    path_to_transcript_dict=transcript_dict,
    acoustic_checkpoint_path=ALIGNER_CHECKPOINT,  # path to aligner.pt
    cache_dir="./librispeech",
    lang="en",
    loading_processes=2,  # depended on how many CPU you have
    device=device,
)

In [None]:
net = FastPorta()

In [None]:
train_loop(
    net,
    dataset,
    device=device,
    batch_size=2,
    save_directory="../saved_models",
    path_to_checkpoint=None,
    resume=False,
    phase_1_steps=4,
    phase_2_steps=0,
    steps_per_save=1,
    path_to_embed_model="../saved_models/embedding_function.pt",
    lr=0.01,
)

In [None]:
text_tensors=torch.rand(2,5,62)
text_lens=torch.tensor([5,3])
speech_lens=torch.tensor([10,7])
gold_speech=torch.rand(2,10,80)
gold_durations=torch.tensor([
    [2,2,2,2,2],
    [2,2,3,0,0]
])
gold_pitch=torch.rand(2,5,1)
gold_energy=torch.rand(2,5,1)
is_inference=False
alpha=1.0
utterance_embedding=torch.rand(2,64)
lang_ids=None

In [None]:
outs = net._forward(
    text_tensors=text_tensors,
    text_lens=text_lens,
    utterance_embedding=utterance_embedding,
    is_inference=True,
    lang_ids=lang_ids,
)

for v in outs:
    print(v.shape if v != None else None)

In [None]:
outs = net._forward(
    text_tensors=text_tensors,
    text_lens=text_lens,
    gold_speech=gold_speech,
    speech_lens=speech_lens,
    gold_durations=gold_durations,
    gold_pitch=gold_pitch,
    gold_energy=gold_energy,
    utterance_embedding=utterance_embedding,
    is_inference=False,
    lang_ids=lang_ids,
)

for v in outs:
    print(v.shape if v != None else None)

In [None]:
with torch.no_grad():
    outs = net.forward(
        text_tensors=text_tensors,
        text_lengths=text_lens,
        gold_speech=gold_speech,
        speech_lengths=speech_lens,
        gold_durations=gold_durations,
        gold_pitch=gold_pitch,
        gold_energy=gold_energy,
        utterance_embedding=utterance_embedding,
        lang_ids=lang_ids,
    )

    print(outs)