# StyleTTS Demo (LJSpeech)


### Utils

In [3]:
import sys
import os

# Add current directory explicitly
sys.path.insert(0, os.getcwd())

# Print to verify Python sees the correct directory
print("Current Working Directory:", os.getcwd())
print("Python Path:", sys.path)

# Now import
from models import *  
from utils import *


Current Working Directory: c:\Users\garym\OneDrive\Scripts\GM_Alienware\workspaces\StyleTTS
Python Path: ['c:\\Users\\garym\\OneDrive\\Scripts\\GM_Alienware\\workspaces\\StyleTTS', 'c:\\Users\\garym\\OneDrive\\Scripts\\GM_Alienware\\opt\\Conda\\envs\\styletts\\python39.zip', 'c:\\Users\\garym\\OneDrive\\Scripts\\GM_Alienware\\opt\\Conda\\envs\\styletts\\DLLs', 'c:\\Users\\garym\\OneDrive\\Scripts\\GM_Alienware\\opt\\Conda\\envs\\styletts\\lib', 'c:\\Users\\garym\\OneDrive\\Scripts\\GM_Alienware\\opt\\Conda\\envs\\styletts', '', 'c:\\Users\\garym\\OneDrive\\Scripts\\GM_Alienware\\opt\\Conda\\envs\\styletts\\lib\\site-packages', 'c:\\Users\\garym\\OneDrive\\Scripts\\GM_Alienware\\opt\\Conda\\envs\\styletts\\lib\\site-packages\\win32', 'c:\\Users\\garym\\OneDrive\\Scripts\\GM_Alienware\\opt\\Conda\\envs\\styletts\\lib\\site-packages\\win32\\lib', 'c:\\Users\\garym\\OneDrive\\Scripts\\GM_Alienware\\opt\\Conda\\envs\\styletts\\lib\\site-packages\\Pythonwin']


In [4]:
# Imports and setup for garo_Inference_LJSpeech.ipynb
import os
import sys

sys.path.insert(0, os.getcwd())  # Ensure the working directory is first in sys.path

import random
import yaml
from munch import Munch
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
import librosa
from nltk.tokenize import word_tokenize
from models import *  # This should now work correctly
from utils import *
# %matplotlib inline # Keep or remove as needed
print("✅ Imports successful! All modules are loaded.")

✅ Imports successful! All modules are loaded.


In [5]:
# This is mostly for verification, not setup.  The environment
# variables should be set *before* launching the notebook.
from phonemizer.backend import EspeakBackend

try:
    backend = EspeakBackend("en-us")
    print("✅ eSpeak is properly detected by Phonemizer!")
except Exception as e:
    print("❌ eSpeak detection failed!", e)

✅ eSpeak is properly detected by Phonemizer!


In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [7]:
_pad = "$"
_punctuation = ';:,.!?¡¿—…"«»“” '
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"

# Export all symbols:
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)

dicts = {symbols[i]: i for i in range(len(symbols))}

class TextCleaner:
    def __init__(self, dummy=None):
        self.word_index_dictionary = dicts

    def __call__(self, text):
        indexes = []
        for char in text:
            try:
                indexes.append(self.word_index_dictionary[char])
            except KeyError:
                print(f"⚠️ Warning: Character '{char}' not found in dictionary!")
        return indexes

# Initialize the text cleaner
textclenaer = TextCleaner()

# ✅ Test it immediately
test_text = "Hello, world!"
cleaned_text = textclenaer(test_text)

print("✅ Text Cleaner is working! Processed Output:", cleaned_text)


✅ Text Cleaner is working! Processed Output: [24, 47, 54, 54, 57, 3, 16, 65, 57, 60, 54, 46, 5]


In [8]:
# Define Mel Spectrogram transformation
to_mel = torchaudio.transforms.MelSpectrogram(
    n_mels=80, n_fft=2048, win_length=1200, hop_length=300
)
mean, std = -4, 4

# Function to create a mask for padding
def length_to_mask(lengths):
    mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
    mask = torch.gt(mask + 1, lengths.unsqueeze(1))
    return mask

# Preprocess waveform into Mel spectrogram
def preprocess(wave):
    wave_tensor = torch.from_numpy(wave).float()
    mel_tensor = to_mel(wave_tensor)
    mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
    return mel_tensor

# Compute style embeddings for reference audio files
def compute_style(ref_dicts, model):
    reference_embeddings = {}
    for key, path in ref_dicts.items():
        wave, sr = librosa.load(path, sr=24000)
        audio, index = librosa.effects.trim(wave, top_db=30)
        if sr != 24000:
            audio = librosa.resample(audio, sr, 24000)
        mel_tensor = preprocess(audio).to(device)

        with torch.no_grad():
            ref = model.style_encoder(mel_tensor.unsqueeze(1))
        reference_embeddings[key] = (ref.squeeze(1), audio)
    
    return reference_embeddings

# ✅ Test the preprocess function with random noise
test_wave = np.random.randn(24000)  # 1 second of fake audio
mel_output = preprocess(test_wave)

print("✅ Preprocessing test successful! Mel Shape:", mel_output.shape)


✅ Preprocessing test successful! Mel Shape: torch.Size([1, 80, 81])


### Load models

In [9]:
# load phonemizer
import phonemizer
global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True,  with_stress=True)
test_text = "Hello, world!"
phonemes = global_phonemizer.phonemize([test_text])

print("✅ Phonemizer test successful! Output:", phonemes)


✅ Phonemizer test successful! Output: ['həlˈoʊ, wˈɜːld! ']


In [10]:
# load hifi-gan
import sys
sys.path.insert(0, "../Demo/hifi-gan")
sys.path.append(r"C:\Users\garym\OneDrive\Scripts\GM_Alienware\workspaces\StyleTTS\Demo\hifi-gan")

import glob
import os
import argparse
import json
import torch
from scipy.io.wavfile import write
from attrdict import AttrDict
from vocoder import Generator
print("✅ Successfully imported vocoder!")
import librosa
import numpy as np
import torchaudio

h = None

def load_checkpoint(filepath, device):
    assert os.path.isfile(filepath)
    print("Loading '{}'".format(filepath))
    checkpoint_dict = torch.load(filepath, map_location=device, weights_only=True)
    print("Complete.")
    return checkpoint_dict

def scan_checkpoint(cp_dir, prefix):
    pattern = os.path.join(cp_dir, prefix + '*')
    cp_list = glob.glob(pattern)
    if len(cp_list) == 0:
        return ''
    return sorted(cp_list)[-1]

cp_g = scan_checkpoint("Vocoder/", 'g_')

config_file = os.path.join(os.path.split(cp_g)[0], 'config.json')
with open(config_file) as f:
    data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)

device = torch.device(device)
generator = Generator(h).to(device)

state_dict_g = load_checkpoint(cp_g, device)
generator.load_state_dict(state_dict_g['generator'])
generator.eval()
generator.remove_weight_norm()

✅ Successfully imported vocoder!


  WeightNorm.apply(module, name, dim)


Loading 'Vocoder\g_00750000'
Complete.
Removing weight norm...


In [23]:
# Load StyleTTS
model_path = "./Models/LJSpeech/epoch_2nd_00180.pth"
model_config_path = "./Models/LJSpeech/config.yml"

config = yaml.safe_load(open(model_config_path))

# Load pretrained ASR model
ASR_config = config.get('ASR_config', False)
ASR_path = config.get('ASR_path', False)
text_aligner = load_ASR_models(ASR_path, ASR_config)

# Load pretrained F0 model
F0_path = config.get('F0_path', False)
pitch_extractor = load_F0_models(F0_path)

# ✅ Fix: Load a default BERT model since it's missing in config.yml
from transformers import AutoModel, AutoConfig

bert_model_name = "bert-base-uncased"  # Default BERT model
bert_config = AutoConfig.from_pretrained(bert_model_name)
bert = AutoModel.from_pretrained(bert_model_name, config=bert_config)
print(config['model_params'].keys())  # Check which keys exist
if 'decoder' not in config['model_params']:
    print("⚠️ Warning: 'decoder' is missing in config! Using default settings.")
    config['model_params']['decoder'] = Munch({  # ✅ Convert it to Munch here
        'type': 'hifigan',  # Default decoder type
        'resblock_kernel_sizes': [3, 7, 11],
        'upsample_rates': [8, 8, 2, 2],
        'upsample_initial_channel': 512,
        'resblock_dilation_sizes': [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
        'upsample_kernel_sizes': [16, 16, 4, 4],
        'gen_istft_n_fft': 1024,
        'gen_istft_hop_size': 256
    })

# ✅ Add default max_dur if it's missing - and set it to 1 to match pretrained model
if 'max_dur' not in config['model_params']:
    print("⚠️ Warning: 'max_dur' is missing in config! Using default value: 1 (to match pretrained model).")
    config['model_params']['max_dur'] = 1 # Set max_dur to 1 

# ✅ Add default multispeaker if it's missing
if 'multispeaker' not in config['model_params']:
    print("⚠️ Warning: 'multispeaker' is missing in config! Using default value: False.")
    config['model_params']['multispeaker'] = False # Default to single speaker model

# ✅ Add default diffusion config with transformer and embedding_mask_proba
if 'diffusion' not in config['model_params']:
    print("⚠️ Warning: 'diffusion' is missing in config! Using default with transformer and embedding_mask_proba.")
    config['model_params']['diffusion'] = Munch({
        'transformer': Munch({ # Transformer config with default params
            'num_layers': 2,    # Example default value, adjust as needed
            'num_heads': 8,     # Example default value, adjust as needed
            'head_features': 64, # Example default value, adjust as needed
            'multiplier': 4      # Example default value, adjust as needed
        }),
        'dist': Munch({ # Added minimal dist config
            'mean': 0.0,
            'std': 1.0,
            'sigma_data': 0.5 # Example value
        }),
        'embedding_mask_proba': 0.1 # ✅ Default embedding_mask_proba
    })

# ✅ Add default slm config
if 'slm' not in config['model_params']:
    print("⚠️ Warning: 'slm' is missing in config! Using default values.")
    config['model_params']['slm'] = Munch({
        'hidden': 256,        # Example default hidden dimension, adjust as needed
        'nlayers': 2,         # Example default number of layers, adjust as needed
        'initial_channel': 512 # Example default initial channel, adjust as needed
    })


# Now call build_model() with the correct number of arguments
model = build_model(Munch(config['model_params']), text_aligner, pitch_extractor, bert)

# Load model parameters
params = torch.load(model_path, map_location='cpu')
params = params['net']
for key in model:
    if key in params:
        if not "discriminator" in key:
            print('%s loaded' % key)
            model[key].load_state_dict(params[key])

# Set model to evaluation mode
_ = [model[key].eval() for key in model]
_ = [model[key].to(device) for key in model]

dict_keys(['hidden_dim', 'n_token', 'style_dim', 'n_layer', 'dim_in', 'max_conv_dim', 'n_mels', 'dropout'])


  params = torch.load(model_path, map_location='cpu')


predictor loaded
decoder loaded


RuntimeError: Error(s) in loading state_dict for Decoder:
	Missing key(s) in state_dict: "decode.3.conv1x1.weight_g", "decode.3.conv1x1.weight_v", "decode.3.pool.bias", "decode.3.pool.weight_g", "decode.3.pool.weight_v", "encode.conv1.bias", "encode.conv1.weight_g", "encode.conv1.weight_v", "encode.conv2.bias", "encode.conv2.weight_g", "encode.conv2.weight_v", "encode.norm1.fc.weight", "encode.norm1.fc.bias", "encode.norm2.fc.weight", "encode.norm2.fc.bias", "encode.conv1x1.weight_g", "encode.conv1x1.weight_v", "F0_conv.bias", "F0_conv.weight_g", "F0_conv.weight_v", "N_conv.bias", "N_conv.weight_g", "N_conv.weight_v", "generator.m_source.l_linear.weight", "generator.m_source.l_linear.bias", "generator.noise_convs.0.weight", "generator.noise_convs.0.bias", "generator.noise_convs.1.weight", "generator.noise_convs.1.bias", "generator.noise_convs.2.weight", "generator.noise_convs.2.bias", "generator.noise_convs.3.weight", "generator.noise_convs.3.bias", "generator.ups.0.bias", "generator.ups.0.weight_g", "generator.ups.0.weight_v", "generator.ups.1.bias", "generator.ups.1.weight_g", "generator.ups.1.weight_v", "generator.ups.2.bias", "generator.ups.2.weight_g", "generator.ups.2.weight_v", "generator.ups.3.bias", "generator.ups.3.weight_g", "generator.ups.3.weight_v", "generator.noise_res.0.convs1.0.bias", "generator.noise_res.0.convs1.0.weight_g", "generator.noise_res.0.convs1.0.weight_v", "generator.noise_res.0.convs1.1.bias", "generator.noise_res.0.convs1.1.weight_g", "generator.noise_res.0.convs1.1.weight_v", "generator.noise_res.0.convs1.2.bias", "generator.noise_res.0.convs1.2.weight_g", "generator.noise_res.0.convs1.2.weight_v", "generator.noise_res.0.convs2.0.bias", "generator.noise_res.0.convs2.0.weight_g", "generator.noise_res.0.convs2.0.weight_v", "generator.noise_res.0.convs2.1.bias", "generator.noise_res.0.convs2.1.weight_g", "generator.noise_res.0.convs2.1.weight_v", "generator.noise_res.0.convs2.2.bias", "generator.noise_res.0.convs2.2.weight_g", "generator.noise_res.0.convs2.2.weight_v", "generator.noise_res.0.adain1.0.fc.weight", "generator.noise_res.0.adain1.0.fc.bias", "generator.noise_res.0.adain1.1.fc.weight", "generator.noise_res.0.adain1.1.fc.bias", "generator.noise_res.0.adain1.2.fc.weight", "generator.noise_res.0.adain1.2.fc.bias", "generator.noise_res.0.adain2.0.fc.weight", "generator.noise_res.0.adain2.0.fc.bias", "generator.noise_res.0.adain2.1.fc.weight", "generator.noise_res.0.adain2.1.fc.bias", "generator.noise_res.0.adain2.2.fc.weight", "generator.noise_res.0.adain2.2.fc.bias", "generator.noise_res.0.alpha1.0", "generator.noise_res.0.alpha1.1", "generator.noise_res.0.alpha1.2", "generator.noise_res.0.alpha2.0", "generator.noise_res.0.alpha2.1", "generator.noise_res.0.alpha2.2", "generator.noise_res.1.convs1.0.bias", "generator.noise_res.1.convs1.0.weight_g", "generator.noise_res.1.convs1.0.weight_v", "generator.noise_res.1.convs1.1.bias", "generator.noise_res.1.convs1.1.weight_g", "generator.noise_res.1.convs1.1.weight_v", "generator.noise_res.1.convs1.2.bias", "generator.noise_res.1.convs1.2.weight_g", "generator.noise_res.1.convs1.2.weight_v", "generator.noise_res.1.convs2.0.bias", "generator.noise_res.1.convs2.0.weight_g", "generator.noise_res.1.convs2.0.weight_v", "generator.noise_res.1.convs2.1.bias", "generator.noise_res.1.convs2.1.weight_g", "generator.noise_res.1.convs2.1.weight_v", "generator.noise_res.1.convs2.2.bias", "generator.noise_res.1.convs2.2.weight_g", "generator.noise_res.1.convs2.2.weight_v", "generator.noise_res.1.adain1.0.fc.weight", "generator.noise_res.1.adain1.0.fc.bias", "generator.noise_res.1.adain1.1.fc.weight", "generator.noise_res.1.adain1.1.fc.bias", "generator.noise_res.1.adain1.2.fc.weight", "generator.noise_res.1.adain1.2.fc.bias", "generator.noise_res.1.adain2.0.fc.weight", "generator.noise_res.1.adain2.0.fc.bias", "generator.noise_res.1.adain2.1.fc.weight", "generator.noise_res.1.adain2.1.fc.bias", "generator.noise_res.1.adain2.2.fc.weight", "generator.noise_res.1.adain2.2.fc.bias", "generator.noise_res.1.alpha1.0", "generator.noise_res.1.alpha1.1", "generator.noise_res.1.alpha1.2", "generator.noise_res.1.alpha2.0", "generator.noise_res.1.alpha2.1", "generator.noise_res.1.alpha2.2", "generator.noise_res.2.convs1.0.bias", "generator.noise_res.2.convs1.0.weight_g", "generator.noise_res.2.convs1.0.weight_v", "generator.noise_res.2.convs1.1.bias", "generator.noise_res.2.convs1.1.weight_g", "generator.noise_res.2.convs1.1.weight_v", "generator.noise_res.2.convs1.2.bias", "generator.noise_res.2.convs1.2.weight_g", "generator.noise_res.2.convs1.2.weight_v", "generator.noise_res.2.convs2.0.bias", "generator.noise_res.2.convs2.0.weight_g", "generator.noise_res.2.convs2.0.weight_v", "generator.noise_res.2.convs2.1.bias", "generator.noise_res.2.convs2.1.weight_g", "generator.noise_res.2.convs2.1.weight_v", "generator.noise_res.2.convs2.2.bias", "generator.noise_res.2.convs2.2.weight_g", "generator.noise_res.2.convs2.2.weight_v", "generator.noise_res.2.adain1.0.fc.weight", "generator.noise_res.2.adain1.0.fc.bias", "generator.noise_res.2.adain1.1.fc.weight", "generator.noise_res.2.adain1.1.fc.bias", "generator.noise_res.2.adain1.2.fc.weight", "generator.noise_res.2.adain1.2.fc.bias", "generator.noise_res.2.adain2.0.fc.weight", "generator.noise_res.2.adain2.0.fc.bias", "generator.noise_res.2.adain2.1.fc.weight", "generator.noise_res.2.adain2.1.fc.bias", "generator.noise_res.2.adain2.2.fc.weight", "generator.noise_res.2.adain2.2.fc.bias", "generator.noise_res.2.alpha1.0", "generator.noise_res.2.alpha1.1", "generator.noise_res.2.alpha1.2", "generator.noise_res.2.alpha2.0", "generator.noise_res.2.alpha2.1", "generator.noise_res.2.alpha2.2", "generator.noise_res.3.convs1.0.bias", "generator.noise_res.3.convs1.0.weight_g", "generator.noise_res.3.convs1.0.weight_v", "generator.noise_res.3.convs1.1.bias", "generator.noise_res.3.convs1.1.weight_g", "generator.noise_res.3.convs1.1.weight_v", "generator.noise_res.3.convs1.2.bias", "generator.noise_res.3.convs1.2.weight_g", "generator.noise_res.3.convs1.2.weight_v", "generator.noise_res.3.convs2.0.bias", "generator.noise_res.3.convs2.0.weight_g", "generator.noise_res.3.convs2.0.weight_v", "generator.noise_res.3.convs2.1.bias", "generator.noise_res.3.convs2.1.weight_g", "generator.noise_res.3.convs2.1.weight_v", "generator.noise_res.3.convs2.2.bias", "generator.noise_res.3.convs2.2.weight_g", "generator.noise_res.3.convs2.2.weight_v", "generator.noise_res.3.adain1.0.fc.weight", "generator.noise_res.3.adain1.0.fc.bias", "generator.noise_res.3.adain1.1.fc.weight", "generator.noise_res.3.adain1.1.fc.bias", "generator.noise_res.3.adain1.2.fc.weight", "generator.noise_res.3.adain1.2.fc.bias", "generator.noise_res.3.adain2.0.fc.weight", "generator.noise_res.3.adain2.0.fc.bias", "generator.noise_res.3.adain2.1.fc.weight", "generator.noise_res.3.adain2.1.fc.bias", "generator.noise_res.3.adain2.2.fc.weight", "generator.noise_res.3.adain2.2.fc.bias", "generator.noise_res.3.alpha1.0", "generator.noise_res.3.alpha1.1", "generator.noise_res.3.alpha1.2", "generator.noise_res.3.alpha2.0", "generator.noise_res.3.alpha2.1", "generator.noise_res.3.alpha2.2", "generator.resblocks.0.convs1.0.bias", "generator.resblocks.0.convs1.0.weight_g", "generator.resblocks.0.convs1.0.weight_v", "generator.resblocks.0.convs1.1.bias", "generator.resblocks.0.convs1.1.weight_g", "generator.resblocks.0.convs1.1.weight_v", "generator.resblocks.0.convs1.2.bias", "generator.resblocks.0.convs1.2.weight_g", "generator.resblocks.0.convs1.2.weight_v", "generator.resblocks.0.convs2.0.bias", "generator.resblocks.0.convs2.0.weight_g", "generator.resblocks.0.convs2.0.weight_v", "generator.resblocks.0.convs2.1.bias", "generator.resblocks.0.convs2.1.weight_g", "generator.resblocks.0.convs2.1.weight_v", "generator.resblocks.0.convs2.2.bias", "generator.resblocks.0.convs2.2.weight_g", "generator.resblocks.0.convs2.2.weight_v", "generator.resblocks.0.adain1.0.fc.weight", "generator.resblocks.0.adain1.0.fc.bias", "generator.resblocks.0.adain1.1.fc.weight", "generator.resblocks.0.adain1.1.fc.bias", "generator.resblocks.0.adain1.2.fc.weight", "generator.resblocks.0.adain1.2.fc.bias", "generator.resblocks.0.adain2.0.fc.weight", "generator.resblocks.0.adain2.0.fc.bias", "generator.resblocks.0.adain2.1.fc.weight", "generator.resblocks.0.adain2.1.fc.bias", "generator.resblocks.0.adain2.2.fc.weight", "generator.resblocks.0.adain2.2.fc.bias", "generator.resblocks.0.alpha1.0", "generator.resblocks.0.alpha1.1", "generator.resblocks.0.alpha1.2", "generator.resblocks.0.alpha2.0", "generator.resblocks.0.alpha2.1", "generator.resblocks.0.alpha2.2", "generator.resblocks.1.convs1.0.bias", "generator.resblocks.1.convs1.0.weight_g", "generator.resblocks.1.convs1.0.weight_v", "generator.resblocks.1.convs1.1.bias", "generator.resblocks.1.convs1.1.weight_g", "generator.resblocks.1.convs1.1.weight_v", "generator.resblocks.1.convs1.2.bias", "generator.resblocks.1.convs1.2.weight_g", "generator.resblocks.1.convs1.2.weight_v", "generator.resblocks.1.convs2.0.bias", "generator.resblocks.1.convs2.0.weight_g", "generator.resblocks.1.convs2.0.weight_v", "generator.resblocks.1.convs2.1.bias", "generator.resblocks.1.convs2.1.weight_g", "generator.resblocks.1.convs2.1.weight_v", "generator.resblocks.1.convs2.2.bias", "generator.resblocks.1.convs2.2.weight_g", "generator.resblocks.1.convs2.2.weight_v", "generator.resblocks.1.adain1.0.fc.weight", "generator.resblocks.1.adain1.0.fc.bias", "generator.resblocks.1.adain1.1.fc.weight", "generator.resblocks.1.adain1.1.fc.bias", "generator.resblocks.1.adain1.2.fc.weight", "generator.resblocks.1.adain1.2.fc.bias", "generator.resblocks.1.adain2.0.fc.weight", "generator.resblocks.1.adain2.0.fc.bias", "generator.resblocks.1.adain2.1.fc.weight", "generator.resblocks.1.adain2.1.fc.bias", "generator.resblocks.1.adain2.2.fc.weight", "generator.resblocks.1.adain2.2.fc.bias", "generator.resblocks.1.alpha1.0", "generator.resblocks.1.alpha1.1", "generator.resblocks.1.alpha1.2", "generator.resblocks.1.alpha2.0", "generator.resblocks.1.alpha2.1", "generator.resblocks.1.alpha2.2", "generator.resblocks.2.convs1.0.bias", "generator.resblocks.2.convs1.0.weight_g", "generator.resblocks.2.convs1.0.weight_v", "generator.resblocks.2.convs1.1.bias", "generator.resblocks.2.convs1.1.weight_g", "generator.resblocks.2.convs1.1.weight_v", "generator.resblocks.2.convs1.2.bias", "generator.resblocks.2.convs1.2.weight_g", "generator.resblocks.2.convs1.2.weight_v", "generator.resblocks.2.convs2.0.bias", "generator.resblocks.2.convs2.0.weight_g", "generator.resblocks.2.convs2.0.weight_v", "generator.resblocks.2.convs2.1.bias", "generator.resblocks.2.convs2.1.weight_g", "generator.resblocks.2.convs2.1.weight_v", "generator.resblocks.2.convs2.2.bias", "generator.resblocks.2.convs2.2.weight_g", "generator.resblocks.2.convs2.2.weight_v", "generator.resblocks.2.adain1.0.fc.weight", "generator.resblocks.2.adain1.0.fc.bias", "generator.resblocks.2.adain1.1.fc.weight", "generator.resblocks.2.adain1.1.fc.bias", "generator.resblocks.2.adain1.2.fc.weight", "generator.resblocks.2.adain1.2.fc.bias", "generator.resblocks.2.adain2.0.fc.weight", "generator.resblocks.2.adain2.0.fc.bias", "generator.resblocks.2.adain2.1.fc.weight", "generator.resblocks.2.adain2.1.fc.bias", "generator.resblocks.2.adain2.2.fc.weight", "generator.resblocks.2.adain2.2.fc.bias", "generator.resblocks.2.alpha1.0", "generator.resblocks.2.alpha1.1", "generator.resblocks.2.alpha1.2", "generator.resblocks.2.alpha2.0", "generator.resblocks.2.alpha2.1", "generator.resblocks.2.alpha2.2", "generator.resblocks.3.convs1.0.bias", "generator.resblocks.3.convs1.0.weight_g", "generator.resblocks.3.convs1.0.weight_v", "generator.resblocks.3.convs1.1.bias", "generator.resblocks.3.convs1.1.weight_g", "generator.resblocks.3.convs1.1.weight_v", "generator.resblocks.3.convs1.2.bias", "generator.resblocks.3.convs1.2.weight_g", "generator.resblocks.3.convs1.2.weight_v", "generator.resblocks.3.convs2.0.bias", "generator.resblocks.3.convs2.0.weight_g", "generator.resblocks.3.convs2.0.weight_v", "generator.resblocks.3.convs2.1.bias", "generator.resblocks.3.convs2.1.weight_g", "generator.resblocks.3.convs2.1.weight_v", "generator.resblocks.3.convs2.2.bias", "generator.resblocks.3.convs2.2.weight_g", "generator.resblocks.3.convs2.2.weight_v", "generator.resblocks.3.adain1.0.fc.weight", "generator.resblocks.3.adain1.0.fc.bias", "generator.resblocks.3.adain1.1.fc.weight", "generator.resblocks.3.adain1.1.fc.bias", "generator.resblocks.3.adain1.2.fc.weight", "generator.resblocks.3.adain1.2.fc.bias", "generator.resblocks.3.adain2.0.fc.weight", "generator.resblocks.3.adain2.0.fc.bias", "generator.resblocks.3.adain2.1.fc.weight", "generator.resblocks.3.adain2.1.fc.bias", "generator.resblocks.3.adain2.2.fc.weight", "generator.resblocks.3.adain2.2.fc.bias", "generator.resblocks.3.alpha1.0", "generator.resblocks.3.alpha1.1", "generator.resblocks.3.alpha1.2", "generator.resblocks.3.alpha2.0", "generator.resblocks.3.alpha2.1", "generator.resblocks.3.alpha2.2", "generator.resblocks.4.convs1.0.bias", "generator.resblocks.4.convs1.0.weight_g", "generator.resblocks.4.convs1.0.weight_v", "generator.resblocks.4.convs1.1.bias", "generator.resblocks.4.convs1.1.weight_g", "generator.resblocks.4.convs1.1.weight_v", "generator.resblocks.4.convs1.2.bias", "generator.resblocks.4.convs1.2.weight_g", "generator.resblocks.4.convs1.2.weight_v", "generator.resblocks.4.convs2.0.bias", "generator.resblocks.4.convs2.0.weight_g", "generator.resblocks.4.convs2.0.weight_v", "generator.resblocks.4.convs2.1.bias", "generator.resblocks.4.convs2.1.weight_g", "generator.resblocks.4.convs2.1.weight_v", "generator.resblocks.4.convs2.2.bias", "generator.resblocks.4.convs2.2.weight_g", "generator.resblocks.4.convs2.2.weight_v", "generator.resblocks.4.adain1.0.fc.weight", "generator.resblocks.4.adain1.0.fc.bias", "generator.resblocks.4.adain1.1.fc.weight", "generator.resblocks.4.adain1.1.fc.bias", "generator.resblocks.4.adain1.2.fc.weight", "generator.resblocks.4.adain1.2.fc.bias", "generator.resblocks.4.adain2.0.fc.weight", "generator.resblocks.4.adain2.0.fc.bias", "generator.resblocks.4.adain2.1.fc.weight", "generator.resblocks.4.adain2.1.fc.bias", "generator.resblocks.4.adain2.2.fc.weight", "generator.resblocks.4.adain2.2.fc.bias", "generator.resblocks.4.alpha1.0", "generator.resblocks.4.alpha1.1", "generator.resblocks.4.alpha1.2", "generator.resblocks.4.alpha2.0", "generator.resblocks.4.alpha2.1", "generator.resblocks.4.alpha2.2", "generator.resblocks.5.convs1.0.bias", "generator.resblocks.5.convs1.0.weight_g", "generator.resblocks.5.convs1.0.weight_v", "generator.resblocks.5.convs1.1.bias", "generator.resblocks.5.convs1.1.weight_g", "generator.resblocks.5.convs1.1.weight_v", "generator.resblocks.5.convs1.2.bias", "generator.resblocks.5.convs1.2.weight_g", "generator.resblocks.5.convs1.2.weight_v", "generator.resblocks.5.convs2.0.bias", "generator.resblocks.5.convs2.0.weight_g", "generator.resblocks.5.convs2.0.weight_v", "generator.resblocks.5.convs2.1.bias", "generator.resblocks.5.convs2.1.weight_g", "generator.resblocks.5.convs2.1.weight_v", "generator.resblocks.5.convs2.2.bias", "generator.resblocks.5.convs2.2.weight_g", "generator.resblocks.5.convs2.2.weight_v", "generator.resblocks.5.adain1.0.fc.weight", "generator.resblocks.5.adain1.0.fc.bias", "generator.resblocks.5.adain1.1.fc.weight", "generator.resblocks.5.adain1.1.fc.bias", "generator.resblocks.5.adain1.2.fc.weight", "generator.resblocks.5.adain1.2.fc.bias", "generator.resblocks.5.adain2.0.fc.weight", "generator.resblocks.5.adain2.0.fc.bias", "generator.resblocks.5.adain2.1.fc.weight", "generator.resblocks.5.adain2.1.fc.bias", "generator.resblocks.5.adain2.2.fc.weight", "generator.resblocks.5.adain2.2.fc.bias", "generator.resblocks.5.alpha1.0", "generator.resblocks.5.alpha1.1", "generator.resblocks.5.alpha1.2", "generator.resblocks.5.alpha2.0", "generator.resblocks.5.alpha2.1", "generator.resblocks.5.alpha2.2", "generator.resblocks.6.convs1.0.bias", "generator.resblocks.6.convs1.0.weight_g", "generator.resblocks.6.convs1.0.weight_v", "generator.resblocks.6.convs1.1.bias", "generator.resblocks.6.convs1.1.weight_g", "generator.resblocks.6.convs1.1.weight_v", "generator.resblocks.6.convs1.2.bias", "generator.resblocks.6.convs1.2.weight_g", "generator.resblocks.6.convs1.2.weight_v", "generator.resblocks.6.convs2.0.bias", "generator.resblocks.6.convs2.0.weight_g", "generator.resblocks.6.convs2.0.weight_v", "generator.resblocks.6.convs2.1.bias", "generator.resblocks.6.convs2.1.weight_g", "generator.resblocks.6.convs2.1.weight_v", "generator.resblocks.6.convs2.2.bias", "generator.resblocks.6.convs2.2.weight_g", "generator.resblocks.6.convs2.2.weight_v", "generator.resblocks.6.adain1.0.fc.weight", "generator.resblocks.6.adain1.0.fc.bias", "generator.resblocks.6.adain1.1.fc.weight", "generator.resblocks.6.adain1.1.fc.bias", "generator.resblocks.6.adain1.2.fc.weight", "generator.resblocks.6.adain1.2.fc.bias", "generator.resblocks.6.adain2.0.fc.weight", "generator.resblocks.6.adain2.0.fc.bias", "generator.resblocks.6.adain2.1.fc.weight", "generator.resblocks.6.adain2.1.fc.bias", "generator.resblocks.6.adain2.2.fc.weight", "generator.resblocks.6.adain2.2.fc.bias", "generator.resblocks.6.alpha1.0", "generator.resblocks.6.alpha1.1", "generator.resblocks.6.alpha1.2", "generator.resblocks.6.alpha2.0", "generator.resblocks.6.alpha2.1", "generator.resblocks.6.alpha2.2", "generator.resblocks.7.convs1.0.bias", "generator.resblocks.7.convs1.0.weight_g", "generator.resblocks.7.convs1.0.weight_v", "generator.resblocks.7.convs1.1.bias", "generator.resblocks.7.convs1.1.weight_g", "generator.resblocks.7.convs1.1.weight_v", "generator.resblocks.7.convs1.2.bias", "generator.resblocks.7.convs1.2.weight_g", "generator.resblocks.7.convs1.2.weight_v", "generator.resblocks.7.convs2.0.bias", "generator.resblocks.7.convs2.0.weight_g", "generator.resblocks.7.convs2.0.weight_v", "generator.resblocks.7.convs2.1.bias", "generator.resblocks.7.convs2.1.weight_g", "generator.resblocks.7.convs2.1.weight_v", "generator.resblocks.7.convs2.2.bias", "generator.resblocks.7.convs2.2.weight_g", "generator.resblocks.7.convs2.2.weight_v", "generator.resblocks.7.adain1.0.fc.weight", "generator.resblocks.7.adain1.0.fc.bias", "generator.resblocks.7.adain1.1.fc.weight", "generator.resblocks.7.adain1.1.fc.bias", "generator.resblocks.7.adain1.2.fc.weight", "generator.resblocks.7.adain1.2.fc.bias", "generator.resblocks.7.adain2.0.fc.weight", "generator.resblocks.7.adain2.0.fc.bias", "generator.resblocks.7.adain2.1.fc.weight", "generator.resblocks.7.adain2.1.fc.bias", "generator.resblocks.7.adain2.2.fc.weight", "generator.resblocks.7.adain2.2.fc.bias", "generator.resblocks.7.alpha1.0", "generator.resblocks.7.alpha1.1", "generator.resblocks.7.alpha1.2", "generator.resblocks.7.alpha2.0", "generator.resblocks.7.alpha2.1", "generator.resblocks.7.alpha2.2", "generator.resblocks.8.convs1.0.bias", "generator.resblocks.8.convs1.0.weight_g", "generator.resblocks.8.convs1.0.weight_v", "generator.resblocks.8.convs1.1.bias", "generator.resblocks.8.convs1.1.weight_g", "generator.resblocks.8.convs1.1.weight_v", "generator.resblocks.8.convs1.2.bias", "generator.resblocks.8.convs1.2.weight_g", "generator.resblocks.8.convs1.2.weight_v", "generator.resblocks.8.convs2.0.bias", "generator.resblocks.8.convs2.0.weight_g", "generator.resblocks.8.convs2.0.weight_v", "generator.resblocks.8.convs2.1.bias", "generator.resblocks.8.convs2.1.weight_g", "generator.resblocks.8.convs2.1.weight_v", "generator.resblocks.8.convs2.2.bias", "generator.resblocks.8.convs2.2.weight_g", "generator.resblocks.8.convs2.2.weight_v", "generator.resblocks.8.adain1.0.fc.weight", "generator.resblocks.8.adain1.0.fc.bias", "generator.resblocks.8.adain1.1.fc.weight", "generator.resblocks.8.adain1.1.fc.bias", "generator.resblocks.8.adain1.2.fc.weight", "generator.resblocks.8.adain1.2.fc.bias", "generator.resblocks.8.adain2.0.fc.weight", "generator.resblocks.8.adain2.0.fc.bias", "generator.resblocks.8.adain2.1.fc.weight", "generator.resblocks.8.adain2.1.fc.bias", "generator.resblocks.8.adain2.2.fc.weight", "generator.resblocks.8.adain2.2.fc.bias", "generator.resblocks.8.alpha1.0", "generator.resblocks.8.alpha1.1", "generator.resblocks.8.alpha1.2", "generator.resblocks.8.alpha2.0", "generator.resblocks.8.alpha2.1", "generator.resblocks.8.alpha2.2", "generator.resblocks.9.convs1.0.bias", "generator.resblocks.9.convs1.0.weight_g", "generator.resblocks.9.convs1.0.weight_v", "generator.resblocks.9.convs1.1.bias", "generator.resblocks.9.convs1.1.weight_g", "generator.resblocks.9.convs1.1.weight_v", "generator.resblocks.9.convs1.2.bias", "generator.resblocks.9.convs1.2.weight_g", "generator.resblocks.9.convs1.2.weight_v", "generator.resblocks.9.convs2.0.bias", "generator.resblocks.9.convs2.0.weight_g", "generator.resblocks.9.convs2.0.weight_v", "generator.resblocks.9.convs2.1.bias", "generator.resblocks.9.convs2.1.weight_g", "generator.resblocks.9.convs2.1.weight_v", "generator.resblocks.9.convs2.2.bias", "generator.resblocks.9.convs2.2.weight_g", "generator.resblocks.9.convs2.2.weight_v", "generator.resblocks.9.adain1.0.fc.weight", "generator.resblocks.9.adain1.0.fc.bias", "generator.resblocks.9.adain1.1.fc.weight", "generator.resblocks.9.adain1.1.fc.bias", "generator.resblocks.9.adain1.2.fc.weight", "generator.resblocks.9.adain1.2.fc.bias", "generator.resblocks.9.adain2.0.fc.weight", "generator.resblocks.9.adain2.0.fc.bias", "generator.resblocks.9.adain2.1.fc.weight", "generator.resblocks.9.adain2.1.fc.bias", "generator.resblocks.9.adain2.2.fc.weight", "generator.resblocks.9.adain2.2.fc.bias", "generator.resblocks.9.alpha1.0", "generator.resblocks.9.alpha1.1", "generator.resblocks.9.alpha1.2", "generator.resblocks.9.alpha2.0", "generator.resblocks.9.alpha2.1", "generator.resblocks.9.alpha2.2", "generator.resblocks.10.convs1.0.bias", "generator.resblocks.10.convs1.0.weight_g", "generator.resblocks.10.convs1.0.weight_v", "generator.resblocks.10.convs1.1.bias", "generator.resblocks.10.convs1.1.weight_g", "generator.resblocks.10.convs1.1.weight_v", "generator.resblocks.10.convs1.2.bias", "generator.resblocks.10.convs1.2.weight_g", "generator.resblocks.10.convs1.2.weight_v", "generator.resblocks.10.convs2.0.bias", "generator.resblocks.10.convs2.0.weight_g", "generator.resblocks.10.convs2.0.weight_v", "generator.resblocks.10.convs2.1.bias", "generator.resblocks.10.convs2.1.weight_g", "generator.resblocks.10.convs2.1.weight_v", "generator.resblocks.10.convs2.2.bias", "generator.resblocks.10.convs2.2.weight_g", "generator.resblocks.10.convs2.2.weight_v", "generator.resblocks.10.adain1.0.fc.weight", "generator.resblocks.10.adain1.0.fc.bias", "generator.resblocks.10.adain1.1.fc.weight", "generator.resblocks.10.adain1.1.fc.bias", "generator.resblocks.10.adain1.2.fc.weight", "generator.resblocks.10.adain1.2.fc.bias", "generator.resblocks.10.adain2.0.fc.weight", "generator.resblocks.10.adain2.0.fc.bias", "generator.resblocks.10.adain2.1.fc.weight", "generator.resblocks.10.adain2.1.fc.bias", "generator.resblocks.10.adain2.2.fc.weight", "generator.resblocks.10.adain2.2.fc.bias", "generator.resblocks.10.alpha1.0", "generator.resblocks.10.alpha1.1", "generator.resblocks.10.alpha1.2", "generator.resblocks.10.alpha2.0", "generator.resblocks.10.alpha2.1", "generator.resblocks.10.alpha2.2", "generator.resblocks.11.convs1.0.bias", "generator.resblocks.11.convs1.0.weight_g", "generator.resblocks.11.convs1.0.weight_v", "generator.resblocks.11.convs1.1.bias", "generator.resblocks.11.convs1.1.weight_g", "generator.resblocks.11.convs1.1.weight_v", "generator.resblocks.11.convs1.2.bias", "generator.resblocks.11.convs1.2.weight_g", "generator.resblocks.11.convs1.2.weight_v", "generator.resblocks.11.convs2.0.bias", "generator.resblocks.11.convs2.0.weight_g", "generator.resblocks.11.convs2.0.weight_v", "generator.resblocks.11.convs2.1.bias", "generator.resblocks.11.convs2.1.weight_g", "generator.resblocks.11.convs2.1.weight_v", "generator.resblocks.11.convs2.2.bias", "generator.resblocks.11.convs2.2.weight_g", "generator.resblocks.11.convs2.2.weight_v", "generator.resblocks.11.adain1.0.fc.weight", "generator.resblocks.11.adain1.0.fc.bias", "generator.resblocks.11.adain1.1.fc.weight", "generator.resblocks.11.adain1.1.fc.bias", "generator.resblocks.11.adain1.2.fc.weight", "generator.resblocks.11.adain1.2.fc.bias", "generator.resblocks.11.adain2.0.fc.weight", "generator.resblocks.11.adain2.0.fc.bias", "generator.resblocks.11.adain2.1.fc.weight", "generator.resblocks.11.adain2.1.fc.bias", "generator.resblocks.11.adain2.2.fc.weight", "generator.resblocks.11.adain2.2.fc.bias", "generator.resblocks.11.alpha1.0", "generator.resblocks.11.alpha1.1", "generator.resblocks.11.alpha1.2", "generator.resblocks.11.alpha2.0", "generator.resblocks.11.alpha2.1", "generator.resblocks.11.alpha2.2", "generator.alphas.0", "generator.alphas.1", "generator.alphas.2", "generator.alphas.3", "generator.alphas.4", "generator.conv_post.bias", "generator.conv_post.weight_g", "generator.conv_post.weight_v". 
	Unexpected key(s) in state_dict: "to_out.0.bias", "to_out.0.weight_g", "to_out.0.weight_v", "decode.4.conv1.bias", "decode.4.conv1.weight_g", "decode.4.conv1.weight_v", "decode.4.conv2.bias", "decode.4.conv2.weight_g", "decode.4.conv2.weight_v", "decode.4.norm1.fc.weight", "decode.4.norm1.fc.bias", "decode.4.norm2.fc.weight", "decode.4.norm2.fc.bias", "decode.2.pool.bias", "decode.2.pool.weight_g", "decode.2.pool.weight_v", "encode.0.conv1.bias", "encode.0.conv1.weight_g", "encode.0.conv1.weight_v", "encode.0.conv2.bias", "encode.0.conv2.weight_g", "encode.0.conv2.weight_v", "encode.0.norm1.weight", "encode.0.norm1.bias", "encode.0.norm2.weight", "encode.0.norm2.bias", "encode.0.conv1x1.weight_g", "encode.0.conv1x1.weight_v", "encode.1.conv1.bias", "encode.1.conv1.weight_g", "encode.1.conv1.weight_v", "encode.1.conv2.bias", "encode.1.conv2.weight_g", "encode.1.conv2.weight_v", "encode.1.norm1.weight", "encode.1.norm1.bias", "encode.1.norm2.weight", "encode.1.norm2.bias", "F0_conv.0.conv1.bias", "F0_conv.0.conv1.weight_g", "F0_conv.0.conv1.weight_v", "F0_conv.0.conv2.bias", "F0_conv.0.conv2.weight_g", "F0_conv.0.conv2.weight_v", "F0_conv.0.norm1.weight", "F0_conv.0.norm1.bias", "F0_conv.0.norm2.weight", "F0_conv.0.norm2.bias", "F0_conv.0.conv1x1.weight_g", "F0_conv.0.conv1x1.weight_v", "F0_conv.0.pool.bias", "F0_conv.0.pool.weight_g", "F0_conv.0.pool.weight_v", "F0_conv.1.bias", "F0_conv.1.weight_g", "F0_conv.1.weight_v", "F0_conv.2.weight", "F0_conv.2.bias", "N_conv.0.conv1.bias", "N_conv.0.conv1.weight_g", "N_conv.0.conv1.weight_v", "N_conv.0.conv2.bias", "N_conv.0.conv2.weight_g", "N_conv.0.conv2.weight_v", "N_conv.0.norm1.weight", "N_conv.0.norm1.bias", "N_conv.0.norm2.weight", "N_conv.0.norm2.bias", "N_conv.0.conv1x1.weight_g", "N_conv.0.conv1x1.weight_v", "N_conv.0.pool.bias", "N_conv.0.pool.weight_g", "N_conv.0.pool.weight_v", "N_conv.1.bias", "N_conv.1.weight_g", "N_conv.1.weight_v", "N_conv.2.weight", "N_conv.2.bias", "asr_res.1.weight", "asr_res.1.bias". 
	size mismatch for decode.2.conv1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for decode.2.conv1.weight_g: copying a param with shape torch.Size([512, 1, 1]) from checkpoint, the shape in current model is torch.Size([1024, 1, 1]).
	size mismatch for decode.2.conv1.weight_v: copying a param with shape torch.Size([512, 1090, 3]) from checkpoint, the shape in current model is torch.Size([1024, 1090, 3]).
	size mismatch for decode.2.conv2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for decode.2.conv2.weight_g: copying a param with shape torch.Size([512, 1, 1]) from checkpoint, the shape in current model is torch.Size([1024, 1, 1]).
	size mismatch for decode.2.conv2.weight_v: copying a param with shape torch.Size([512, 512, 3]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 3]).
	size mismatch for decode.2.norm2.fc.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([2048, 128]).
	size mismatch for decode.2.norm2.fc.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([2048]).
	size mismatch for decode.2.conv1x1.weight_g: copying a param with shape torch.Size([512, 1, 1]) from checkpoint, the shape in current model is torch.Size([1024, 1, 1]).
	size mismatch for decode.2.conv1x1.weight_v: copying a param with shape torch.Size([512, 1090, 1]) from checkpoint, the shape in current model is torch.Size([1024, 1090, 1]).
	size mismatch for decode.3.conv1.weight_v: copying a param with shape torch.Size([512, 512, 3]) from checkpoint, the shape in current model is torch.Size([512, 1090, 3]).
	size mismatch for decode.3.norm1.fc.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([2180, 128]).
	size mismatch for decode.3.norm1.fc.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([2180]).

### Synthesize speech

In [None]:
# get first 3 training sample as references

train_path = config.get('train_data', None)
val_path = config.get('val_data', None)
train_list, val_list = get_data_path_list(train_path, val_path)

ref_dicts = {}
for j in range(3):
    filename = train_list[j].split('|')[0]
    name = filename.split('/')[-1].replace('.wav', '')
    ref_dicts[name] = filename
    
reference_embeddings = compute_style(ref_dicts, model)

In [None]:
# synthesize a text
text = ''' StyleTTS is a style-based generative model for parallel TTS that can synthesize diverse speech with natural prosody from a reference speech utterance. '''

In [None]:
# tokenize
ps = global_phonemizer.phonemize([text])
ps = word_tokenize(ps[0])
ps = ' '.join(ps)
tokens = textclenaer(ps)
tokens.insert(0, 0)
tokens.append(0)
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)

In [None]:
converted_samples = {}

with torch.no_grad():
    input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
    m = length_to_mask(input_lengths).to(device)
    t_en = model.text_encoder(tokens, input_lengths, m)
        
    for key, (ref, _) in reference_embeddings.items():
        
        s = ref.squeeze(1)
        style = s
        
        d = model.predictor.text_encoder(t_en, style, input_lengths, m)

        x, _ = model.predictor.lstm(d)
        duration = model.predictor.duration_proj(x)
        pred_dur = torch.round(duration.squeeze()).clamp(min=1)
        
        pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
        c_frame = 0
        for i in range(pred_aln_trg.size(0)):
            pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
            c_frame += int(pred_dur[i].data)

        # encode prosody
        en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
        style = s.expand(en.shape[0], en.shape[1], -1)

        F0_pred, N_pred = model.predictor.F0Ntrain(en, s)

        out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), 
                                F0_pred, N_pred, ref.squeeze().unsqueeze(0))


        c = out.squeeze()
        y_g_hat = generator(c.unsqueeze(0))
        y_out = y_g_hat.squeeze().cpu().numpy()

        c = out.squeeze()
        y_g_hat = generator(c.unsqueeze(0))
        y_out = y_g_hat.squeeze()
        
        converted_samples[key] = y_out.cpu().numpy()

In [None]:
import IPython.display as ipd
for key, wave in converted_samples.items():
    print('Synthesized: %s' % key)
    display(ipd.Audio(wave, rate=24000))
    try:
        print('Reference: %s' % key)
        display(ipd.Audio(reference_embeddings[key][-1], rate=24000))
    except:
        continue