In [1]:
import os, sys
from IPython.display import Audio

sys.path.append("../")

import torch
import torchaudio
import soundfile as sf
import numpy as np
from tqdm import tqdm
from glob import glob
from numpy import trim_zeros

from src.spk_embedding.StyleEmbedding import StyleEmbedding
from src.tts.vocoders.hifigan.HiFiGAN import HiFiGANGenerator
from src.tts.models.fastporta.FastPorta import FastPorta
from src.tts.models.fastporta.FastPorta2 import FastPorta2
from src.tts.models.fastporta.FastPorta3 import FastPorta3
from src.tts.models.fastporta.FastPortaVAE import FastPortaVAE
from src.tts.models.fastspeech2.FastSpeech2 import FastSpeech2
from src.datasets.fastspeech_dataset import (
    FastSpeechDataset,
    build_path_to_transcript_dict_libri_tts,
)
from src.utility.tokenizer import ArticulatoryCombinedTextFrontend as Tokenizer

from src.preprocessing.audio_processing import AudioPreprocessor

from src.pipelines.fastporta.train_loop import collate_and_pad

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [2]:
TEST_CLEAN_PATH = '../data/test-clean'

AVOCODO_CHECKPOINT = "../saved_models/Avocodo.pt"
ALIGNER_CHECKPOINT = "../saved_models/aligner.pt"
FASTPORTA_CHECKPOINT = "../saved_models/fastporta/fastporta_checkpoint_lastest.pt"
FASTPORTA2_CHECKPOINT = "../saved_models/fastporta/fastporta2_parallel_p1-0.5_checkpoint_lastest.pt"
FASTPORTA3_CHECKPOINT = "../saved_models/fastporta/fastporta4_checkpoint_260.pt"
FASTPORTAVAE_CHECKPOINT = "../saved_models/fastporta/fastporta-vae_checkpoint_lastest.pt"
FASTSPEECH2_CHECKPOINT = "../saved_models/fastspeech2/fastspeech2_checkpoint_lastest.pt"
STYLE_EMBED_CHECKPOINT = "../saved_models/embedding_function.pt"

In [None]:
transcript_dict = build_path_to_transcript_dict_libri_tts(TEST_CLEAN_PATH)

In [None]:
dataset = FastSpeechDataset(
    path_to_transcript_dict=transcript_dict,
    acoustic_checkpoint_path=ALIGNER_CHECKPOINT,  # path to aligner.pt
    cache_dir="./librispeech",
    lang="en",
    loading_processes=2,  # depended on how many CPU you have
    device=device,
)

In [None]:
vocoder = HiFiGANGenerator().to(device)
avocodo_check_dict = torch.load(AVOCODO_CHECKPOINT, map_location=device)
vocoder.load_state_dict(avocodo_check_dict["generator"])
vocoder.eval()

style_embed_function = StyleEmbedding().to(device)
style_embed_check_dict = torch.load(STYLE_EMBED_CHECKPOINT, map_location=device)
style_embed_function.load_state_dict(style_embed_check_dict["style_emb_func"])
style_embed_function.eval()
style_embed_function.requires_grad_(False)

acoustic_model = FastPorta3().to(device)
check_dict = torch.load(FASTPORTA3_CHECKPOINT, map_location=device)
# acoustic_model = FastPorta2(mix_style_p=0).to(device)
# check_dict = torch.load(FASTPORTA2_CHECKPOINT, map_location=device)
# acoustic_model = FastPorta3().to(device)
# check_dict = torch.load(FASTPORTA3_CHECKPOINT, map_location=device)
# acoustic_model = FastPortaVAE().to(device)
# check_dict = torch.load(FASTPORTAVAE_CHECKPOINT, map_location=device)
# acoustic_model = FastSpeech2().to(device)
# check_dict = torch.load(FASTSPEECH2_CHECKPOINT, map_location=device)
acoustic_model.load_state_dict(check_dict["model"])
acoustic_model.eval()

In [None]:
# Ref audio

sample_id = 15
sample = dataset[sample_id]
input_audio_path = sample[-1]
input_wave, sr = sf.read(input_audio_path)
input_text = transcript_dict[input_audio_path]
batch = collate_and_pad([sample])
print("Path: ", input_audio_path)
print("Text: ", input_text)
Audio(data=input_wave, rate=sr)

In [None]:
# Another text

text_sample_id = 4
text_sample = dataset[text_sample_id]
text_batch=collate_and_pad([text_sample])
text_input = transcript_dict[text_sample[-1]]
print(text_input)


In [None]:
style_embedding = style_embed_function(
    batch_of_spectrograms=batch[2].to(device),
    batch_of_spectrogram_lengths=batch[3].to(device),
)

mel = acoustic_model.inference(
    text=text_batch[0][0].to(device),
    speech=None,
    alpha=1.0,
    utterance_embedding=style_embedding[0],
    return_duration_pitch_energy=False,
    lang_id=batch[8][0].to(device),
)

waveform = vocoder(mel.transpose(1, 0))[0]
waveform = waveform.detach().cpu()

In [None]:

torchaudio.save(
    'origin.wav',
    src=torch.Tensor(input_wave).unsqueeze(0),
    sample_rate=16000
)

In [None]:
torchaudio.save(
    'synth.wav',
    src=waveform,
    sample_rate=24000
)

In [24]:
# PATH_TO_CHECKPOINTS = '../saved_models/fastporta'
PATH_TO_CHECKPOINTS = '../saved_models/fastspeech2'


In [26]:
acoustic_model = FastSpeech2().to(device)
acoustic_model.load_state_dict(torch.load(os.path.join(PATH_TO_CHECKPOINTS, 'checkpoint_lastest.pt'), 
                                     map_location=device)["model"])
acoustic_model.eval()

FastSpeech2(
  (encoder): Conformer(
    (embed): Sequential(
      (0): Linear(in_features=62, out_features=100, bias=True)
      (1): Tanh()
      (2): Linear(in_features=100, out_features=384, bias=True)
    )
    (pos_enc): RelPositionalEncoding(
      (dropout): Dropout(p=0.2, inplace=False)
    )
    (output_norm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
    (hs_emb_projection): Linear(in_features=448, out_features=384, bias=True)
    (language_embedding): Embedding(8000, 384)
    (encoders): MultiSequential(
      (0): EncoderLayer(
        (self_attn): RelPositionMultiHeadedAttention(
          (linear_q): Linear(in_features=384, out_features=384, bias=True)
          (linear_k): Linear(in_features=384, out_features=384, bias=True)
          (linear_v): Linear(in_features=384, out_features=384, bias=True)
          (linear_out): Linear(in_features=384, out_features=384, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
          (linear_pos): Linear(in

In [27]:
STYLE_EMBED_CHECKPOINT='../saved_models/embedding_function.pt'
AVOCODO_CHECKPOINT='../saved_models/Avocodo.pt'

style_embed_function = StyleEmbedding().to(device)
style_embed_check_dict = torch.load(STYLE_EMBED_CHECKPOINT, map_location=device)
style_embed_function.load_state_dict(style_embed_check_dict["style_emb_func"])
style_embed_function.eval()
style_embed_function.requires_grad_(False)

vocoder = HiFiGANGenerator().to(device)
avocodo_check_dict = torch.load(AVOCODO_CHECKPOINT, map_location=device)
vocoder.load_state_dict(avocodo_check_dict["generator"])
vocoder.eval()

ap = AudioPreprocessor(input_sr=16000, output_sr=16000, melspec_buckets=80,
            hop_length=256,n_fft=1024,cut_silence=False,device=device)

In [28]:
def inference(text, ref_path, acoustic_model, ap, style_embed_function, vocoder, alpha=1.0, lang="en"):
    wave, sr = sf.read(ref_path)
    norm_wave = ap.audio_to_wave_tensor(normalize=True, audio=wave)

    norm_wave = torch.tensor(trim_zeros(norm_wave.numpy()))
    cached_speech = ap.audio_to_mel_spec_tensor(
        audio=norm_wave, normalize=False, explicit_sampling_rate=16000
    ).transpose(0, 1)

    cached_speech_len = torch.LongTensor([len(cached_speech)])
    cached_speech = cached_speech.unsqueeze(0)
    tokenizer = Tokenizer(language=lang)
    embed_text = tokenizer.string_to_tensor(
        text, handle_missing=False, input_phonemes=False
    )

    style_embedding = style_embed_function(
        batch_of_spectrograms=cached_speech.to(device),
        batch_of_spectrogram_lengths=cached_speech_len.to(device),
    )

    mel = acoustic_model.inference(
        text=embed_text.to(device),
        speech=None,
        alpha=alpha,
        utterance_embedding=style_embedding[0],
        return_duration_pitch_energy=False,
        lang_id=torch.Tensor([[12]])[0].to(dtype=torch.int64, device=device),
    )
    waveform = vocoder(mel.transpose(1, 0))[0]
    waveform = waveform.detach().cpu()

    return waveform

In [38]:
text = 'The late astounding events, however, had rendered Procope manifestly uneasy, and not the less so from his consciousness that the count secretly partook of his own anxiety'
ref_path = '../data/test-clean/5105/28241/5105-28241-0009.flac'

In [39]:
waveform = inference(text, ref_path, acoustic_model, ap, style_embed_function, vocoder, alpha=1.0, lang="en")

In [40]:
torchaudio.save(
    'synth.wav',
    src=waveform,
    sample_rate=24000
)