# StyleTTS 2 Stage 1 Inference (for korean model)

### Utils

In [1]:
import torch
torch.manual_seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

import random
random.seed(0)

import numpy as np
np.random.seed(0)

In [None]:
%cd ..

In [3]:
# load packages
import time
import random
import yaml
from munch import Munch
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
import librosa

from models import *
from utils import *

%matplotlib inline

In [4]:
import IPython.display as ipd

to_mel = torchaudio.transforms.MelSpectrogram(
    n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
mean, std = -4, 4

def length_to_mask(lengths):
    mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
    mask = torch.gt(mask+1, lengths.unsqueeze(1))
    return mask

def preprocess(path):
    wave, sr = librosa.load(path, sr=24000)
    # wave, index = librosa.effects.trim(wave, top_db=30)
    if sr != 24000:
        wave = librosa.resample(wave, sr, 24000)
    wave = np.concatenate([np.zeros([5000]), wave, np.zeros([5000])], axis=0)

    # display(ipd.Audio(wave, rate=sr, normalize=False))

    wave_tensor = torch.from_numpy(wave).float()
    mel_tensor = to_mel(wave_tensor)
    mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
    return mel_tensor


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

### Load models

In [None]:
# config = yaml.safe_load(open("Models/LibriTTS_vocos/config_libritts_vocos.yml"))
config = yaml.safe_load(open("/data/ckpts/stts2/LibriTTS_vocos/config_libritts_vocos.yml"))

# load pretrained ASR model
ASR_config = config.get('ASR_config', False)
ASR_path = config.get('ASR_path', False)
text_aligner = load_ASR_models(ASR_path, ASR_config)

# load pretrained F0 model
F0_path = config.get('F0_path', False)
pitch_extractor = load_F0_models(F0_path)

# load BERT model
from Utils.PLBERT.util import load_plbert
BERT_path = config.get('PLBERT_dir', False)
plbert = load_plbert(BERT_path)

In [None]:
model_params = recursive_munch(config['model_params'])
# model = build_model_no_bert(model_params, text_aligner, pitch_extractor)
model = build_model(model_params, text_aligner, pitch_extractor, plbert)
_ = [model[key].eval() for key in model]
_ = [model[key].to(device) for key in model]

In [None]:
# Load first stage checkpoint path
# params_whole = torch.load("Models/LibriTTS_vocos/LibriTTS_vocos_first_stage.pth", map_location='cpu')
params_whole = torch.load("/data/ckpts/stts2/LibriTTS_vocos/LibriTTS_vocos_first_stage.pth", map_location='cpu')

params = params_whole['net']

In [None]:
ignore_modules = ['diffusion', 'wd', 'bert', 'bert_encoder',]

for key in model:
    if key in params and key not in ignore_modules:
        print('%s loaded' % key)
        try:
            model[key].load_state_dict(params[key], strict=True)
        except:
            from collections import OrderedDict
            state_dict = params[key]
            new_state_dict = OrderedDict()
            print(f'{key} key lenghth: {len(model[key].state_dict().keys())}, state_dict length: {len(state_dict.keys())}')
            for (k_m, v_m), (k_c, v_c) in zip(model[key].state_dict().items(), state_dict.items()):
                new_state_dict[k_m] = v_c
            model[key].load_state_dict(new_state_dict, strict=True)
_ = [model[key].eval() for key in model]


### Synthesize speech

In [11]:
from symbols_en import symbols
# If you want to use multi lingual vocabs
# from symbols import symbols

dicts = {}
for i in range(len((symbols))):
    dicts[symbols[i]] = i

class TextCleaner:
    def __init__(self, dummy=None):
        self.word_index_dictionary = dicts
    def __call__(self, text, cleaned=True):
        indexes = []
        for char in text:
            try:
                indexes.append(self.word_index_dictionary[char])
            except KeyError:
                print(f"Unknown character: {char}")
                print(text)
        return indexes

textclenaer = TextCleaner()

In [12]:
import matplotlib.pyplot as plt
# in case not distributed
try:
    n_down = model.text_aligner.module.n_down
except:
    n_down = model.text_aligner.n_down
    

def inference(text, ref_wav_path):
    # text = text.strip()
    tokens = textclenaer(text, cleaned=True)
    tokens.insert(0, 0)
    tokens.append(0)
    tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
    
    mels = preprocess(ref_wav_path).to(device)
    mels = mels[:, :, :mels.size(-1) - 1]
    mel_input_length = torch.zeros(1).long().to(device)
    mel_input_length[0] = mels.size(-1)
    print(mel_input_length)

    with torch.no_grad():
        input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
        mel_input_length = mel_input_length // (2 ** n_down)
        print(mel_input_length)
        mels = mels[..., :mel_input_length * (2 ** n_down)]
        mask = length_to_mask(mel_input_length).to('cuda')

        ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, tokens)

        _, amax_s2s = torch.max(s2s_pred, dim=2)

        print(''.join([symbols[s1.item()] for s1 in amax_s2s[0]]))
        print(''.join([symbols[s2.item()] for s2 in tokens[0]]))

        s2s_attn = s2s_attn.transpose(-1, -2)
        s2s_attn = s2s_attn[..., 1:]
        s2s_attn = s2s_attn.transpose(-1, -2)

        text_mask = length_to_mask(input_lengths).to(tokens.device)
        attn_mask = (~mask).unsqueeze(-1).expand(mask.shape[0], mask.shape[1], text_mask.shape[-1]).float().transpose(-1, -2)
        attn_mask = attn_mask.float() * (~text_mask).unsqueeze(-1).expand(text_mask.shape[0], text_mask.shape[1], mask.shape[-1]).float()
        attn_mask = (attn_mask < 1)
        s2s_attn.masked_fill_(attn_mask, 0.0)


        mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length)
        s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
        # s2s_attn = s2s_attn_mono

        # encode
        t_en = model.text_encoder(tokens, input_lengths, text_mask)
        
        asr = (t_en @ s2s_attn)

        F0_real, _, F0 = model.pitch_extractor(mels.unsqueeze(1))
        F0_real = F0_real.unsqueeze(0)
        s = model.style_encoder(mels.unsqueeze(1))
        real_norm = log_norm(mels.unsqueeze(1)).squeeze(1)
        # out = model.decoder(asr, F0_real, real_norm, s)
        out = model.decoder(asr, F0_real, real_norm, s)

        # plot one s2s attention

        plt.figure(figsize=(10, 5))
        plt.imshow(s2s_attn.squeeze().cpu().numpy(), aspect='auto', origin='lower')
        plt.colorbar()
        plt.show()

        
    return out.squeeze().cpu().numpy()[..., :-50], s2s_attn, mels # weird pulse at the end of the model, need to be fixed later

In [None]:
import IPython.display as ipd

# These lines are from the validation set
lines = [
    "LibriTTS/train-clean-360/1392/140654/1392_140654_000034_000002.wav|lˈiːv ðə sˈɪnz ʌvðə bˈɑːdi, ænd wɪð ðˌaɪ bˈɑːdi pɹˈæktɪs vˈɜːtʃuː!|1030",
    "LibriTTS/train-clean-360/14/208/14_208_000017_000000.wav|ɪt wʌz tˈuː dˈɜːɾi fɔːɹ mˈɪsɪz ˈælən tʊ ɐkˈʌmpəni hɜː hˈʌsbənd tə ðə pˈʌmp ɹˈuːm; hiː ɐkˈoːɹdɪŋli sˈɛt ˈɔf baɪ hɪmsˈɛlf, ænd kˈæθɹɪn hæd bˈɛɹli wˈɑːtʃt hˌɪm dˌaʊn ðə stɹˈiːt wɛn hɜː nˈoʊɾɪs wʌz klˈeɪmd baɪ ðɪ ɐpɹˈoʊtʃ ʌvðə sˈeɪm tˈuː ˈoʊpən kˈæɹɪdʒᵻz, kəntˈeɪnɪŋ ðə sˈeɪm θɹˈiː pˈiːpəl ðæt hæd sɚpɹˈaɪzd hɜː sˈoʊ mˌʌtʃ ɐ fjˈuː mˈɔːɹnɪŋz bˈæk.|1119",
    "LibriTTS/train-clean-360/1401/146770/1401_146770_000031_000004.wav|hiː hɐdbɪn tə tʃˈɛdkuːm, ænd wʌz kˈʌmɪŋ bˈæk.|422",
    "LibriTTS/train-clean-360/1401/174511/1401_174511_000039_000000.wav|hiː ˈʌɾɚd ɐ ɡɹˈaʊl ænd ðˈɛn θɹˈuː bˈæk hɪz kˈoʊt, dɪsplˈeɪɪŋ ɐ bˈædʒ ɐtˈætʃt tə hɪz vˈɛst.|422",
]

wav_dir = '/data/LibriTTS/'
silence = torch.zeros(24000 // 2).numpy()
for i, lines in enumerate(lines):
    wavs = []
    wav_path, text, _ = lines.split('|')
    ref_wav_path = os.path.join(wav_dir, wav_path)
    # ref_wav_path = wav_path
    print('Text:', text)
    wav, s2s_attn, mels = inference(text, ref_wav_path)
    ref, sr = librosa.load(ref_wav_path, sr=24000)
    
    wavs.append(ref)
    wavs.append(silence)
    wavs.append(wav)

    res = np.concatenate(wavs, axis=0)

    # Play the Reference, Silence, and Synthesized
    print('Synthesized:', i)
    display(ipd.Audio(res, rate=24000, normalize=True))
