In [None]:
import os, sys
from IPython.display import Audio

sys.path.append("../")

import torch
import torchaudio
import librosa
import numpy as np

from src.spk_embedding.StyleEmbedding import StyleEmbedding
from src.tts.vocoders.hifigan.HiFiGAN import HiFiGANGenerator
from src.tts.models.fastporta.FastPorta import FastPorta
from src.datasets.fastspeech_dataset import (
    FastSpeechDataset,
    build_path_to_transcript_dict_libri_tts,
)
from src.pipelines.fastporta.train_loop import collate_and_pad

device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
TEST_CLEAN_PATH = '../data/test-clean'

AVOCODO_CHECKPOINT = "../saved_models/Avocodo.pt"
ALIGNER_CHECKPOINT = "../saved_models/aligner.pt"
FASTPORTA_CHECKPOINT = "../saved_models/fastporta/checkpoint_lastest.pt"
STYLE_EMBED_CHECKPOINT = "../saved_models/embedding_function.pt"

In [None]:
transcript_dict = build_path_to_transcript_dict_libri_tts(TEST_CLEAN_PATH)

In [None]:
dataset = FastSpeechDataset(
    path_to_transcript_dict=transcript_dict,
    acoustic_checkpoint_path=ALIGNER_CHECKPOINT,  # path to aligner.pt
    cache_dir="./librispeech",
    lang="en",
    loading_processes=2,  # depended on how many CPU you have
    device=device,
)

In [None]:
vocoder = HiFiGANGenerator().to(device)
avocodo_check_dict = torch.load(AVOCODO_CHECKPOINT, map_location=device)
vocoder.load_state_dict(avocodo_check_dict["generator"])
vocoder.eval()

style_embed_function = StyleEmbedding().to(device)
style_embed_check_dict = torch.load(STYLE_EMBED_CHECKPOINT, map_location=device)
style_embed_function.load_state_dict(style_embed_check_dict["style_emb_func"])
style_embed_function.eval()
style_embed_function.requires_grad_(False)

acoustic_model = FastPorta().to(device)
fastspeech2_check_dict = torch.load(FASTPORTA_CHECKPOINT, map_location=device)
acoustic_model.load_state_dict(fastspeech2_check_dict["model"])
acoustic_model.eval()

In [None]:
sample_id = 1
sample = dataset[sample_id]
input_audio_path = sample[-1]
input_wave, _ = librosa.load(input_audio_path)
input_text = transcript_dict[input_audio_path]
batch = collate_and_pad([sample])

print(input_audio_path)
Audio(data=input_wave, rate=24000)

In [None]:
style_embedding = style_embed_function(
    batch_of_spectrograms=batch[2].to(device),
    batch_of_spectrogram_lengths=batch[3].to(device),
)

mel = acoustic_model.inference(
    text=batch[0][0].to(device),
    speech=None,
    alpha=1.0,
    utterance_embedding=style_embedding[0],
    return_duration_pitch_energy=False,
    lang_id=batch[8][0].to(device),
)

waveform = vocoder(mel.transpose(1, 0))[0]
waveform = waveform.detach().cpu()

In [None]:
mel.shape

In [None]:
print(input_text)
Audio(data=waveform, rate=24000, autoplay=True)

In [None]:
# torchaudio.save(
#     'synth.wav',
#     src=waveform,
#     sample_rate=16000
# )

# torchaudio.save(
#     'origin.wav',
#     src=torch.Tensor(input_wave).unsqueeze(0),
#     sample_rate=16000
# )