### Import libraries, metadata

In [None]:
import os
import librosa
from librosa.filters import mel as librosa_mel_fn
import pickle as pkl
import IPython.display as ipd
from tqdm.notebook import tqdm
import torch
import codecs
import matplotlib.pyplot as plt
%matplotlib inline

from g2p_en import G2p
from text import *
from text import cmudict
from text.cleaners import custom_english_cleaners
from text.symbols import symbols

# Mappings from symbol to numeric ID and vice versa:
symbol_to_id = {s: i for i, s in enumerate(symbols)}
id_to_symbol = {i: s for i, s in enumerate(symbols)}

csv_file = '/media/disk1/lyh/LJSpeech-1.1/metadata.csv'
root_dir = '/media/disk1/lyh/LJSpeech-1.1/wavs'
data_dir = '/media/disk1/lyh/LJSpeech-1.1/waveglow'

g2p = G2p()
metadata={}
with codecs.open(csv_file, 'r', 'utf-8') as fid:
    for line in fid.readlines():
        id, _, text = line.split("|")
        
        clean_char = custom_english_cleaners(text.rstrip())
        clean_phone = []
        for s in g2p(clean_char.lower()):
            if '@'+s in symbol_to_id:
                clean_phone.append('@'+s)
            else:
                clean_phone.append(s)
        
        metadata[id]={'char':clean_char,
                     'phone':clean_phone}

In [None]:
def text2seq(text):
    sequence = [symbol_to_id[c] for c in text]
    # Append EOS token
    sequence.append(symbol_to_id['~'])
    return sequence

### STFT

In [None]:
from stft import STFT

class TacotronSTFT(torch.nn.Module):
    def __init__(self,
                 filter_length=1024,
                 hop_length=256,
                 win_length=1024,
                 n_mel_channels=80,
                 sampling_rate=22050, 
                 mel_fmin=0.0,
                 mel_fmax=8000.0):
        super(TacotronSTFT, self).__init__()
        self.stft_fn = STFT(filter_length, hop_length, win_length)
        mel_basis = librosa_mel_fn(sampling_rate,
                                   filter_length,
                                   n_mel_channels,
                                   mel_fmin,
                                   mel_fmax)
        self.mel_basis = torch.from_numpy(mel_basis).float()

    def wav_to_specs(self, y):
        assert(torch.min(y.data) >= -1)
        assert(torch.max(y.data) <= 1)

        magnitudes, _ = self.stft_fn.transform(y)
        mel_output = torch.matmul(self.mel_basis, magnitudes)
        melspec = torch.log(mel_output)
        
        return melspec

stft = TacotronSTFT()

### Others

In [None]:
def get_mel(filename):
    wav, sr = librosa.load(filename, sr=22050)
    wav = torch.FloatTensor(wav.astype(np.float32))
    
    melspec = stft.wav_to_specs(wav.unsqueeze(0))
    return melspec.squeeze(0), wav


def save_file(fname):
    wav_name = os.path.join(root_dir, fname) + '.wav'
    text = metadata[fname]['char']
    char_seq = torch.LongTensor( text2seq(metadata[fname]['char']) )
    phone_seq = torch.LongTensor( text2seq(metadata[fname]['phone']) )
    
    melspec, wav = get_mel(wav_name)
    '''
    with open(f'{data_dir}/char_seq/{fname}_sequence.pkl', 'wb') as f:
        pkl.dump(char_seq, f)
    with open(f'{data_dir}/phone_seq/{fname}_sequence.pkl', 'wb') as f:
        pkl.dump(phone_seq, f)
    with open(f'{data_dir}/melspectrogram/{fname}_melspectrogram.pkl', 'wb') as f:
        pkl.dump(melspec, f)
    '''
    return text, char_seq, phone_seq, melspec

### Save and Inspect Data

In [None]:
for k in tqdm(metadata.keys()):
    text, char_seq, phone_seq, melspec = save_file(k)
    if k == 'LJ001-0019':
        print("Text:")
        print(text)
        print()
        print("Melspectrogram:")
        plt.figure(figsize=(16,4))
        plt.imshow(melspec, aspect='auto', origin='lower')
        plt.show()