# Transformer TTS

A implementation of [Transformer TTS](https://ojs.aaai.org/index.php/AAAI/article/view/4642) using PyTorch. A significant portion of the code was developed with reference to [Nvidia Tacotron2](https://github.com/NVIDIA/tacotron2) and [choiHkk Transformer TTS](https://github.com/choiHkk/Transformer-TTS).

<br>

**Paper reference**:  
Li, N., Liu, S., Liu, Y., Zhao, S., & Liu, M. (2019, July). Neural speech synthesis with transformer network. In Proceedings of the AAAI conference on artificial intelligence (Vol. 33, No. 01, pp. 6706-6713).

In [86]:
# pacakages
import os
import math
import librosa
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy import signal

# torch
import torch
import torch.nn as nn
import torch.nn.functional as F

# torchaudio
import torchaudio
import torchaudio.functional as af
import torchaudio.transforms as T

In [87]:
class Hparams:
    # audio
    preemphasis = 0.97
    n_fft = 1024
    hop_length = 256
    win_length = 1024
    n_mels = 80
    f_min = 0.0
    f_max = 8000.0
    ref_level_db = 20.0
    min_level_db = -100.0
    max_abs_value = 4.0


    # train
    epochs = 10
    batch_size = 16
    vis_every = 50

    # model
    n_phonemes = 70
    n_mels = 80
    embedding_dim = 512
    d_model = 256

hp = Hparams()

In [88]:
# visualize mel spectrogram
def visualize(melspec):
    fig, ax = plt.subplots(1, 1, figsize=(12, 3))
    im1 = ax.imshow(melspec, aspect="auto", origin="lower", interpolation='none')
    plt.colorbar(im1, ax=ax)
    plt.tight_layout()
    plt.show()

## Dataset

The dataset used in the experiment was LJSpeech, which can be easily downloaded from torchaudio. To convert text to phonemes and encode it as numbers, we used the text processor from `torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH`. The preprocessing of wave files was referenced from Nvidia's Tacotron2 and Kyubong's Tacotron.

In [89]:
class TransformerTTSDataset(torch.utils.data.Dataset):

    """
    https://github.com/Kyubyong/tacotron/blob/master/utils.py#L21
    """

    def __init__(self):
        self.dataset = torchaudio.datasets.LJSPEECH('.', download=True)
        self.text_preprocess = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH.get_text_processor()

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):

        # get data
        y, sr, _, text = self.dataset[idx]

        # text preprocess
        text, _ = self.text_preprocess(text)

        # audio preprocess
        mel = self.audio_preprocess(y, sr)

        # (text_len), (n_mel, mel_len)
        return text.squeeze(), mel

    def audio_preprocess(self, y, sr):

        # trimming
        y, _ = librosa.effects.trim(y)

        # Preemphasis
        y = af.preemphasis(y, hp.preemphasis)

        # stft
        linear = librosa.stft(y=y.numpy().squeeze(),
                            n_fft=hp.n_fft,
                            hop_length=hp.hop_length,
                            win_length=hp.win_length,
                            pad_mode='constant')

        # magnitude spectrogram
        mag = np.abs(linear)

        # mel spectrogram
        mel_basis = librosa.filters.mel(sr=sr, n_fft=hp.n_fft, n_mels=hp.n_mels, fmin=hp.f_min, fmax=hp.f_max)
        mel = np.dot(mel_basis, mag)

        # to decibel
        mel = 20 * np.log10(np.maximum(1e-5, mel))

        mel = mel - hp.ref_level_db

        # normalize
        max_abs_value = hp.max_abs_value
        min_level_db = hp.min_level_db
        mel = np.clip((2 * max_abs_value) * ((mel - min_level_db) / (-min_level_db)) - max_abs_value,
                      -max_abs_value, max_abs_value)

        return torch.FloatTensor(mel)


class TransformerTTSCollate():

    def __init__(self):
        ...

    def __call__(self, batch):

        # get decreasing order by text length within batch
        text_lengths, ids_sorted_decreasing = torch.sort(
            torch.LongTensor([len(text) for text, _ in batch]),
            dim=0, descending=True
        )

        # all zero padded tensor
        max_text_len = text_lengths[0]
        text_padded = torch.LongTensor(len(batch), max_text_len)
        text_padded.zero_()

        # allocate text to zero padded tensor
        for i in range(len(ids_sorted_decreasing)):
            text = batch[ids_sorted_decreasing[i]][0]
            text_padded[i, :text.size(0)] = text

        # get maximum length of mel sequence within batch
        num_mels = batch[0][1].size(0)
        max_mel_len = max([mel.size(1) for _, mel in batch])

        # all zero padded tensor
        mel_padded = torch.FloatTensor(len(batch), num_mels, max_mel_len)
        mel_padded.zero_()
        gate_padded = torch.FloatTensor(len(batch), max_mel_len)
        gate_padded.zero_()
        mel_lengths = torch.LongTensor(len(batch))

        for i in range(len(ids_sorted_decreasing)):
            mel = batch[ids_sorted_decreasing[i]][1]
            mel_padded[i, :, :mel.size(1)] = mel
            gate_padded[i, mel.size(1) - 1:] = 1
            mel_lengths[i] = mel.size(1)

        return (
            text_padded,
            text_lengths,
            mel_padded.transpose(1, 2),
            gate_padded,
            mel_lengths
        )

In [90]:
dataset = TransformerTTSDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=hp.batch_size, shuffle=True, collate_fn=TransformerTTSCollate())

## Model

### Scaled Positional Encoding

**Page 3**

$$ PE(pos, 2i) = \sin (\frac{pos}{10000^{\frac{2i}{d_{model}}}}) \tag{6} $$

$$ PE(pos, 2i + 1) = \cos (\frac{pos}{10000^{\frac{2i}{d_{model}}}}) \tag{7} $$

<br>

In NMT, the embeddings for both source and target language are from language spaces, so the scales of these embeddings are similar. This condition doesn’t hold in the TTS scenarioe, since the source domain is of texts while the target domain is of mel spectrograms, hence using fixed positional embeddings may impose heavy constraints on both the encoder and decoder pre-nets (which will be described in Sec. 3.3 and 3.4).

In [91]:
class ScaledPositionalEncoding(nn.Module):

    """
    https://pytorch.org/tutorials/beginner/transformer_tutorial
    """

    def __init__(self, d_model, max_len=5000):

        super(ScaledPositionalEncoding, self).__init__()

        # terms for positional embedding
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model))

        # fixed positional embedding
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

        # trainable weight
        self.alpha = nn.Parameter(torch.ones(1))


    def forward(self, x):
        return x + self.alpha * self.pe[:, :x.size(1)]

### Encoder Prenet

**page 3**

The output of each convolution layer has 512 channels, followed by a batch normalization and ReLU activation, and a dropout layer as well. In addition, we add a linear projection after the final ReLU activation, since the output range of ReLU is $[0, +∞)$, while each dimension of these triangle positional embeddings is in $[−1, 1]$.

In [92]:
class EncoderPrenet(nn.Module):

    """
    https://github.com/NVIDIA/tacotron2/blob/master/model.py#L89
    """

    def __init__(self, in_dim=512, sizes=[512, 512, 512], encoder_kernel_size=3):
        super(EncoderPrenet, self).__init__()

        # convolution layer
        in_sizes = [in_dim] + sizes[:-1]
        self.layers = nn.ModuleList(
            [nn.Sequential(
                ConvNorm(in_size * 2, out_size * 2, kernel_size=encoder_kernel_size,
                      padding=int((encoder_kernel_size - 1) / 2),
                      dilation=1, w_init_gain='relu'),
                nn.BatchNorm1d(out_size * 2)
            ) for (in_size, out_size) in zip(in_sizes, sizes)]
        )
        self.projection = LinearNorm(sizes[-1] * 2, sizes[-1])


    def forward(self, x):
        x = x.transpose(1, 2)
        for linear in self.layers:
            x = F.dropout(F.relu(linear(x)), p=0.5, training=True)

        x = self.projection(x.transpose(1, 2))
        return x


class LinearNorm(nn.Module):

    """
    https://github.com/NVIDIA/tacotron2/blob/master/layers.py#L8
    """

    def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
        super(LinearNorm, self).__init__()
        self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias)
        nn.init.xavier_uniform_(
            self.linear_layer.weight,
            gain=nn.init.calculate_gain(w_init_gain)
        )

    def forward(self, x):
        return self.linear_layer(x)

### Decoder Prenet

**page 3-4**

The mel spectrogram is first consumed by a neural network composed of two fully connected layers(each has 256 hidden units) with ReLU activation, named "decoder pre-net", and it plays an important role in the TTS system.

In [93]:
class DecoderPrenet(nn.Module):

    """
    https://github.com/NVIDIA/tacotron2/blob/master/model.py#L89
    """

    def __init__(self, in_dim=80, sizes=[256, 256]):
        super(DecoderPrenet, self).__init__()

        in_sizes = [in_dim] + sizes[:-1]
        self.layers = nn.ModuleList(
            [LinearNorm(in_size, out_size, bias=False)
            for (in_size, out_size) in zip(in_sizes, sizes)]
        )

    def forward(self, x):
        for linear in self.layers:
            x = F.dropout(F.relu(linear(x)), p=0.5, training=True)

        return x

### Postnet

**[Tacotron 2] page 2. Spectrogram Prediction Network**

he predicted mel spectrogram is passed through a $5$-layer convolutional post-net which predicts a residual to add to the prediction to improve the overall reconstruction. Each post-net layer is  comprised of $512$ filters with shape $5 \times 1$ with batch
normalization, followed by tanh activations on all but the final layer.



In [94]:
class PostNet(nn.Module):

    """
    https://github.com/NVIDIA/tacotron2/blob/master/model.py#L103
    """

    def __init__(self, n_mel_channels=80, postnet_embedding_dim=1024,
                 postnet_kernel_size=5, postnet_n_convolutions=5):
        super(PostNet, self).__init__()

        in_channels = [n_mel_channels] + [postnet_embedding_dim] * (postnet_n_convolutions - 1)
        out_channels = [postnet_embedding_dim] * (postnet_n_convolutions - 1) + [n_mel_channels]

        # convolution layers
        self.convolutions = nn.ModuleList(
            [nn.Sequential(
                ConvNorm(in_channel, out_channel,
                         kernel_size=postnet_kernel_size, stride=1,
                         padding=int((postnet_kernel_size - 1) / 2),
                         dilation=1, w_init_gain='tanh'),
                nn.BatchNorm1d(out_channel)
            ) for (in_channel, out_channel) in zip(in_channels, out_channels)]
        )


    def forward(self, x):
        for i in range(len(self.convolutions) - 1):
            x = F.dropout(F.tanh(self.convolutions[i](x)), 0.5, self.training)
        return F.dropout(self.convolutions[-1](x), 0.5, self.training)


class ConvNorm(nn.Module):

    """
    https://github.com/NVIDIA/tacotron2/blob/master/layers.py#L21
    """

    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
                 padding=None, dilation=1, bias=True, w_init_gain='linear'):
        super(ConvNorm, self).__init__()

        # check padding is valid
        if padding is None:
            assert(kernel_size % 2 == 1)
            padding = int(dilation * (kernel_size - 1) / 2)

        # 1d convolution
        self.conv = nn.Conv1d(in_channels, out_channels,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation,
                              bias=bias)

        # weight normalization
        nn.init.xavier_uniform_(
            self.conv.weight,
            gain=nn.init.calculate_gain(w_init_gain)
        )


    def forward(self, x):
        return self.conv(x)


### Transformer TTS

In [95]:
class TransformerTTS(nn.Module):

    """
    https://github.com/choiHkk/Transformer-TTS/blob/main/model.py#L20
    """

    def __init__(self, hp):
        super(TransformerTTS, self).__init__()

        n_phonemes = hp.n_phonemes
        n_mels = hp.n_mels
        d_model = hp.d_model

        # phoneme embedding
        self.n_mels = hp.n_mels
        self.phoneme_embedding = nn.Embedding(n_phonemes, d_model * 2)
        std = math.sqrt(2.0 / (n_phonemes + d_model))
        val = math.sqrt(3.0) * std
        self.phoneme_embedding.weight.data.uniform_(-val, val)

        # encoder pre-net
        self.encoder_prenet = EncoderPrenet(d_model, [d_model] * 3, 3)

        # decoder pre-net
        self.decoder_prenet = DecoderPrenet(n_mels, [d_model, d_model])

        # scaled positional encoding
        self.scaled_positional_encoding = ScaledPositionalEncoding(d_model)

        # transformer encoder layer
        transformer_encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=4, dim_feedforward=1024, dropout=0.1,
            activation="relu", batch_first=True, norm_first=False, layer_norm_eps=1e-5)

        # transformer decoder layer
        transformer_decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, nhead=4, dim_feedforward=1024, dropout=0.1,
            activation="relu", batch_first=True, norm_first=False, layer_norm_eps=1e-5)

        # transformer encoder
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer=transformer_encoder_layer, num_layers=3, norm=None)

        # transformer decoer
        self.transformer_decoder = nn.TransformerDecoder(
            decoder_layer=transformer_decoder_layer, num_layers=3, norm=None)

        # mel linear projection
        self.mel_linear_projection = LinearNorm(d_model, n_mels)

        # stop linear projection
        self.stop_linear_projection = LinearNorm(d_model, 1, w_init_gain='sigmoid')

        # postnet
        self.postnet = PostNet(n_mels, 1024, 5, 5)


    def forward(self, text, text_lengths, mel, mel_lengths):

        self.train()
        self.initialize_masks(text_lengths, mel_lengths)

        # phoneme embedding + encoder prenet
        x = self.encoder_prenet(self.phoneme_embedding(text))

        # positional encoding
        x = self.scaled_positional_encoding(x)

        # transformer encoder
        memory = self.transformer_encoder(
            src=x,
            mask=self.src_mask,
            src_key_padding_mask=self.src_key_padding_mask)

        y = torch.cat([self.get_go_frame(memory).unsqueeze(1), mel[:,:-1,:]], dim=1)
        y = self.scaled_positional_encoding(self.decoder_prenet(y))

        # transformer decoder
        features = self.transformer_decoder(
            tgt=y,
            memory=memory,
            tgt_mask=self.tgt_mask,
            memory_mask=self.memory_mask,
            tgt_key_padding_mask=self.tgt_key_padding_mask,
            memory_key_padding_mask=self.src_key_padding_mask
        )

        # stop linear, mel linear, postnet
        mel = self.mel_linear_projection(features)
        stop = self.stop_linear_projection(features)
        post_mel = self.postnet(mel.transpose(1, 2)).transpose(1, 2) + mel

        # return padding masked output
        return self.parse_output(post_mel, mel, stop)


    def get_go_frame(self, memory):

        # transformer decoder initial input
        batch = memory.size(0)
        decoder_input = torch.autograd.Variable(memory.data.new(batch, self.n_mels).zero_())

        return decoder_input


    def initialize_masks(self, text_lengths, mel_lengths):

        # transformer encoder layer mask
        S = text_lengths.max().item()
        self.src_mask = self.generate_square_subsequent_mask(S, S).to(
            device=text_lengths.device)
        self.src_key_padding_mask = self.generate_padding_mask(text_lengths).to(
            device=text_lengths.device)

        # transformer decoder layer mask
        T = mel_lengths.max().item()
        self.tgt_mask = self.generate_square_subsequent_mask(T, T).to(
            device=mel_lengths.device)
        self.tgt_key_padding_mask = self.generate_padding_mask(mel_lengths).to(
            device=mel_lengths.device)

        # transformer decoder input mask
        self.memory_mask = self.generate_square_subsequent_mask(T, S).to(
            device=mel_lengths.device)


    def parse_output(self, post_mel, mel, stop):
        post_mel.data.masked_fill_(
            self.tgt_key_padding_mask.unsqueeze(-1).repeat(1, 1, post_mel.size(-1)), 0.0)
        mel.data.masked_fill_(
            self.tgt_key_padding_mask.unsqueeze(-1).repeat(1, 1, mel.size(-1)), 0.0)
        stop.data.masked_fill_(
            self.tgt_key_padding_mask.unsqueeze(-1), 1e3)

        return post_mel, mel, stop


    def generate_square_subsequent_mask(self, lsz, rsz):

        """
        https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#Transformer
        """

        return torch.triu(torch.ones(lsz, rsz) * float('-inf'), diagonal=1)


    def generate_padding_mask(self, lengths, max_len=None):

        """
        https://github.com/ming024/FastSpeech2/blob/master/utils/tools.py
        """

        batch_size = lengths.size(0)
        if max_len is None:
            max_len = lengths.max().item()
        ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(
            dtype=lengths.dtype, device=lengths.device)
        return ids >= lengths.unsqueeze(1).expand(-1, max_len)

## Loss function

**Page 4. Mel Linear, Stop Linear and Post-net**

It's worth mentioning that, for the stop linear, there is only one positive sample in the end of each sequence which means ”stop”, while hundreds of negative samples for other frames. This imbalance may result in unstoppable inference. We impose a positive weight ($5.0 ~ 8.0$) on the tail positive stop token when calculating binary cross entropy loss, and this problem was efficiently solved.


In [96]:
class TransformerTTSLoss():

    """
    https://github.com/choiHkk/Transformer-TTS/blob/main/loss_function.py
    """

    def __init__(self):
        self.mel_loss = nn.MSELoss()
        self.stop_loss = nn.BCEWithLogitsLoss()
        self.alpha = 5.0

    def __call__(self, post_mel, mel_pred, mel_target, stop_pred, stop_target):

        # match stop dimension
        stop_pred = stop_pred.view(-1, 1)
        stop_target = stop_target.view(-1, 1)

        # calculate loss
        mel_loss = self.mel_loss(mel_pred, mel_target) + self.mel_loss(post_mel, mel_target)
        stop_loss = self.stop_loss(stop_pred, stop_target)
        return mel_loss + stop_loss * self.alpha

## Train

In [97]:
def visualize_training_mel(exp_dir, mel_pred, mel, epoch, step):

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 6))

    # predicted mel spectrogram
    im1 = ax1.imshow(mel_pred, aspect="auto", origin="lower", interpolation='none')
    plt.colorbar(im1, ax=ax1)
    ax1.set_title("Prediction")
    ax1.set_xlabel("Frames")
    ax1.set_ylabel("Channels")

    # target mel spectrogram
    im2 = ax2.imshow(mel, aspect="auto", origin="lower", interpolation='none')
    plt.colorbar(im2, ax=ax2)
    ax2.set_title("Target")
    ax2.set_xlabel("Frames")
    ax2.set_ylabel("Channels")

    plt.tight_layout()
    exp_dir = os.path.join(exp_dir, f"epoch-{epoch}-step-{step}.png")
    plt.savefig(exp_dir)
    plt.close()

In [98]:
def get_experiment_dir(base_dir: str="./exp"):

    # make base directory
    base_name = "experiment"
    if not os.path.exists(base_dir):
        os.makedirs(base_dir)

    # experiment dir lists
    existing_dirs = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
    experiment_dirs = [d for d in existing_dirs if d.startswith(base_name)]

    # get maximum experiment number
    max_num = 0
    for dir_name in experiment_dirs:
        try:
            num = int(dir_name[len(base_name):])
            if num > max_num:
                max_num = num
        except ValueError:
            continue

    # new directory name / path
    new_dir_name = f"{base_name}{max_num + 1}"
    new_dir_path = os.path.join(base_dir, new_dir_name)
    os.makedirs(new_dir_path)

    return new_dir_path

In [99]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
epochs = hp.epochs

# model
model = TransformerTTS(hp).to(device)

# loss
criterion = TransformerTTSLoss()

# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# make experiment directory
os.makedirs("./exp", exist_ok=True)
experiment_dir = get_experiment_dir("./exp")


for epoch in range(epochs):

    step = 0
    train_tqdm_bar = tqdm(enumerate(dataloader), total=len(dataloader))

    for i, data in train_tqdm_bar:

        # get data
        text, text_len, mel, stop, mel_len = data

        # move to appropriate device
        text = text.to(device)
        text_len = text_len.to(device)
        mel = mel.to(device)
        stop = stop.to(device)
        mel_len = mel_len.to(device)

        # model prediction + loss calculation
        post_mel, mel_pred, stop_pred = model(text, text_len, mel, mel_len)
        loss = criterion(post_mel, mel_pred, mel, stop_pred, stop)
        train_tqdm_bar.set_postfix(loss=loss.item())

        # back propgation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # predicted mel, target mel visualization
        if i % hp.vis_every == 0:

            # first data
            _mel_pred = mel_pred[0].transpose(1, 0)
            _mel_targ = mel[0].transpose(1, 0)

            # detach from gpu
            _mel_pred = _mel_pred.detach().cpu().numpy()
            _mel_targ = _mel_targ.detach().cpu().numpy()

            # save fig
            visualize_training_mel(experiment_dir, _mel_pred, _mel_targ,
                epoch, step)

        # increment step
        step = step + 1



100%|██████████| 819/819 [16:16<00:00,  1.19s/it, loss=0.562]
100%|██████████| 819/819 [15:57<00:00,  1.17s/it, loss=0.469]
100%|██████████| 819/819 [15:16<00:00,  1.12s/it, loss=0.465]
100%|██████████| 819/819 [15:17<00:00,  1.12s/it, loss=0.309]
100%|██████████| 819/819 [15:38<00:00,  1.15s/it, loss=0.385]
100%|██████████| 819/819 [15:23<00:00,  1.13s/it, loss=0.379]
100%|██████████| 819/819 [15:11<00:00,  1.11s/it, loss=0.33]
100%|██████████| 819/819 [14:59<00:00,  1.10s/it, loss=0.292]
100%|██████████| 819/819 [15:02<00:00,  1.10s/it, loss=0.271]
100%|██████████| 819/819 [15:00<00:00,  1.10s/it, loss=0.306]
