In [None]:
import torch

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]='0'

import warnings
warnings.filterwarnings("ignore")

import sys
sys.path.append('waveglow/')

import matplotlib.pyplot as plt
%matplotlib inline

import IPython.display as ipd
import pickle as pkl
import torch
import hparams
from torch.utils.data import DataLoader
from modules.model import Model
from text import text_to_sequence, sequence_to_text
from denoiser import Denoiser
from tqdm import tqdm_notebook as tqdm
import librosa
from modules.loss import MDNLoss

# data_type = 'phone'
# checkpoint_path = f"training_log/aligntts/checkpoint_40000"
checkpoint_path = f"training_log/aligntts/checkpoint_100000"
state_dict = {}

for k, v in torch.load(checkpoint_path)['state_dict'].items():
    state_dict[k[7:]]=v


model = Model(hparams).cuda()
model.load_state_dict(state_dict)
_ = model.cuda().eval()
criterion = MDNLoss()

In [None]:
from utils.utils import decode_text, display_hparams

display_hparams(hparams)

data_type = hparams.data_type

In [None]:
import numpy as np
from utils.utils import decode_text

def save_alignment_diagram(alignments, text_padded, text_lengths, mel_padded, file_path='align.png'):
    
    decoded_text = decode_text(text_padded, text_lengths)

    alignment_dict = dict()

    for t, char_order in (alignments == 1).nonzero():
        alignment_dict.setdefault(int(char_order), []).append(int(t))

    xticks = []
    x_redline = []

    for key in alignment_dict.keys():

        xticks.append(np.mean(alignment_dict[key]))
        x_redline.append(alignment_dict[key][-1] + .5)

    fig, axes = plt.subplots(2, 1, figsize=(60, 16), sharex=True)
    axes[0].imshow(alignments.cpu().numpy().T[:, :], aspect='auto', origin='reversed')
    axes[1].imshow(mel_padded.cpu().numpy()[0, :, :], aspect='auto', origin='reversed')

    plt.xticks(xticks, text, fontsize=24)
    axes[0].set_yticks([])
    axes[1].set_yticks([])

    for x in x_redline:
        axes[0].axvline(x=x, color='white')
        axes[1].axvline(x=x, color='red')

    plt.tight_layout()
    plt.savefig(file_path)

In [None]:
skip_existing = True

datasets = ['train', 'val', 'test']

for dataset in datasets:
    with open(f'filelists/ljs_audio_text_{dataset}_filelist.txt', 'r') as f:
        lines = [line.split('|') for line in f.read().splitlines()]

    for i in tqdm(range(len(lines))):
        file_name, _, text = lines[i]
        
        save_name = f'../Dataset/LJSpeech-1.1/preprocessed/alignments/{file_name}.pkl'
        
        if os.path.isfile(save_name) and skip_existing:
            continue
        
        text = '^' + text + '~'
#         print(text)
        seq = os.path.join('../Dataset/LJSpeech-1.1/preprocessed',
                           f'{data_type}')
        mel = os.path.join('../Dataset/LJSpeech-1.1/preprocessed',
                           'melspectrogram')

        with open(f'{seq}/{file_name}_sequence.pkl', 'rb') as f:
            text_padded = pkl.load(f).unsqueeze(0).cuda()
        with open(f'{mel}/{file_name}_melspectrogram.pkl', 'rb') as f:
            mel_padded = pkl.load(f).unsqueeze(0).cuda()
        
        mel_padded = (mel_padded - torch.min(mel_padded))\
                         / torch.max((mel_padded - torch.min(mel_padded)))
        
        text_lengths=torch.LongTensor([text_padded.size(1)]).cuda()
        mel_lengths=torch.LongTensor([mel_padded.size(2)]).cuda()
        
        

        encoder_input = model.Prenet(text_padded)
        hidden_states, _ = model.FFT_lower(encoder_input, text_lengths)
        mu_sigma = model.get_mu_sigma(hidden_states)
        _, log_prob_matrix = criterion(mu_sigma, mel_padded, text_lengths, mel_lengths)
    
        alignments = model.viterbi(log_prob_matrix[0:1], text_lengths[0:1], mel_lengths[0:1])[0].t()
        
        with open(f'../Dataset/LJSpeech-1.1/preprocessed/alignments/{file_name}.pkl', 'wb') as f:
            pkl.dump(alignments, f)
            
        save_alignment_diagram(alignments, text_padded, text_lengths, mel_padded, 
                               f'../Dataset/LJSpeech-1.1/preprocessed/alignments/{file_name}.png')
        