# Import libraries and setup matplotlib

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '2'

import warnings
warnings.filterwarnings("ignore")

import sys
sys.path.append('waveglow/')

import matplotlib.pyplot as plt
%matplotlib inline

import IPython.display as ipd
import pickle as pkl
from text import *
import numpy as np
import torch
import hparams
from model import Model
from denoiser import Denoiser

# Load model from checkpoint

### 1. TTS model

In [None]:
checkpoint_path = ""

model = Model(hparams).cuda()
model.load_state_dict(torch.load(checkpoint_path)['state_dict'])
_ = model.cuda().eval()

### 2. WaveGlow

In [None]:
waveglow_path = f'{hparams.output_directory}/waveglow_256channels.pt'
waveglow = torch.load(waveglow_path)['model']

for m in waveglow.modules():
    if 'Conv' in str(type(m)):
        setattr(m, 'padding_mode', 'zeros')

waveglow.cuda().eval()
for k in waveglow.convinv:
    k.float()

denoiser = Denoiser(waveglow)

# Speech Synthesis

In [None]:
with open('filelists/ljs_audio_text_val_filelist.txt', 'r') as f:
    lines = [line.split('|') for line in f.read().splitlines()]

file_name, _, text = lines[1]
sequence = np.array(text_to_sequence(text, ['english_cleaners']))[None, :]
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()

In [None]:
with torch.no_grad():
    melspec, alignments = model.inference(sequence, max_len=768)
    melspec = torch.log(10**(melspec / 10))
    audio = waveglow.infer(melspec, sigma=0.666)
    

print("Text:")
print(text)
print()

print("Audio:")
audio_denoised = denoiser(audio, strength=0.01)[:, 0]
ipd.display(ipd.Audio(audio_denoised.cpu().numpy(),
                      rate=hparams.sampling_rate))
print()

print("Melspectrogram:")
plt.figure(figsize=(16, 4))
plt.imshow(melspec[0].cpu().numpy(),
           aspect='auto',
           origin='bottom',
           interpolation='none')
plt.show()
print()


print("Alignments:")
fig, axes = plt.subplots(6, 2, figsize=(20,60))
for i in range(6):
    for j in [0, 1]:
        axes[i, j].imshow(alignments[i, j, :melspec.size(2)].cpu().numpy().T,
                          aspect='auto',
                          origin='bottom',
                          interpolation='none')
        axes[i, j].set_title(f'Layer: {i} / Head: {j}', fontsize=15)

plt.show()