#### This notebook will guide you through the code behind vits(https://github.com/jaywalnut310/vits),a classical e2e TTS model.

#### Import required package

In [1]:
import os
import json
import argparse
import itertools
import math
import logging
import json
import subprocess
import re
from unidecode import unidecode
from phonemizer import phonemize
import numpy as np
from scipy.io.wavfile import read
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
import random
import librosa
import librosa.util as librosa_util
from librosa.util import normalize, pad_center, tiny
from scipy.signal import get_window
from scipy.io.wavfile import read
from librosa.filters import mel as librosa_mel_fn

c:\Users\TokaiTeio\.conda\envs\pytorch\lib\site-packages\numpy\.libs\libopenblas.EL2C6PLE4ZYW3ECEVIV3OXXGRN2NRFM2.gfortran-win_amd64.dll
c:\Users\TokaiTeio\.conda\envs\pytorch\lib\site-packages\numpy\.libs\libopenblas.GK7GX5KEQ4F6UYO3P26ULGBQYHGQO7J4.gfortran-win_amd64.dll
c:\Users\TokaiTeio\.conda\envs\pytorch\lib\site-packages\numpy\.libs\libopenblas.XWYDX2IKJW2NMTWSFYNGFUWKQU3LYTCZ.gfortran-win_amd64.dll


#### In a e2e TTS system, the first thing is to understand the training data.
##### the training data of VITS consists of two ingredients, text and coresponding audio.

VITS reads data through a txt. Inside the txt, the data is seened as below.

DUMMY1/LJ050-0234.wav|It has used...

DUMMY1/LJ019-0373.wav|to avail himself...

Lets dive into how VITS clean and preprocess the data.

In [2]:
# Some text processing variables.
# No need to understand when you first time see them.
# You will understand what they mean in the following cells.
_pad        = '_'
_punctuation = ';:,.!?¡¿—…"«»“” '
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
# Export all symbols:
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
# Special symbol ids
SPACE_ID = symbols.index(" ")

In [3]:
# Cleaner of text
# For a deep understanding of what the function means
# open cleaner.ipynb
# Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+')

# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
  ('mrs', 'misess'),
  ('mr', 'mister'),
  ('dr', 'doctor'),
  ('st', 'saint'),
  ('co', 'company'),
  ('jr', 'junior'),
  ('maj', 'major'),
  ('gen', 'general'),
  ('drs', 'doctors'),
  ('rev', 'reverend'),
  ('lt', 'lieutenant'),
  ('hon', 'honorable'),
  ('sgt', 'sergeant'),
  ('capt', 'captain'),
  ('esq', 'esquire'),
  ('ltd', 'limited'),
  ('col', 'colonel'),
  ('ft', 'fort'),
]]
def expand_abbreviations(text):
  for regex, replacement in _abbreviations:
    text = re.sub(regex, replacement, text)
  return text


def expand_numbers(text):
  return normalize_numbers(text)


def lowercase(text):
  return text.lower()


def collapse_whitespace(text):
  return re.sub(_whitespace_re, ' ', text)


def convert_to_ascii(text):
  return unidecode(text)


def basic_cleaners(text):
  '''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
  text = lowercase(text)
  text = collapse_whitespace(text)
  return text


def transliteration_cleaners(text):
  '''Pipeline for non-English text that transliterates to ASCII.'''
  text = convert_to_ascii(text)
  text = lowercase(text)
  text = collapse_whitespace(text)
  return text


def english_cleaners(text):
  '''Pipeline for English text, including abbreviation expansion.'''
  text = convert_to_ascii(text)
  text = lowercase(text)
  text = expand_abbreviations(text)
  phonemes = phonemize(text, language='en-us', backend='espeak', strip=True)
  phonemes = collapse_whitespace(phonemes)
  return phonemes


def english_cleaners2(text):
  '''Pipeline for English text, including abbreviation expansion. + punctuation + stress'''
  text = convert_to_ascii(text)
  text = lowercase(text)
  text = expand_abbreviations(text)
  phonemes = phonemize(text, language='en-us', backend='espeak', strip=True, preserve_punctuation=True, with_stress=True)
  phonemes = collapse_whitespace(phonemes)
  return phonemes

In [7]:
mel_basis = {}
# https://en.wikipedia.org/wiki/Hann_function
hann_window = {}
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}

# For a closer look at the function, open the file dataset.ipynb
class TextAudioLoader(torch.utils.data.Dataset): 
    # This is the class that loads the data in VITS.
    def __init__(self, audiopaths_and_text, hparams):
        # hyperparams and data paths
        # no need to fully understand the init method.
        self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
        self.text_cleaners  = hparams.text_cleaners
        self.max_wav_value  = hparams.max_wav_value
        self.sampling_rate  = hparams.sampling_rate
        self.filter_length  = hparams.filter_length 
        self.hop_length     = hparams.hop_length 
        self.win_length     = hparams.win_length
        self.sampling_rate  = hparams.sampling_rate 

        self.cleaned_text = getattr(hparams, "cleaned_text", False)

        self.add_blank = hparams.add_blank
        self.min_text_len = getattr(hparams, "min_text_len", 1)
        self.max_text_len = getattr(hparams, "max_text_len", 190)

        random.seed(1234)
        random.shuffle(self.audiopaths_and_text)
        self._filter()
    def _filter(self):
        # The below comment is from original repo
        """
        Filter text & store spec lengths
        
        Store spectrogram lengths for Bucketing
        wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
        spec_length = wav_length // hop_length
        """
        
        audiopaths_and_text_new = []
        lengths = []
        for audiopath, text in self.audiopaths_and_text:
            # we filter the text with appropriate length
            if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
                audiopaths_and_text_new.append([audiopath, text])
                # lengths store the length of spectrogram
                # length of spectrogram is length of audio // hop_length
                lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length))
        self.audiopaths_and_text = audiopaths_and_text_new
        self.lengths = lengths
        
    # A method that call get_text and get_audio, return text, spectrogram, and audio(frequency domain).
    def get_audio_text_pair(self, audiopath_and_text):
        # separate filename and text
        audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
        text = self.get_text(text)
        spec, wav = self.get_audio(audiopath)
        return (text, spec, wav)
    
    def get_audio(self, filename):
        audio, sampling_rate = load_wav_to_torch(filename) # read audio.
        
        #if sampling_rate != self.sampling_rate:
        #    raise ValueError("{} {} SR doesn't match target {} SR".format(
        #        sampling_rate, self.sampling_rate))
        audio_norm = audio / self.max_wav_value # normalize
        audio_norm = audio_norm.unsqueeze(0) # add channel
        #spec filename should be the same with audio, with .spec.pt
        spec_filename = filename.replace(".wav", ".spec.pt") 
        if os.path.exists(spec_filename): # skip if already exists
            spec = torch.load(spec_filename)
        else:
            spec = spectrogram_torch(audio_norm, self.filter_length,
                self.sampling_rate, self.hop_length, self.win_length,
                center=False) # read spectrogram from audio, method is at below.
            spec = torch.squeeze(spec, 0)
            torch.save(spec, spec_filename) # save as .spec.pt
        return spec, audio_norm
    def get_text(self, text):
#        if self.cleaned_text:
#            text_norm = cleaned_text_to_sequence(text)
#        else:
        
        text_norm = text_to_sequence(text, self.text_cleaners)
        
        # After cleaning, the text should be looked from
        # Mrs. De Mohrenschildt thought that Oswald,
        # to
        # mɪsˈɛs də mˈoʊɹɪnstʃˌaɪlt θˈɔːt ðæt ˈɑːswəld,
        
        # if self.add_blank:
            # text_norm = commons.intersperse(text_norm, 0)
        
        text_norm = torch.LongTensor(text_norm)
        return text_norm

    # getitem method is called when you call dataset[index]
    def __getitem__(self, index):
        return self.get_audio_text_pair(self.audiopaths_and_text[index])
    # len method is called when you call len(dataset)
    def __len__(self):
        return len(self.audiopaths_and_text)
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
    # after normalizing, y should not be larger than 1 and smaller than -1.
    if torch.min(y) < -1.:
        print('min value is ', torch.min(y))
    if torch.max(y) > 1.:
        print('max value is ', torch.max(y))

    global hann_window
    dtype_device = str(y.dtype) + '_' + str(y.device)
    wnsize_dtype_device = str(win_size) + '_' + dtype_device
    if wnsize_dtype_device not in hann_window:
        # stores hann_window function values.
        # further examples will be in the next cell.
        hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)

    # padding, and will have further explanation in the next cell.
    y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
    y = y.squeeze(1)

    # Short-time Fourier transform (STFT). Converting audio to frequency domain.
    spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
                      center=center, pad_mode='reflect', normalized=False, onesided=True)
    # normalizing the spectrogram, and add 1e-6 in case of log(0)
    spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
    return spec

def load_wav_to_torch(full_path):
  sampling_rate, data = read(full_path)
  return torch.FloatTensor(data.astype(np.float32)), sampling_rate

# This method is to load data, it is obvious that
# we can divide audio and text by "|"
# since the data looked DUMMY1/LJ050-0234.wav|It has used...
def load_filepaths_and_text(filename, split="|"):
  with open(filename, encoding='utf-8') as f:
    filepaths_and_text = [line.strip().split(split) for line in f]
  return filepaths_and_text

def text_to_sequence(text, cleaner_names):
  '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
    Args:
      text: string to convert to a sequence
      cleaner_names: names of the cleaner functions to run the text through
    Returns:
      List of integers corresponding to the symbols in the text
  '''
  sequence = []

  clean_text = _clean_text(text, cleaner_names)
  
  # convert cleaned text to sequence like [1, 3, 5]
  for symbol in clean_text:
    symbol_id = _symbol_to_id[symbol]
    sequence += [symbol_id] 
  return sequence
# function that called cleaner.
def _clean_text(text, cleaner_names):
  for name in cleaner_names:
    #cleaner = getattr(cleaners, name)
    #if not cleaner:
    #  raise Exception('Unknown cleaner: %s' % name)
    #text = cleaner(text)
    
    # call function by string: name
    text = eval(name+'()')(text)
  return text

#### This TextAudioLoader will convert training data into this form:

#### (text, spectrogram, frequency domain)
##### Noted: text is not str, but a sequence like[1, 5, 3]

spectrogram and frequency domain are tensors with shape like (frequency, frames) and (frames).