# StyleTTS Demo (LibriTTS)


### Utils

In [1]:
%cd /home/melissa/ArtiVoice-GTR

/home/melissa/ArtiVoice-GTR


In [2]:
# load packages
import random
import yaml
from munch import Munch
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
import librosa

from models import *
from utils import *

%matplotlib inline

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  method='lar', copy_X=True, eps=np.finfo(np.float).eps,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  method='lar', copy_X=True, eps=np.finfo(np.float).eps,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.float).eps, copy_Gram=True, verbose=0,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.float).eps, copy_X=True, fit_path=True,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.float).eps, copy_X=True, fit_path=True,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes

In [3]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [4]:
_pad = "$"
_punctuation = ';:,.!?¡¿—…"«»“” '
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"


# Export all symbols:
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)

dicts = {}
for i in range(len((symbols))):
    dicts[symbols[i]] = i

class TextCleaner:
    def __init__(self, dummy=None):
        self.word_index_dictionary = dicts
    def __call__(self, text):
        indexes = []
        for char in text:
            try:
                indexes.append(self.word_index_dictionary[char])
            except KeyError:
                print(char)
        return indexes

textclenaer = TextCleaner()

In [5]:
to_mel = torchaudio.transforms.MelSpectrogram(
    n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
mean, std = -4, 4

def length_to_mask(lengths):
    mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
    mask = torch.gt(mask+1, lengths.unsqueeze(1))
    return mask

def preprocess(wave):
    wave_tensor = torch.from_numpy(wave).float()
    mel_tensor = to_mel(wave_tensor)
    mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
    return mel_tensor

def compute_style(ref_dicts):
    reference_embeddings = {}
    for key, path in ref_dicts.items():
        wave, sr = librosa.load(path, sr=24000)
        audio, index = librosa.effects.trim(wave, top_db=30)
        if sr != 24000:
            audio = librosa.resample(audio, sr, 24000)
        mel_tensor = preprocess(audio).to(device)
        try:
            with torch.no_grad():
                ref = model.style_encoder(mel_tensor.unsqueeze(1))
            reference_embeddings[key] = (ref.squeeze(1), audio)
        except:
            continue
    
    return reference_embeddings

### Load models

In [6]:
# load phonemizer
import phonemizer
global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True,  with_stress=True)

In [58]:
# load hifi-gan

import sys
sys.path.insert(0, "./Demo/hifi-gan")

import glob
import os
import json
import torch
from attrdict import AttrDict
from vocoder import Generator
import numpy as np

h = None

def load_checkpoint(filepath, device):
    assert os.path.isfile(filepath)
    print("Loading '{}'".format(filepath))
    checkpoint_dict = torch.load(filepath, map_location=device)
    print("Complete.")
    return checkpoint_dict

def scan_checkpoint(cp_dir, prefix):
    pattern = os.path.join(cp_dir, prefix + '*')
    cp_list = glob.glob(pattern)
    if len(cp_list) == 0:
        return ''
    return sorted(cp_list)[-1]

cp_g = scan_checkpoint("/storageNVME/melissa/ckpts/stylettsCN/pretrained/Vocoder/libritts", 'g_')

config_file = os.path.join(os.path.split(cp_g)[0], 'config.json')
with open(config_file) as f:
    data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)

device = torch.device(device)
generator = Generator(h).to(device)

state_dict_g = load_checkpoint(cp_g, device)
generator.load_state_dict(state_dict_g['generator'])
generator.eval()
generator.remove_weight_norm()

Loading '/storageNVME/melissa/ckpts/stylettsCN/pretrained/Vocoder/libritts/g_00060000'
Complete.
Removing weight norm...


In [40]:
# load StyleTTS
model_path = "/storageNVME/melissa/ckpts/stylettsCN/debug/epoch_2nd_00012.pth"
model_config_path = "/storageNVME/melissa/ckpts/stylettsCN/debug/config_2nd.yml"

config = yaml.safe_load(open(model_config_path))

# load pretrained ASR model
ASR_config = config.get('ASR_config', False)
ASR_path = config.get('ASR_path', False)
text_aligner = load_ASR_models(ASR_path, ASR_config)

# load pretrained F0 model
F0_path = config.get('F0_path', False)
pitch_extractor = load_F0_models(F0_path)

model = build_model(Munch(config['model_params']), text_aligner, pitch_extractor)

params = torch.load(model_path, map_location='cpu')
params = params['net']
for key in model:
    if key in params:
        if not "discriminator" in key:
            try:
                model[key].load_state_dict(params[key])
                print('%s loaded' % key)
            except Exception as e:
                print("%s loaded failed" % key)
                print(e)
_ = [model[key].eval() for key in model]
_ = [model[key].to(device) for key in model]

  "num_layers={}".format(dropout, num_layers))


predictor loaded
decoder loaded
pitch_extractor loaded
text_encoder loaded
style_encoder loaded
text_aligner loaded


### Resynthesis from the first training stage generator

In [59]:
mels_gt = torch.tensor(np.load("./synth/aishell3_s1-epoch_1st_00008-gt.npy")).to(device)
mels_rec = torch.tensor(np.load("./synth/aishell3_s1-epoch_1st_00008-mel_rec.npy")).to(device)

with torch.no_grad():
    y_gt = generator(mels_gt)
    y_rec = generator(mels_rec)

import IPython.display as ipd
for i, (gt, rec) in enumerate(zip(y_gt, y_rec)):
    # if i == 2: break
    print('GT after vocoder:', i)
    display(ipd.Audio(gt.squeeze().detach().cpu().numpy(), rate=24000))
    # print('Recon after vocoder:', i)
    # display(ipd.Audio(rec.squeeze().detach().cpu().numpy(), rate=24000))

GT after vocoder: 0


GT after vocoder: 1


GT after vocoder: 2


GT after vocoder: 3


In [62]:
# from scipy.io.wavfile import write

# for i, gt in enumerate(y_gt):
#     write(f"gt_hifigan_{i}.wav", 24000, gt.squeeze().detach().cpu().numpy().astype(np.int16))

import soundfile as sf
for i, gt in enumerate(y_gt):
    sf.write(f"gt_hifigan_{i}.wav", gt.squeeze().detach().cpu().numpy().astype(np.int16), 24000)


In [7]:
mels_gt = torch.tensor(np.load("./synth/aishell3_s1-epoch_1st_00030-gt.npy")).to(device)
mels_rec = torch.tensor(np.load("./synth/aishell3_s1-epoch_1st_00030-mel_rec.npy")).to(device)

with torch.no_grad():
    y_gt = generator(mels_gt)
    y_rec = generator(mels_rec)

import IPython.display as ipd
for i, (gt, rec) in enumerate(zip(y_gt, y_rec)):
    print('GT after vocoder:', i)
    display(ipd.Audio(gt.squeeze().detach().cpu().numpy(), rate=24000))
    print('Recon after vocoder:', i)
    display(ipd.Audio(rec.squeeze().detach().cpu().numpy(), rate=24000))

GT after vocoder: 0


Recon after vocoder: 0


GT after vocoder: 1


Recon after vocoder: 1


GT after vocoder: 2


Recon after vocoder: 2


GT after vocoder: 3


Recon after vocoder: 3


In [6]:
mels_gt = torch.tensor(np.load("./synth/libritts_aishell3_s1-epoch_1st_00016-gt.npy")).to(device)
mels_rec = torch.tensor(np.load("./synth/libritts_aishell3_s1-epoch_1st_00016-mel_rec.npy")).to(device)

with torch.no_grad():
    y_gt = generator(mels_gt)
    y_rec = generator(mels_rec)

import IPython.display as ipd

for i, (gt, rec) in enumerate(zip(y_gt, y_rec)):
    if i == 2: break
    print('GT after vocoder:', i)
    display(ipd.Audio(gt.squeeze().detach().cpu().numpy(), rate=24000))
    print('Recon after vocoder:', i)
    display(ipd.Audio(rec.squeeze().detach().cpu().numpy(), rate=24000))

GT after vocoder: 0


Recon after vocoder: 0


GT after vocoder: 1


Recon after vocoder: 1


In [30]:
mels_gt = torch.tensor(np.load("./synth/libritts_aishell3_s1-epoch_1st_00024-gt.npy")).to(device)
mels_rec = torch.tensor(np.load("./synth/libritts_aishell3_s1-epoch_1st_00024-mel_rec.npy")).to(device)

with torch.no_grad():
    y_gt = generator(mels_gt)
    y_rec = generator(mels_rec)

import IPython.display as ipd
for i, (gt, rec) in enumerate(zip(y_gt, y_rec)):
    if i == 2: break
    print('GT after vocoder:', i)
    display(ipd.Audio(gt.squeeze().detach().cpu().numpy(), rate=24000))
    print('Recon after vocoder:', i)
    display(ipd.Audio(rec.squeeze().detach().cpu().numpy(), rate=24000))

GT after vocoder: 0


Recon after vocoder: 0


GT after vocoder: 1


Recon after vocoder: 1


In [35]:
mels_gt = torch.tensor(np.load("./synth/gtr_s1-epoch_1st_00080-gt.npy")).to(device)
mels_rec = torch.tensor(np.load("./synth/gtr_s1-epoch_1st_00080-mel_rec.npy")).to(device)

with torch.no_grad():
    y_gt = generator(mels_gt)
    y_rec = generator(mels_rec)

import IPython.display as ipd
for i, (gt, rec) in enumerate(zip(y_gt, y_rec)):
    # if i == 2: break
    print('GT after vocoder:', i)
    display(ipd.Audio(gt.squeeze().detach().cpu().numpy(), rate=24000))
    print('Recon after vocoder:', i)
    display(ipd.Audio(rec.squeeze().detach().cpu().numpy(), rate=24000))

GT after vocoder: 0


Recon after vocoder: 0


GT after vocoder: 1


Recon after vocoder: 1


In [36]:
from scipy.io.wavfile import write

for i, gt in enumerate(y_gt):
    write(f"gtr_s1_{i}.wav", 24000, gt.squeeze().detach().cpu().numpy().astype(np.int16))

### Synthesize speech (seen speakers, LibriTTS train-clean-100)

In [41]:
# get first 3 training sample as references

# train_path = config.get('train_data', None)
# val_path = config.get('val_data', None)
train_path = "Data/train_list_aishell3_renamed.txt"
val_path = "Data/val_list_aishell3_renamed.txt"
train_list, val_list = get_data_path_list(train_path, val_path)

ref_dicts = {}
for j in range(3):
    filename = train_list[j].split('|')[0]
    name = filename.split('/')[-1].replace('.wav', '')
    ref_dicts[name] = filename

# ref_dicts = {'SSB00430444': '/storageNVME/melissa/aishell3/train/wav/SSB0043/SSB00430444.wav'}
ref_dicts = {'103': '/storageNVME/melissa/libritts_part/LibriTTS/train-clean-100/103/1241/103_1241_000032_000001.wav'}

reference_embeddings = compute_style(ref_dicts)

In [42]:
ref_dicts

{'103': '/storageNVME/melissa/libritts_part/LibriTTS/train-clean-100/103/1241/103_1241_000032_000001.wav'}

In [11]:
# synthesize a text
text = ''' StyleTTS is a style based generative model that can synthesize diverse speech with natural prosody from a reference speech utterance. '''
# text = ''' but it isn't, it's firmly fastened on one end.'''

In [43]:
# tokenize
# ps = global_phonemizer.phonemize([text])
# tokens = textclenaer(ps[0])
ps = "ʈʂʰˈʐtɕʰˈi xˈʊŋˈiŋtɕʰˈjɑŋ ʈʂˈweɪkˈan tˈweɪfˈɑŋ pˈankˈʊŋlˈi"  # '/storageNVME/melissa/aishell3/train/wav/SSB0043/SSB00430444.wav'
# ps = "bˌʌt ɪɾ ˌɪzəntɪts fˈɜːmli fˈæsənd æt wˈʌn ˈɛnd."  # /storageNVME/melissa/libritts_part/LibriTTS/train-clean-460/103/1241/103_1241_000032_000001.wav
tokens = textclenaer(ps)
tokens.insert(0, 0)
tokens.append(0)
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)

In [57]:
converted_samples = {}


with torch.no_grad():
    input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
    m = length_to_mask(input_lengths).to(device)
    t_en = model.text_encoder(tokens, input_lengths, m)
        
    for key, (ref, _) in reference_embeddings.items():
        
        s = ref.squeeze(1)
        style = s

        d = model.predictor.text_encoder(t_en, style, input_lengths, m)

        x, _ = model.predictor.lstm(d)
        duration = model.predictor.duration_proj(x)
        pred_dur = torch.round(duration.squeeze()).clamp(min=1)
        
        pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
        c_frame = 0
        for i in range(pred_aln_trg.size(0)):
            pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
            c_frame += int(pred_dur[i].data)

        # encode prosody
        en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
        style = s.expand(en.shape[0], en.shape[1], -1)

        F0_pred, N_pred = model.predictor.F0Ntrain(en, s)

        out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), 
                                F0_pred, N_pred, ref.squeeze().unsqueeze(0))


        c = out.squeeze()
        print(c.shape)
        y_g_hat = generator(c.unsqueeze(0))
        y_out = y_g_hat.squeeze().cpu().numpy()

        c = out.squeeze()
        y_g_hat = generator(c.unsqueeze(0))
        y_out = y_g_hat.squeeze()
        
        converted_samples[key] = y_out.cpu().numpy()

RuntimeError: input.size(-1) must be equal to input_size. Expected 896, got 640

In [53]:
# converted_samples = {}

# with torch.no_grad():
#     input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
#     m = length_to_mask(input_lengths).to(device)
#     t_en = model.text_encoder(tokens, input_lengths, m)
        
#     for key, (ref, _) in reference_embeddings.items():
        
#         s = ref.squeeze(1)
#         style = s

#         print(s)
#         prosody = model.predictor.embedding(s).permute(0, 2, 1)
#         texts = torch.cat([texts, prosody], axis=1)
#         d = model.predictor.text_encoder(texts, style, input_lengths, m)

#         x, _ = model.predictor.lstm(d)
#         duration = model.predictor.duration_proj(x)
#         pred_dur = torch.round(duration.squeeze()).clamp(min=1)
        
#         pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
#         c_frame = 0
#         for i in range(pred_aln_trg.size(0)):
#             pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
#             c_frame += int(pred_dur[i].data)

#         # encode prosody
#         en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
#         style = s.expand(en.shape[0], en.shape[1], -1)

#         F0_pred, N_pred = model.predictor.F0Ntrain(en, s)

#         out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), 
#                                 F0_pred, N_pred, ref.squeeze().unsqueeze(0))


#         c = out.squeeze()
#         print(c.shape)
#         y_g_hat = generator(c.unsqueeze(0))
#         y_out = y_g_hat.squeeze().cpu().numpy()

#         c = out.squeeze()
#         y_g_hat = generator(c.unsqueeze(0))
#         y_out = y_g_hat.squeeze()
        
#         converted_samples[key] = y_out.cpu().numpy()

tensor([[ 0.1292, -0.2295, -0.1842, -0.0956,  0.0729,  0.1443, -0.0544,  0.1090,
         -0.1570, -0.1284, -0.1494,  0.0418, -0.0159,  0.2407,  0.2559,  0.1489,
          0.0660,  0.0641,  0.1657, -0.2643, -0.2320, -0.1753, -0.1330,  0.0551,
         -0.1224,  0.0679,  0.0779,  0.2501,  0.1059, -0.0688,  0.0519, -0.2095,
         -0.0819, -0.1876,  0.2514, -0.1462, -0.2772, -0.2105,  0.2003, -0.2154,
          0.0317,  0.2308, -0.0849,  0.1862, -0.2936,  0.1281, -0.1678,  0.2647,
         -0.0817, -0.1009,  0.0531, -0.2499,  0.2476,  0.2015,  0.0128,  0.2889,
         -0.1953,  0.0413,  0.0047,  0.0435,  0.1980,  0.1089, -0.1900,  0.1218,
          0.0785,  0.0963, -0.0446,  0.1427,  0.1922, -0.0351,  0.0078,  0.2167,
         -0.0309,  0.0232,  0.0475,  0.0218,  0.2510,  0.3043,  0.1012,  0.1566,
         -0.0466, -0.0365, -0.0850, -0.1666,  0.1104,  0.1309, -0.0296,  0.2797,
         -0.0044,  0.0045,  0.1860, -0.1747,  0.1385, -0.2701,  0.2023,  0.1253,
         -0.1217, -0.0799,  

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)

In [25]:
import IPython.display as ipd
for key, wave in converted_samples.items():
    print('Synthesized: %s' % key)
    display(ipd.Audio(wave, rate=24000))
    try:
        print('Reference: %s' % key)
        display(ipd.Audio(reference_embeddings[key][-1], rate=24000))
    except:
        continue

### Zero-shot TTS (unseen speakers, LibriTTS test-clean)

In [None]:
test_clean_path = '/share/naplab/users/yl4579/data/LibriTTS/test-clean/'

ref_dicts = {}
# pick first 3 speakers from test-clean
spks = [ f.path for f in os.scandir(test_clean_path) if f.is_dir() ]
spks = spks[:3]
for spk in spks:
    spk_path = spk
    spk = spk.split('/')[-1]
    spk_path = spk_path + "/" + (np.random.choice(os.listdir(spk_path), size=1)[0])
    for f in os.listdir(spk_path):
        if f.endswith('.wav'):
            ref_dicts[spk] = spk_path + "/" + f
reference_embeddings = compute_style(ref_dicts)

In [None]:
# synthesize a text
text = ''' StyleTTS is a style based generative model that can synthesize diverse speech with natural prosody from a reference speech utterance. '''

In [15]:
# tokenize
ps = global_phonemizer.phonemize([text])
tokens = textclenaer(ps[0])
tokens.insert(0, 0)
tokens.append(0)
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)

[0,
 61,
 62,
 156,
 43,
 102,
 54,
 16,
 62,
 157,
 51,
 158,
 62,
 157,
 51,
 158,
 156,
 86,
 61,
 16,
 102,
 68,
 16,
 70,
 16,
 61,
 62,
 156,
 43,
 102,
 54,
 16,
 44,
 156,
 47,
 102,
 61,
 62,
 16,
 46,
 147,
 156,
 86,
 56,
 85,
 123,
 83,
 62,
 157,
 102,
 64,
 16,
 55,
 156,
 69,
 158,
 46,
 83,
 54,
 16,
 81,
 72,
 62,
 16,
 53,
 72,
 56,
 16,
 61,
 156,
 102,
 56,
 119,
 83,
 61,
 157,
 43,
 102,
 68,
 16,
 46,
 43,
 102,
 64,
 156,
 87,
 158,
 61,
 16,
 61,
 58,
 156,
 51,
 158,
 62,
 131,
 16,
 65,
 102,
 81,
 16,
 56,
 156,
 72,
 62,
 131,
 85,
 123,
 83,
 54,
 16,
 58,
 123,
 156,
 69,
 158,
 61,
 83,
 46,
 51,
 16,
 48,
 123,
 138,
 55,
 16,
 70,
 16,
 123,
 156,
 86,
 48,
 123,
 83,
 56,
 61,
 16,
 61,
 58,
 156,
 51,
 158,
 62,
 131,
 16,
 156,
 138,
 125,
 85,
 123,
 83,
 56,
 61,
 4,
 16,
 0]

In [None]:
converted_samples = {}

with torch.no_grad():
    input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
    m = length_to_mask(input_lengths).to(device)
    t_en = model.text_encoder(tokens, input_lengths, m)
        
    for key, (ref, _) in reference_embeddings.items():
        
        s = ref.squeeze(1)
        style = s
        
        d = model.predictor.text_encoder(t_en, style, input_lengths, m)

        x, _ = model.predictor.lstm(d)
        duration = model.predictor.duration_proj(x)
        pred_dur = torch.round(duration.squeeze()).clamp(min=1)
        
        pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
        c_frame = 0
        for i in range(pred_aln_trg.size(0)):
            pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
            c_frame += int(pred_dur[i].data)

        # encode prosody
        en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
        style = s.expand(en.shape[0], en.shape[1], -1)

        F0_pred, N_pred = model.predictor.F0Ntrain(en, s)

        out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), 
                                F0_pred, N_pred, ref.squeeze().unsqueeze(0))


        c = out.squeeze()
        y_g_hat = generator(c.unsqueeze(0))
        y_out = y_g_hat.squeeze().cpu().numpy()

        c = out.squeeze()
        y_g_hat = generator(c.unsqueeze(0))
        y_out = y_g_hat.squeeze()
        
        converted_samples[key] = y_out.cpu().numpy()

In [None]:
import IPython.display as ipd
for key, wave in converted_samples.items():
    print('Synthesized: %s' % key)
    display(ipd.Audio(wave, rate=24000))
    try:
        print('Reference: %s' % key)
        display(ipd.Audio(reference_embeddings[key][-1], rate=24000))
    except:
        continue