# Import libraries and setup

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]=''

import warnings
warnings.filterwarnings("ignore")

import sys
sys.path.append('waveglow/')

import matplotlib.pyplot as plt
%matplotlib inline

import IPython.display as ipd
import pickle as pkl
import torch
import hparams
from data_utils import TextMelLoader, TextMelCollate
from torch.utils.data import DataLoader
from model import Model
from text import text_to_sequence, sequence_to_text
from denoiser import Denoiser
from tqdm import tqdm_notebook as tqdm
import librosa

checkpoint_path = ""

model = Model(hparams).cuda()
model.load_state_dict(torch.load(checkpoint_path)['state_dict'])
_ = model.cuda().eval()

# Data filtering

In [None]:
datasets = ['train', 'val', 'test']

for dataset in datasets:
    with open(f'filelists/ljs_audio_text_{dataset}_filelist.txt', 'r') as f:
        lines = [line.split('|') for line in f.read().splitlines()]

    for i in tqdm(range(len(lines))):
        file_name, _, text = lines[i]
        seq = os.path.join(f'{hparams.data_path}/preprocessed',
                           'sequence')
        mel = os.path.join(f'{hparams.data_path}/preprocessed',
                           'melspectrogram')

        with open(f'{seq}/{file_name}_sequence.pkl', 'rb') as f:
            text_padded = pkl.load(f).unsqueeze(0)
        with open(f'{mel}/{file_name}_melspectrogram.pkl', 'rb') as f:
            mel_padded = pkl.load(f).unsqueeze(0)

        text_lengths=torch.LongTensor([text_padded.size(1)])
        mel_lengths=torch.LongTensor([mel_padded.size(2)])

        _, melspec, alignments, _, _ = model(text_padded.cuda(), 
                                             mel_padded.cuda(), 
                                             text_lengths.cuda(),
                                             mel_lengths.cuda())
        melspec = melspec[0]
        alignments = torch.cat(alignments, dim=0)
        
        F = torch.mean(torch.max(alignments, dim=-1)[0], dim=-1)
        r, c = torch.argmax(F).item()//2, torch.argmax(F).item()%2
        
        location = torch.max(alignments[r,c], dim=1)[1]
        diff = location[1:]-location[:-1]
        valid_data = torch.all( (diff==-2)
                               +(diff==-1)
                               +(diff==0)
                               +(diff==1)
                               +(diff==2) )

        
        if valid_data:
            with open(f'{hparams.teacher_path}/targets/{file_name}.pkl', 'wb') as f:
                pkl.dump(melspec.detach().cpu(), f)
            with open(f'{hparams.teacher_path}/alignments/{file_name}.pkl', 'wb') as f:
                pkl.dump(alignments[r,c].detach().cpu(), f)

# Check data

In [None]:
print("Text:")
print(text)
print()

print("Melspectrogram:")
plt.figure(figsize=(16, 4))
plt.imshow(melspec.detach().cpu().numpy(),
           aspect='auto',
           origin='bottom',
           interpolation='none')
plt.show()
print()

print("Alignments:")
F = torch.mean(torch.max(alignments, dim=-1)[0], dim=-1)
fig, axes = plt.subplots(6, 2, figsize=(20,60))
for i in range(6):
    for j in [0, 1]:
        axes[i, j].imshow(alignments[i, j].detach().cpu().numpy().T,
                          aspect='auto',
                          origin='bottom',
                          interpolation='none')
        axes[i, j].set_title(f'Layer: {i} / Head: {j} / F: {F[i,j]}', fontsize=15)

plt.show()