# StyleTTS 2 Stage 1 Inference (for korean model)

### Utils

In [1]:
import torch
torch.manual_seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

import random
random.seed(0)

import numpy as np
np.random.seed(0)

In [None]:
%cd ..

In [3]:
# load packages
import time
import random
import yaml
from munch import Munch
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
import librosa

from models import *
from utils import *

%matplotlib inline

In [4]:
import IPython.display as ipd

to_mel = torchaudio.transforms.MelSpectrogram(
    n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
mean, std = -4, 4

def length_to_mask(lengths):
    mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
    mask = torch.gt(mask+1, lengths.unsqueeze(1))
    return mask

def preprocess(wave):
    wave_tensor = torch.from_numpy(wave).float()
    mel_tensor = to_mel(wave_tensor)
    mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
    return mel_tensor

def compute_style(path):
    audio, sr = librosa.load(path, sr=24000)
    # audio, index = librosa.effects.trim(audio, top_db=30)
    # if sr != 24000:
    #     audio = librosa.resample(audio, sr, 24000)
    mel_tensor = preprocess(audio).to(device)

    with torch.no_grad():
        ref_s = model.style_encoder(mel_tensor.unsqueeze(1))
        ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))

    return torch.cat([ref_s, ref_p], dim=1)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

### Load models

In [None]:
# config = yaml.safe_load(open("Models/LibriTTS_vocos/config_libritts_vocos.yml"))
config = yaml.safe_load(open("/data/ckpts/stts2/LibriTTS_vocos/config_libritts_vocos.yml"))

# load pretrained ASR model
ASR_config = config.get('ASR_config', False)
ASR_path = config.get('ASR_path', False)
text_aligner = load_ASR_models(ASR_path, ASR_config)

# load pretrained F0 model
F0_path = config.get('F0_path', False)
pitch_extractor = load_F0_models(F0_path)

# load BERT model
from Utils.PLBERT.util import load_plbert
BERT_path = config.get('PLBERT_dir', False)
plbert = load_plbert(BERT_path)

model_params = recursive_munch(config['model_params'])
model = build_model_no_bert(model_params, text_aligner, pitch_extractor)
# If you want to use BERT, you can use the following code
# model = build_model(model_params, text_aligner, pitch_extractor, plbert)

_ = [model[key].eval() for key in model]
_ = [model[key].to(device) for key in model]

# params_whole = torch.load("Models/LibriTTS_vocos/epoch_2nd_00029.pth", map_location='cpu')
params_whole = torch.load("/data/ckpts/stts2/LibriTTS_vocos/epoch_2nd_00029.pth", map_location='cpu')

params = params_whole['net']
ignore_modules = []

for key in model:
    if key in params and key not in ignore_modules:
        print('%s loaded' % key)
        try:
            model[key].load_state_dict(params[key], strict=True)
        except:
            from collections import OrderedDict
            state_dict = params[key]
            new_state_dict = OrderedDict()
            print(f'{key} key lenghth: {len(model[key].state_dict().keys())}, state_dict length: {len(state_dict.keys())}')
            for (k_m, v_m), (k_c, v_c) in zip(model[key].state_dict().items(), state_dict.items()):
                new_state_dict[k_m] = v_c
            model[key].load_state_dict(new_state_dict, strict=True)
_ = [model[key].eval() for key in model]


In [8]:
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule

sampler = DiffusionSampler(
    model.diffusion.diffusion,
    sampler=ADPM2Sampler(),
    sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
    clamp=False
)

### Synthesize speech

In [9]:
import phonemizer

global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True,  with_stress=True)

In [10]:
from symbols_en import symbols
# If you want to use multi lingual vocabs, you can use the following code
# from symbols import symbols
from nltk.tokenize import word_tokenize

dicts = {}
for i in range(len((symbols))):
    dicts[symbols[i]] = i

class TextCleaner:
    def __init__(self, dummy=None):
        self.word_index_dictionary = dicts
    def __call__(self, text, cleaned=True):
        indexes = []
        if not cleaned:
            ps = global_phonemizer.phonemize([text])
            ps = word_tokenize(ps[0])
            ps = ' '.join(ps)
        else:
            ps = text

        for char in ps:
            try:
                indexes.append(self.word_index_dictionary[char])
            except KeyError:
                print(f"Unknown character: {char}")
                print(text)
        return indexes

textclenaer = TextCleaner()

In [11]:
import time

def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, is_cleaned=False):
    text = text.strip()
    tokens = textclenaer(text, cleaned=is_cleaned)
    tokens.insert(0, 0)
    tokens.append(0)
    # print(f'Tokens:', tokens)
    print(f'The length of tokens is:', len(tokens))
    tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
    
    with torch.no_grad():
        input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
        text_mask = length_to_mask(input_lengths).to(device)

        t_en = model.text_encoder(tokens, input_lengths, text_mask)
        d_en = model.prosodic_text_encoder(tokens, input_lengths, text_mask)
        d_en_dur = d_en.transpose(-1, -2)


        s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), 
                                          embedding=d_en_dur,
                                          embedding_scale=embedding_scale,
                                            features=ref_s, # reference from the same speaker as the embedding
                                             num_steps=diffusion_steps).squeeze(1)


        s = s_pred[:, 128:]
        ref = s_pred[:, :128]

        ref = alpha * ref + (1 - alpha) * ref_s[:, :128]

        s = beta * s + (1 - beta) * ref_s[:, 128:]
        d = model.predictor.text_encoder(d_en, 
                                        s, input_lengths, text_mask)
        x, _ = model.predictor.lstm(d)  

        duration = model.predictor.duration_proj(x)
        duration = torch.sigmoid(duration).sum(axis=-1)
        pred_dur = torch.round(duration.squeeze()).clamp(min=1)

        pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
        c_frame = 0
        for i in range(pred_aln_trg.size(0)):
            pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
            c_frame += int(pred_dur[i].data)

        # encode prosody
        en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
        # if model_params.decoder.type == "hifigan":
        if True:
            asr_new = torch.zeros_like(en)
            asr_new[:, :, 0] = en[:, :, 0]
            asr_new[:, :, 1:] = en[:, :, 0:-1]
            en = asr_new

        F0_pred, N_pred = model.predictor.F0Ntrain(en, s)

        asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
        # if model_params.decoder.type == "hifigan":
        if True:
            asr_new = torch.zeros_like(asr)
            asr_new[:, :, 0] = asr[:, :, 0]
            asr_new[:, :, 1:] = asr[:, :, 0:-1]
            asr = asr_new

        start = time.time()
        out = model.decoder(asr, 
                                F0_pred, N_pred, ref.squeeze().unsqueeze(0))
        end = time.time()
        print(f"model.decoder time:", end-start)
        
        # plot one s2s attention

        # plt.figure(figsize=(10, 5))
        # plt.imshow(pred_aln_trg.squeeze().cpu().numpy(), aspect='auto', origin='lower')
        # plt.colorbar()
        # plt.show()
        
    return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later

In [None]:
import shutil
import soundfile as sf
import IPython.display as ipd

# lines = [
#     "LibriTTS/train-clean-360/100/121669/100_121669_000031_000001.wav|hiː wʌz vˈɛɹi ˈæŋɡɹi, ˌɪndˈiːd, fɚðə pˈɪɡ wʌzɐ ɡɹˈeɪt pˈɛt, ænd hiː hæd wˈɔntᵻd tə kˈiːp ɪt tˈɪl ɪt ɡɹˈuː vˈɛɹi bˈɪɡ.|1079",
#     "LibriTTS/train-clean-360/100/121669/100_121669_000032_000000.wav|sˌoʊ hiː pˌʊt ˌɔn hɪz kˈoʊt ænd bˈʌkəld ɐ stɹˈæp ɚɹˈaʊnd hɪz wˈeɪst, ænd wɛnt dˌaʊn tə ðə vˈɪlɪdʒ tə sˈiː ɪf hiː kʊd fˈaɪnd ˈaʊt hˌuː hæd stˈoʊlən hɪz pˈɪɡ.|1079",
#     "LibriTTS/train-clean-360/100/121669/100_121669_000033_000000.wav|ˌʌp ænd dˌaʊn ðə stɹˈiːt hiː wˈɛnt, ænd ɪn ænd ˈaʊt ðə lˈeɪnz, bˌʌt nˈoʊ tɹˈeɪsᵻz ʌvðə pˈɪɡ kʊd hiː fˈaɪnd ˈɛnɪwˌɛɹ.|1079",
#     "LibriTTS/train-clean-360/100/121669/100_121669_000033_000001.wav|ænd ðæt wʌz nˈoʊ ɡɹˈeɪt wˈʌndɚ, fɚðə pˈɪɡ wʌz ˈiːʔn̩ baɪ ðæt tˈaɪm ænd ɪts bˈoʊnz pˈɪkt klˈiːn.|1079",
# ]
# is_cleaned = True

lines = [
    "LibriTTS/train-clean-360/100/121669/100_121669_000031_000001.wav|The quick brown fox jumps over the lazy dog.|1079",
    "LibriTTS/train-clean-360/100/121669/100_121669_000032_000000.wav|She sells seashells by the seashore|1079",
    "LibriTTS/train-clean-360/100/121669/100_121669_000033_000000.wav|Peter Piper picked a peck of pickled peppers.|1079",
    "LibriTTS/train-clean-360/100/121669/100_121669_000033_000001.wav|How much wood would a woodchuck chuck if a woodchuck could chuck wood?|1079",
]
is_cleaned = False

silence = torch.zeros(24000 // 2).numpy()

download_dir_name = 'Outputs/libritts_vocos_stage2'

# wav_dir = 'Your wav directory'
wav_dir = '/data/LibriTTS/'

for i, lines in enumerate(lines):
    wavs = []
    wav_path, text, _ = lines.split('|')
    print('Text:', text)
    language = text[1:3]
    ref_wav_path = os.path.join(wav_dir, wav_path)
    ref_s = compute_style(ref_wav_path)
    wav = inference(text, ref_s, alpha=0.3, beta=0.7, diffusion_steps=15, is_cleaned=is_cleaned)

    ref = librosa.load(ref_wav_path, sr=24000)[0]
    wavs.append(ref)
    wavs.append(silence)
    wavs.append(wav)

    res = np.concatenate(wavs, axis=0)
    
    filenpath = os.path.join(download_dir_name, os.path.basename(wav_path))
    filenpath = filenpath.replace('.wav', f'_{i}.wav')
    os.makedirs(os.path.dirname(filenpath), exist_ok=True)

    # Play the Reference, Silence, and Synthesized
    sf.write(filenpath, res, 24000)
    # shutil.copy(ref_wav_path, download_dir_name)

    # print('Synthesized:', i)
    # display(ipd.Audio(wav, rate=24000, normalize=False))

    # print('Reference:')
    # display(ipd.Audio(ref_wav_path, rate=24000, normalize=False))