# Import libraries and setup matplotlib

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '2'

import warnings
warnings.filterwarnings("ignore")

import sys
sys.path.append('waveglow/')

import matplotlib.pyplot as plt
%matplotlib inline

import IPython.display as ipd
import pickle as pkl
from text import *
import numpy as np
import torch
import hparams
from modules.model import Model
from denoiser import Denoiser
import soundfile

### Text preprocessing

In [2]:
from g2p_en import G2p
from text.symbols import symbols
from text.cleaners import custom_english_cleaners

# Mappings from symbol to numeric ID and vice versa:
symbol_to_id = {s: i for i, s in enumerate(symbols)}
id_to_symbol = {i: s for i, s in enumerate(symbols)}

g2p = G2p()
def text2seq(text, data_type='char'):
    text = custom_english_cleaners(text.rstrip())
    if data_type=='phone':
        clean_phone = []
        for s in g2p(text.lower()):
            if '@'+s in symbol_to_id:
                clean_phone.append('@'+s)
            else:
                clean_phone.append(s)
                text = clean_phone
    
    # Append SOS, EOS token
    sequence = [symbol_to_id[c] for c in text]
    sequence = [symbol_to_id['^']] + sequence + [symbol_to_id['~']]
    return sequence

### Waveglow

In [3]:
waveglow_path = 'training_log/waveglow_256channels.pt'
waveglow = torch.load(waveglow_path)['model']

for m in waveglow.modules():
    if 'Conv' in str(type(m)):
        setattr(m, 'padding_mode', 'zeros')

waveglow.cuda().eval()
for k in waveglow.convinv:
    k.float()

denoiser = Denoiser(waveglow)

with open('filelists/ljs_audio_text_val_filelist.txt', 'r') as f:
    lines = [line.split('|') for line in f.read().splitlines()]

In [4]:
for data_type in ['char', 'phone']:
    for step in ['20000', '50000', '100000', '200000']:
        checkpoint_path = f"training_log/transformer-tts-{data_type}/checkpoint_{step}"
        state_dict = {}
        for k, v in torch.load(checkpoint_path)['state_dict'].items():
            state_dict[k[7:]]=v

        model = Model(hparams).cuda()
        model.load_state_dict(state_dict)
        _ = model.cuda().eval()

        for i in [1, 6, 22]:
            file_name, _, text = lines[i]
            sequence = np.array(text2seq(text,data_type))[None, :]
            sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()

            with torch.no_grad():
                melspec, enc_alignments, dec_alignments, enc_dec_alignments, stop = model.inference(sequence, max_len=2048)
                melspec = melspec[:,:,:len(stop)]
                audio = waveglow.infer(melspec, sigma=0.666)

            soundfile.write(f'wavs/{file_name}_{data_type}{step}.wav', audio.cpu().numpy()[0].astype(float), 22050)