# Import libraries and setup matplotlib

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '2'

import warnings
warnings.filterwarnings("ignore")

import sys
sys.path.append('waveglow/')

import matplotlib.pyplot as plt
%matplotlib inline

import IPython.display as ipd
import pickle as pkl
from text import *
import numpy as np
import torch
import hparams
from model import Model
from denoiser import Denoiser

# Load model from checkpoint

### 1. TTS model

In [None]:
checkpoint_path = "training_log/fastspeech/checkpoint_100000"

model = Model(hparams).cuda()
model.load_state_dict(torch.load(checkpoint_path)['state_dict'])
_ = model.cuda().eval()

### 2. WaveGlow

In [None]:
waveglow_path = 'training_log/waveglow_256channels.pt'
waveglow = torch.load(waveglow_path)['model']

for m in waveglow.modules():
    if 'Conv' in str(type(m)):
        setattr(m, 'padding_mode', 'zeros')
        
waveglow.cuda().eval()
for k in waveglow.convinv:
    k.float()
denoiser = Denoiser(waveglow)

# Speech Synthesis

In [None]:
with open('filelists/ljs_audio_text_val_filelist.txt', 'r') as f:
    lines = [line.split('|') for line in f.read().splitlines()]

file_name, _, text = lines[1]
sequence = np.array(text_to_sequence(text, ['english_cleaners']))[None, :]
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()

In [None]:
print(f"Script:\n{text}\n")
for alpha in [0.8, 0.9, 1.0, 1.1, 1.2]:
    with torch.no_grad():
        melspec, durations = model.inference(sequence, alpha)
        melspec = torch.log(10**(melspec / 10))
        audio = waveglow.infer(melspec, sigma=0.666)

    print(f"alpha: {alpha}")
    ipd.display(ipd.Audio(audio.cpu().numpy(), rate=hparams.sampling_rate))
    
    if alpha==1.0:
        ticks=[]
        phoneme = sequence_to_text(sequence[0].tolist())
        duration = torch.round(durations[0]).tolist()
        for i, d in enumerate(duration):
            ticks.extend([phoneme[i]]*int(d))

        plt.figure(figsize=(20,5))
        plt.imshow(melspec.detach().cpu()[0], aspect='auto', origin='lower')
        plt.xticks(range(melspec.size(2)), ticks)

        plt.figure(figsize=(15,60))
        plt.imshow(melspec.detach().cpu()[0].t(), aspect='auto')
        plt.yticks(range(melspec.size(2)), ticks)
        
print()
print("Sequence:")
print( [ c for c in sequence_to_text(sequence[0].tolist())] )
print()
print("Vowel:")
print( [ c for c in sequence_to_text(sequence[0].tolist())
        if c[-1] in [' ', ',', '0', '1', '2'] ] )
plt.show()

# Duration

In [None]:
print(f'Ratio:\t{melspec.size(2) / sequence.size(1):.2f}')
print()
for c, d in zip(sequence_to_text(sequence[0].tolist()), durations[0].tolist()):
    print(f'{c}:\t{d:.2f}')