In [1]:
import utils
import models
import re
import torch
import commons
import numpy as np

from torch.utils.data import DataLoader
from text.symbols import symbols
from data_utils import TextMelLoader, TextMelCollate
import re
from text import _clean_text
hps = utils.get_hparams_from_file("./configs/base.json")

In [2]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

In [3]:
from text import text_to_sequence, cmudict
hps = utils.get_hparams_from_file("./configs/base.json")
cmu_dict = cmudict.CMUDict(hps.data.cmudict_path)

In [4]:
import transformers
from transformers import DistilBertModel, DistilBertTokenizer, DistilBertConfig
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')

In [5]:
class WordPhoneMelLoader(TextMelLoader):
    
    def __getitem__(self, index):
        audiopath, sent = self.audiopaths_and_text[index]
        phones, mel = self.get_mel_text_pair((audiopath, sent))
        
        clean_sent = _clean_text(sent, ['english_cleaners'])
        wordpieces = tokenizer.encode(clean_sent, add_special_tokens=True)
        wordpieces = torch.IntTensor(wordpieces)

        words = clean_sent.split(" ")
        wordpiece_attn = torch.zeros((len(wordpieces), len(words)))
        phone_attn = torch.zeros((len(phones), len(words)))

        wp_idx = 0
        ph_idx = 0
        wordpieces_ = wordpieces.numpy()
        phones_ = phones.numpy()

        for i, word in enumerate(words):
            phs = text_to_sequence(word, ['english_cleaners'], cmu_dict)
            wps = tokenizer.encode(word, add_special_tokens=False)

            while np.any(wordpieces_[wp_idx:wp_idx+len(wps)] - wps):
                if wp_idx + len(wps) >= len(wordpieces_):
                    break
                wp_idx += 1
            if not np.any(wordpieces_[wp_idx:wp_idx+len(wps)] - wps):
                wordpiece_attn[wp_idx:wp_idx +len(wps), i] = 1
            
            while np.any(phones_[ph_idx:ph_idx+len(phs)] - phs):
                if ph_idx + len(phs) >= len(phones_):
                    break
                ph_idx += 1
            if not np.any(phones_[ph_idx:ph_idx+len(phs)] - phs):
                phone_attn[ph_idx:ph_idx + len(phs), i] = 1
                if ph_idx+len(phs) < len(phones_) and phones_[ph_idx+len(phs)] == 11:
                    phone_attn[ph_idx+len(phs), i] = 1
                    ph_idx += 1 
                    
        assert torch.all(wordpiece_attn.sum(dim=0))
        assert torch.all(phone_attn.sum(dim=0))
        
        return wordpieces, phones, mel, wordpiece_attn, phone_attn

In [6]:
class WordPhoneMelCollate(TextMelCollate):
    
    def __call__(self, batch):
        text_lengths = torch.LongTensor([len(x[0]) for x in batch])
        max_text_len = max(text_lengths)
        text_padded = torch.LongTensor(len(batch), max_text_len)
        text_padded.zero_()
        for i in range(len(batch)):
            text = batch[i][0]
            text_padded[i, :text.size(0)] = text

        phone_lengths = torch.LongTensor([len(x[1]) for x in batch])
        max_phone_len = max(phone_lengths)
        phones_padded = torch.LongTensor(len(batch), max_phone_len)
        phones_padded.zero_()
        for i in range(len(batch)):
            phones = batch[i][1]
            phones_padded[i, :phones.size(0)] = phones
    
        num_mels = batch[0][2].size(0)
        max_target_len = max([x[2].size(1) for x in batch])
        if max_target_len % self.n_frames_per_step != 0:
            max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step
            assert max_target_len % self.n_frames_per_step == 0
        mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
        mel_padded.zero_()
        mel_lengths = torch.LongTensor(len(batch))
        for i in range(len(batch)):
            mel = batch[i][2]
            mel_padded[i, :, :mel.size(1)] = mel
            mel_lengths[i] = mel.size(1)
            
        max_word_count = max([x[3].size(1) for x in batch])
        wordpiece_attn_padded = torch.zeros((len(batch), max_text_len, max_word_count))
        phone_attn_padded = torch.zeros((len(batch), max_phone_len, max_word_count))
        for i in range(len(batch)):
            wordpiece_attn_padded[i, :batch[i][3].size(0), :batch[i][3].size(1)] = batch[i][3]
            phone_attn_padded[i, :batch[i][4].size(0), :batch[i][4].size(1)] = batch[i][4]
        
        return text_padded, text_lengths, phones_padded, phone_lengths, mel_padded, mel_lengths, wordpiece_attn_padded, phone_attn_padded

In [13]:
collate_fn = WordPhoneMelCollate(1)

train_dataset = WordPhoneMelLoader(hps.data.training_files, hps.data)
train_loader = DataLoader(train_dataset, num_workers=8, shuffle=False,
      batch_size=hps.train.batch_size, pin_memory=True,
      drop_last=False, collate_fn=collate_fn)


In [17]:
test_dataset = WordPhoneMelLoader('filelists/ljs_audio_text_test_filelist.txt', hps.data)
test_loader = DataLoader(test_dataset, num_workers=8, shuffle=False,
                         batch_size=hps.train.batch_size, pin_memory=True,
                         drop_last=False, collate_fn=collate_fn)

In [16]:
from tqdm import tqdm
import os
save_dir = 'temp_data/train/'
os.mkdir(save_dir)
file_bases = ['wp%d.npy', 'wp_len%d.npy', 'ph%d.npy', 'ph_len%d.npy', 'mel_%d.npy', 'mel_len%d.npy', 'wp_attn%d.npy', 'ph_attn%d.npy']
for batch_idx, data in tqdm(enumerate(train_loader)):
    for item, file_base in zip(data, file_bases):
        np.save(os.path.join(save_dir, file_base % batch_idx), item.numpy())
    

391it [01:10,  5.51it/s]


In [18]:
save_dir = 'temp_data/test/'
os.mkdir(save_dir)
file_bases = ['wp%d.npy', 'wp_len%d.npy', 'ph%d.npy', 'ph_len%d.npy', 'mel_%d.npy', 'mel_len%d.npy', 'wp_attn%d.npy', 'ph_attn%d.npy']
for batch_idx, data in tqdm(enumerate(test_loader)):
    for item, file_base in zip(data, file_bases):
        np.save(os.path.join(save_dir, file_base % batch_idx), item.numpy())

16it [00:03,  4.07it/s]


In [32]:
import random
class disk_loader:
    
    ITEMS = ['wp', 'wp_len', 'ph', 'ph_len', 'mel_', 'mel_len', 'wp_attn', 'ph_attn']
    dummy = 'wp_attn'
    
    def __init__(self, path):
        self.path = path
        filenames = os.listdir(path)
        self.filenames = [f for f in filenames if self.dummy in f]
        random.shuffle(self.filenames)
    
    def loader(self):
        for f in self.filenames:
            items = [np.load(os.path.join(self.path, f.replace(self.dummy, item_name))) for item_name in self.ITEMS]
            yield items



In [33]:
dl = disk_loader('temp_data/train/')
for i in (dl.loader()):
    print(i)
    break

[array([[ 101, 1103,  107, ...,    0,    0,    0],
       [ 101, 1105,  107, ...,    0,    0,    0],
       [ 101,  170, 1748, ...,  119,  102,    0],
       ...,
       [ 101, 1103, 3137, ...,    0,    0,    0],
       [ 101, 1103, 3318, ...,    0,    0,    0],
       [ 101, 1598, 5923, ...,    0,    0,    0]]), array([18, 25, 32, 27, 29, 30, 26, 20, 25, 23,  9, 18, 28, 18, 20, 29, 18,
       33, 32, 24, 22, 13, 21, 28, 27, 22, 22, 13, 26, 21, 15, 17]), array([[ 91,  73,  11, ...,   0,   0,   0],
       [ 73, 119,  90, ...,   0,   0,   0],
       [102,  11, 104, ...,   0,   0,   0],
       ...,
       [ 91,  73,  11, ...,   0,   0,   0],
       [ 91,  73,  11, ...,   0,   0,   0],
       [116,  73, 119, ...,   0,   0,   0]]), array([ 57,  73, 106,  99, 111, 117,  87,  82,  89, 115,  26,  78, 120,
        59,  71, 106,  55, 131, 112,  86,  87,  48,  68, 129,  94,  73,
        98,  49, 105,  72,  81,  69]), array([[[ -9.737273 ,  -8.349141 ,  -6.771012 , ...,   0.       ,
           0. 

In [7]:
collate_fn = WordPhoneMelCollate(1)

train_dataset = WordPhoneMelLoader(hps.data.training_files, hps.data)
train_loader = DataLoader(train_dataset, num_workers=8, shuffle=False,
      batch_size=hps.train.batch_size, pin_memory=True,
      drop_last=True, collate_fn=collate_fn)

test_dataset = WordPhoneMelLoader('filelists/ljs_audio_text_test_filelist.txt', hps.data)
test_loader = DataLoader(test_dataset, num_workers=8, shuffle=False,
      batch_size=hps.train.batch_size, pin_memory=True,
      drop_last=True, collate_fn=collate_fn)