In [None]:
#| default_exp tts_data

# TTS-Data

In [None]:
#| export
import torch
from torch.utils.data import Dataset, DataLoader
from fastspeech.loading import *
from fastspeech.preprocess import *
from torch import tensor
import fastcore.all as fc
from functools import partial
from tqdm.auto import tqdm

In [None]:
#| hide
path = '../../data/LJSpeech-1.1/wavs/'
path_vocab = '../../fastspeech/sample_data/cmudict-0.7b.symbols.txt'

In [None]:
#| hide
file_paths = list(map(str, get_audio_files(path)))

In [None]:
#| hide
sampling_rate = 16000

In [None]:
#| export
class TTSDataset(Dataset):
    def __init__(self, 
                 path_data: str,
                 path_vocab: str,
                 sr: int = 22050,
                 n_fft: int = 1024,
                 hl: int = 256,
                 nb: int = 80,
                ) -> None:
        super().__init__()
        replace_to_tg = partial(replace_extension, extension=".TextGrid")
        
        files = get_audio_files(path_data)
        files_tg = files.map(replace_to_tg)
        
        self.vocab = Vocab(path_vocab, ["spn"])
        
        self.data = []
        for audio_file, tg_file in tqdm(list(zip(files, files_tg))):
            wav = load_audio(audio_file, sr=sr)
            mel = melspectrogram(wav, n_fft=n_fft, hl=hl, nb=nb)
            phones, duration = get_phones_and_durations(tg_file, sr, hl)
            
            phones = tensor(phones_list_to_num(phones, self.vocab)).squeeze()
            mels = tensor(melspectrogram(wav, n_fft, hl, nb))
            durations = round_and_align_durations(tensor(duration), mels.shape[-1]).to(torch.int)
            
            self.data.append((phones, durations, mels))
            
    def __getitem__(self, index: int) -> tuple:
        return self.data[index]
    
    def __len__(self) -> int:
        return len(self.data)

In [None]:
#| hide
ds = TTSDataset(path, path_vocab)

  0%|          | 0/13082 [00:00<?, ?it/s]

In [None]:
#| hide
ds[0]

(tensor([27,  6, 69, 67, 38, 54, 81, 45, 66, 80, 30, 55, 26, 33, 22,  9, 67, 85,
         52, 38, 54, 69, 77, 27,  9, 65, 34, 83,  9, 55, 70, 66, 59, 55, 27,  9,
         24,  7, 24,  9, 53, 59, 55, 48,  9, 55, 82, 59, 65,  9, 55, 53, 48, 66,
         44, 24, 30, 53, 26]),
 tensor([ 3, 11,  8,  9, 15,  7,  8, 10, 17,  7,  5,  6,  5, 17, 10,  9,  7, 75,
          9, 12,  4,  6,  3,  3,  4,  9, 11,  5,  4,  6,  8,  7, 18, 13, 46,  3,
          4,  8,  5,  2,  8,  9,  3,  8,  3,  3,  8, 17,  8,  4,  5,  4,  4,  9,
          4, 10, 13, 15,  9], dtype=torch.int32),
 tensor([[3.7324e-06, 4.3111e-05, 4.4483e-05,  ..., 3.1676e-07, 3.7787e-07,
          3.6053e-06],
         [2.0968e-05, 8.7105e-05, 2.5985e-05,  ..., 4.2691e-06, 3.1428e-06,
          1.3746e-05],
         [6.8580e-05, 1.2529e-03, 4.5900e-03,  ..., 3.4039e-05, 1.2055e-05,
          1.7352e-05],
         ...,
         [7.6227e-08, 7.9915e-07, 1.8581e-05,  ..., 1.1654e-07, 7.4758e-08,
          3.9378e-08],
         [4.1749e-08, 2

In [None]:
#| export
def collate_fn(inp, pad_num: int):
    phones, durations, mels = [item[0] for item in inp], [item[1] for item in inp], [item[2] for item in inp]
    
    mel_attention = tensor(list(map(lambda t: t.shape[-1], mels)))
    phones_attention = tensor(list(map(lambda t: t.shape[-1], phones)))
    
    mel_batched = pad_mels(mels, 0)
    phones_batched = pad_phones(pad_max_seq(phones), pad_num)
    mel_len = mel_batched.shape[-1]
    
    duration_batched = pad_duration(pad_max_seq(durations), mel_batched.shape[-1])
    
    assert phones_batched.shape == duration_batched.shape
    assert len(duration_batched.sum(dim=1).unique()) == 1
    
    return phones_batched, duration_batched, mel_batched, mel_attention, phones_attention

In [None]:
#| hide
pad_num = ds.vocab.pad_num
dl = DataLoader(ds, 2, shuffle=True, collate_fn=partial(collate_fn, pad_num=pad_num))

In [None]:
#| hide
next(iter(dl))

(tensor([[27, 45, 67, 45, 82,  9, 67, 65, 30, 68, 53, 48, 55, 30, 67,  9, 67, 31,
          66, 48, 80, 45, 27, 66,  9, 41,  2, 66, 26, 69, 77, 27,  9, 30, 40, 24,
          49, 22,  9, 55, 26, 67, 49, 22, 38, 84, 84, 84, 84, 84, 84, 84, 84, 84,
          84, 84],
         [27,  9, 51, 10, 51, 44, 55, 65,  6, 67, 44, 56, 67, 30, 55, 69,  9, 55,
          67, 10, 79, 26, 30, 70, 69, 59, 53, 26, 42, 45, 54, 42, 49, 66, 45, 25,
          53, 48, 26, 44, 82, 34, 79, 26, 27,  9, 65, 10, 55, 44, 68, 54,  9, 55,
          69,  0]]),
 tensor([[ 6, 10, 13, 16,  7,  4, 10,  6,  6,  9, 11,  5,  8,  6,  9,  4, 11,  9,
          10,  9,  7,  4,  5, 10,  4,  7,  6, 10,  6,  3,  4,  4, 10, 12, 13,  3,
          11, 23, 10,  4,  4, 16, 11, 17, 31,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0, 73],
         [ 4,  4, 10, 19, 10,  5,  7,  5, 11, 10,  4,  5, 10,  8,  4,  5,  5,  5,
           7,  4, 10,  6, 19, 17, 59, 12,  3,  9,  1,  6,  5,  5,  9, 13,  4, 19,
           5,  7,  6,  4,  8, 18,  8,  

In [None]:
import nbdev; nbdev.nbdev_export()