In [6]:
!pip install --upgrade pip

[0m

In [None]:
!pip install torch
!pip install speechbrain
!pip install transformers
!pip install tgt

In [None]:
!pip install --upgrade torchvision torchaudio

### PyTorch Modules

In [18]:
%%file /notebooks/models/TransformerTTS.py
# @title Bringing it all in one file in speechbrain

"""
Neural network modules for the Tacotron2 end-to-end neural
Text-to-Speech (TTS) model

Authors
* Salman Hussain Ali 2024
"""
import math
# This code uses a significant portion of the NVidia implementation, even though it
# has been modified and enhanced

# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/tacotron2/model.py
# *****************************************************************************
#  Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#      * Redistributions of source code must retain the above copyright
#        notice, this list of conditions and the following disclaimer.
#      * Redistributions in binary form must reproduce the above copyright
#        notice, this list of conditions and the following disclaimer in the
#        documentation and/or other materials provided with the distribution.
#      * Neither the name of the NVIDIA CORPORATION nor the
#        names of its contributors may be used to endorse or promote products
#        derived from this software without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
#  ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
#  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#  DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
#  DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
#  (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
#  LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
#  ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
#  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
#  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# *****************************************************************************

from math import sqrt

from speechbrain.lobes.models.transformer.Transformer import get_mask_from_lengths
import torch
import torchaudio
from torch import nn
from torch.nn import functional as F
from collections import namedtuple


class LinearNorm(torch.nn.Module):
    """A linear layer with Xavier initialization

    Arguments
    ---------
    in_dim: int
        the input dimension
    out_dim: int
        the output dimension
    bias: bool
        whether or not to use a bias
    w_init_gain: linear
        the weight initialization gain type (see torch.nn.init.calculate_gain)

    Example
    -------
    >>> import torch
    >>> from speechbrain.lobes.models.Tacotron2 import LinearNorm
    >>> layer = LinearNorm(in_dim=5, out_dim=3)
    >>> x = torch.randn(3, 5)
    >>> y = layer(x)
    >>> y.shape
    torch.Size([3, 3])
    """

    def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
        super().__init__()
        self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)

        torch.nn.init.xavier_uniform_(
            self.linear_layer.weight,
            gain=torch.nn.init.calculate_gain(w_init_gain),
        )

    def forward(self, x):
        """Computes the forward pass

        Arguments
        ---------
        x: torch.Tensor
            a (batch, features) input tensor


        Returns
        -------
        output: torch.Tensor
            the linear layer output

        """
        return self.linear_layer(x)


class ConvNorm(torch.nn.Module):
    """A 1D convolution layer with Xavier initialization

    Arguments
    ---------
    in_channels: int
        the number of input channels
    out_channels: int
        the number of output channels
    kernel_size: int
        the kernel size
    stride: int
        the convolutional stride
    padding: int
        the amount of padding to include. If not provided, it will be calculated
        as dilation * (kernel_size - 1) / 2
    dilation: int
        the dilation of the convolution
    bias: bool
        whether or not to use a bias
    w_init_gain: linear
        the weight initialization gain type (see torch.nn.init.calculate_gain)

    Example
    -------
    >>> import torch
    >>> from speechbrain.lobes.models.Tacotron2 import ConvNorm
    >>> layer = ConvNorm(in_channels=10, out_channels=5, kernel_size=3)
    >>> x = torch.randn(3, 10, 5)
    >>> y = layer(x)
    >>> y.shape
    torch.Size([3, 5, 5])
    """

    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size=1,
            stride=1,
            padding=None,
            dilation=1,
            bias=True,
            w_init_gain="linear",
    ):
        super().__init__()
        if padding is None:
            assert kernel_size % 2 == 1
            padding = int(dilation * (kernel_size - 1) / 2)

        self.conv = torch.nn.Conv1d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            bias=bias,
        )

        torch.nn.init.xavier_uniform_(
            self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
        )

    def forward(self, signal):
        """Computes the forward pass

        Arguments
        ---------
        signal: torch.Tensor
            the input to the convolutional layer

        Returns
        -------
        output: torch.Tensor
            the output
        """
        return self.conv(signal)


class EncoderPreNetBlock(nn.Module):
    """
        A block for preprocessing text inputs in an encoder network, consisting of a 1D convolutional layer followed by batch normalization, ReLU activation, dropout, and linear transformation.

        Arguments
        ---------
        in_dim: int
            Number of input channels.
        kernel_size: int
            Size of the convolutional kernel.
        stride: int
            Stride of the convolution.
        padding: int, optional
            Amount of padding to include. If not provided, it will be calculated as dilation * (kernel_size - 1) / 2.
        dilation: int, optional
            Dilation factor of the convolution.
        bias: bool, optional
            Whether to include bias in the convolutional layer.
        dropout: float, optional
            Dropout probability.

        Example
        -------
        >>> import torch
        >>> from speechbrain.lobes.models.TransformerTTS import EncoderPreNetBlock
        >>> block = EncoderPreNetBlock(in_dim=10, kernel_size=3)
        >>> x = torch.randn(3, 10, 5)  # Input tensor shape: (batch_size, in_dim, sequence_length)
        >>> y = block(x)
        >>> y.shape  # Output shape: (batch_size, in_dim, sequence_length)
        torch.Size([3, 10, 5])
        """

    def __init__(
            self,
            embed_dim=512,
            kernel_size=1,
            stride=1,
            padding=None,
            dilation=1,
            bias=True,
            dropout=0.5
    ):
        super().__init__()
        if padding is None:
            assert kernel_size % 2 == 1
            padding = int(dilation * (kernel_size - 1) / 2)

        self.conv = ConvNorm(
            in_channels=embed_dim,
            out_channels=embed_dim,
            kernel_size=kernel_size
        )

        self.batch = nn.BatchNorm1d(embed_dim)

        self.dropout = nn.Dropout(dropout)

        self.activation = torch.nn.ReLU()

    def forward(self, text):
        """Computes the forward pass

        Arguments
        ---------
        text: torch.Tensor
            the input to the convolutional layer

        Returns
        -------
        output: torch.Tensor
            the output
        """
        # text = text.transpose(1, 2)
        out = self.conv(text)
        out = self.batch(out)
        out = self.activation(out)
        out = self.dropout(out)
        return out


class EncoderPrenet(nn.Module):
    """The Tacotron pre-net module consisting of a specified number of
    normalized (Xavier-initialized) linear layers

    Arguments
    ---------
    in_dim: int
        the input dimensions
    sizes: int
        the dimension of the hidden layers/output
    dropout: float
        the dropout probability

    Example
    -------
    >>> import torch
    >>> from speechbrain.lobes.models.Tacotron2 import Prenet
    >>> layer = Prenet()
    >>> x = torch.randn(862, 2, 80)
    >>> output = layer(x)
    >>> output.shape
    torch.Size([862, 2, 256])
    """

    def __init__(
            self,
            emb_dim=512,
            hidden_dim=256,
            kernel_size=1,
            stride=1,
            padding=None,
            dilation=1,
            bias=True,
            num_layers=3,
            dropout=0.1):
        super().__init__()

        self.layers = nn.ModuleList(
            [EncoderPreNetBlock(emb_dim, kernel_size, stride, padding, dilation, bias) for i in range(num_layers)]
        )

        self.dropout = nn.Dropout(dropout)

        self.linear = nn.Linear(emb_dim, hidden_dim)

    def forward(self, x):
        """Computes the forward pass for the prenet

        Arguments
        ---------
        x: torch.Tensor
            the prenet inputs

        Returns
        -------
        output: torch.Tensor
            the output
        """
        # x = x.transpose(1,2) TODO review
        for layer in self.layers:
            x = layer(x)
        x = x.transpose(1, 2)
        x = self.linear(x)
        return x


class DecoderPrenet(nn.Module):
    """The Tacotron pre-net module consisting of a specified number of
        normalized (Xavier-initialized) linear layers

        Arguments
        ---------
        in_dim: int
            the input dimensions
        sizes: int
            the dimension of the hidden layers/output
        dropout: float
            the dropout probability

        Example
        -------
        >>> import torch
        >>> from speechbrain.lobes.models.TransformerTTS import DecoderPrenet
        >>> layer = DecoderPrenet()
        >>> x = torch.randn(862, 2, 80)
        >>> output = layer(x)
        >>> output.shape
        torch.Size([862, 2, 256])
        """

    def __init__(
            self,
            mel_dims=80,
            hidden_dims=256,
            d_model=512,
            dropout=0.5):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Linear(mel_dims, hidden_dims),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dims, d_model),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        self.linear = nn.Linear(hidden_dims, d_model)

    def forward(self, mel):
        """Computes the forward pass for the prenet

        Arguments
        ---------
        x: torch.Tensor
            the prenet inputs

        Returns
        -------
        output: torch.Tensor
            the output
        """
        out = self.layers(mel)
        out = self.linear(out)
        return out


class Postnet(nn.Module):
    """The TransformerTTS postnet consists of a number of 1-d convolutional layers
    with Xavier initialization and a tanh activation, with batch normalization.
    Depending on configuration, the postnet may either refine the MEL spectrogram
    or upsample it to a linear spectrogram. It has the same architecture as Tacotron2's Post-Net

    Arguments
    ---------
    n_mel_channels: int
        the number of MEL spectrogram channels
    postnet_embedding_dim: int
        the postnet embedding dimension
    postnet_kernel_size: int
        the kernel size of the convolutions within the decoders
    postnet_n_convolutions: int
        the number of convolutions in the postnet

    Example
    -------
    >>> import torch
    >>> from speechbrain.lobes.models.Tacotron2 import Postnet
    >>> layer = Postnet()
    >>> x = torch.randn(2, 80, 861)
    >>> output = layer(x)
    >>> output.shape
    torch.Size([2, 80, 861])
    """

    def __init__(
            self,
            n_mel_channels=80,
            postnet_embedding_dim=512,
            postnet_kernel_size=5,
            postnet_n_convolutions=5,
    ):
        super().__init__()
        self.convolutions = nn.ModuleList()

        self.convolutions.append(
            nn.Sequential(
                ConvNorm(
                    n_mel_channels,
                    postnet_embedding_dim,
                    kernel_size=postnet_kernel_size,
                    stride=1,
                    padding=int((postnet_kernel_size - 1) / 2),
                    dilation=1,
                    w_init_gain="tanh",
                ),
                nn.BatchNorm1d(postnet_embedding_dim),
            )
        )

        for i in range(1, postnet_n_convolutions - 1):
            self.convolutions.append(
                nn.Sequential(
                    ConvNorm(
                        postnet_embedding_dim,
                        postnet_embedding_dim,
                        kernel_size=postnet_kernel_size,
                        stride=1,
                        padding=int((postnet_kernel_size - 1) / 2),
                        dilation=1,
                        w_init_gain="tanh",
                    ),
                    nn.BatchNorm1d(postnet_embedding_dim),
                )
            )

        self.convolutions.append(
            nn.Sequential(
                ConvNorm(
                    postnet_embedding_dim,
                    n_mel_channels,
                    kernel_size=postnet_kernel_size,
                    stride=1,
                    padding=int((postnet_kernel_size - 1) / 2),
                    dilation=1,
                    w_init_gain="linear",
                ),
                nn.BatchNorm1d(n_mel_channels),
            )
        )
        self.n_convs = len(self.convolutions)

    def forward(self, x):
        """Computes the forward pass of the postnet

        Arguments
        ---------
        x: torch.Tensor
            the postnet input (usually a MEL spectrogram)

        Returns
        -------
        output: torch.Tensor
            the postnet output (a refined MEL spectrogram or a
            linear spectrogram depending on how the model is
            configured)
        """
        i = 0
        for conv in self.convolutions:
            if i < self.n_convs - 1:
                x = F.dropout(torch.tanh(conv(x)), 0.5, training=self.training)
            else:
                x = F.dropout(conv(x), 0.5, training=self.training)
            i += 1

        return x


class TransformerTTS(nn.Module):
    """The Transformer text-to-speech model, based on the NVIDIA implementation.

    This class is the main entry point for the model, which is responsible
    for instantiating all submodules, which, in turn, manage the individual
    neural network layers

    Simplified STRUCTURE: input->word embedding ->encoder ->attention \
    ->decoder(+prenet) -> postnet ->output

    prenet(input is decoder previous time step) output is input to decoder
    concatenated with the attention output

    Arguments
    ---------
    mask_padding: bool
        whether or not to mask pad-outputs of tacotron
    n_mel_channels: int
        number of mel channels for constructing spectrogram
    n_symbols:  int=128
        number of accepted char symbols defined in textToSequence
    symbols_embedding_dim: int
        number of embedding dimension for symbols fed to nn.Embedding
    encoder_prenet_kernel_size: int
        size of kernel processing the embeddings
    encoder_prenet_n_convolutions: int
        number of convolution layers in encoder
    encoder_prenet_embedding_dim: int
        number of kernels in encoder, this is also the dimension
        of the bidirectional LSTM in the encoder
    n_frames_per_step: int=1
        only 1 generated mel-frame per step is supported for the decoder as of now.
    decoder_rnn_dim: int
        number of 2 unidirectional stacked LSTM units
    prenet_dim: int
        dimension of linear prenet layers
    max_decoder_steps: int
        maximum number of steps/frames the decoder generates before stopping
    gate_threshold: int
        cut off level any output probability above that is considered
        complete and stops generation so we have variable length outputs
    p_attention_dropout: float
        attention drop out probability
    p_decoder_dropout: float
        decoder drop  out probability
    postnet_embedding_dim: int
        number os postnet dfilters
    postnet_kernel_size: int
        1d size of posnet kernel
    postnet_n_convolutions: int
        number of convolution layers in postnet
    decoder_no_early_stopping: bool
        determines early stopping of decoder
        along with gate_threshold . The logical inverse of this is fed to the decoder

    Example
    -------
    >>> import torch
    >>> _ = torch.manual_seed(213312)
    >>> from speechbrain.lobes.models.Tacotron2 import Tacotron2
    >>> model = Tacotron2(
    ...    mask_padding=True,
    ...    n_mel_channels=80,
    ...    n_symbols=148,
    ...    symbols_embedding_dim=512,
    ...    encoder_prenet_kernel_size=5,
    ...    encoder_prenet_n_convolutions=3,
    ...    encoder_prenet_embedding_dim=512,
    ...    attention_rnn_dim=1024,
    ...    attention_dim=128,
    ...    attention_location_n_filters=32,
    ...    attention_location_kernel_size=31,
    ...    n_frames_per_step=1,
    ...    decoder_rnn_dim=1024,
    ...    prenet_dim=256,
    ...    max_decoder_steps=32,
    ...    gate_threshold=0.5,
    ...    p_attention_dropout=0.1,
    ...    p_decoder_dropout=0.1,
    ...    postnet_embedding_dim=512,
    ...    postnet_kernel_size=5,
    ...    postnet_n_convolutions=5,
    ...    decoder_no_early_stopping=False
    ... )
    >>> _ = model.eval()
    >>> inputs = torch.tensor([
    ...     [13, 12, 31, 14, 19],
    ...     [31, 16, 30, 31, 0],
    ... ])
    >>> input_lengths = torch.tensor([5, 4])
    >>> outputs, output_lengths, alignments = model.infer(inputs, input_lengths)
    >>> outputs.shape, output_lengths.shape, alignments.shape
    (torch.Size([2, 80, 1]), torch.Size([2]), torch.Size([2, 1, 5]))
    """

    def __init__(
            self,
            mask_padding=True,
            # mel generation parameter in data io
            n_mel_channels=80,
            # symbols
            n_symbols=52,
            symbols_embedding_dim=512,
            # Encoder Pre-Net parameters
            encoder_prenet_kernel_size=5,
            encoder_prenet_n_convolutions=3,
            encoder_prenet_padding=None,
            encoder_prenet_dilation=1,
            encoder_prenet_bias=True,
            encoder_prenet_dropout=0.5,
            encoder_prenet_stride=1,
            # Decoder Pre-Net parameters
            n_frames_per_step=1,
            decoder_prenet_hidden_dims=256,
            decoder_prenet_dropout=0.15,
            # Transformer Parameters
            d_model=256,
            transformer_nhead=8,
            transformer_num_encoder_layers=6,
            transformer_num_decoder_layers=6,
            transformer_d_ffn=2048,
            transformer_dropout=0.1,
            transformer_activation="relu",
            custom_encoder_module=None,
            custom_decoder_module=None,
            batch_first=False,
            norm_first=False,
            layer_norm_eps=1e-5,
            # Mel-post processing network parameters
            postnet_embedding_dim=512,
            postnet_kernel_size=5,
            postnet_n_convolutions=5,
            gate_threshold=0.5,
            max_decoder_steps=32,
            early_stopping=False,
            padding_idx=0
    ):
        super().__init__()
        self.mask_padding = mask_padding
        self.n_mel_channels = n_mel_channels
        self.n_frames_per_step = n_frames_per_step
        self.max_decoder_steps = max_decoder_steps
        self.early_stopping = early_stopping

        self.gate_threshold = gate_threshold

        self.encoder_embedding = nn.Embedding(n_symbols, symbols_embedding_dim, padding_idx=padding_idx)

        std = sqrt(2.0 / (n_symbols + d_model))
        val = sqrt(3.0) * std  # uniform bounds for std
        self.encoder_embedding.weight.data.uniform_(-val, val)

        if custom_encoder_module is None:
            encoder_block = torch.nn.TransformerEncoderLayer(d_model=d_model,
                                                         nhead=transformer_nhead,
                                                         dim_feedforward=transformer_d_ffn,
                                                         dropout=transformer_dropout,
                                                         activation=transformer_activation,
                                                         batch_first=batch_first,
                                                         norm_first=norm_first,
                                                         layer_norm_eps=layer_norm_eps)
        else:
            encoder_block = custom_encoder_module

        if custom_decoder_module is None:
            decoder_block = torch.nn.TransformerDecoderLayer(d_model=d_model,
                                                         nhead=transformer_nhead,
                                                         dim_feedforward=transformer_d_ffn,
                                                         dropout=transformer_dropout,
                                                         activation=transformer_activation,
                                                         batch_first=batch_first,
                                                         norm_first=norm_first,
                                                         layer_norm_eps=layer_norm_eps)
        else:
            decoder_block = custom_decoder_module

        self.transformer_encoder = nn.TransformerEncoder(encoder_block, num_layers=transformer_num_encoder_layers)

        self.transformer_decoder = nn.TransformerDecoder(decoder_block, num_layers=transformer_num_decoder_layers)

        self.encoder_prenet = EncoderPrenet(
            emb_dim=symbols_embedding_dim,
            hidden_dim=d_model,
            kernel_size=encoder_prenet_kernel_size,
            num_layers=encoder_prenet_n_convolutions,
            stride=encoder_prenet_stride,
            padding=encoder_prenet_padding,
            dilation=encoder_prenet_dilation,
            bias=encoder_prenet_bias,
            dropout=encoder_prenet_dropout)

        self.decoder_prenet = DecoderPrenet(
            mel_dims=n_mel_channels,
            hidden_dims=decoder_prenet_hidden_dims,
            dropout=decoder_prenet_dropout,
            d_model=d_model
        )

        self.postnet = Postnet(
            n_mel_channels=n_mel_channels,
            postnet_embedding_dim=postnet_embedding_dim,
            postnet_kernel_size=postnet_kernel_size,
            postnet_n_convolutions=postnet_n_convolutions,
        )

        self.encoder_positional_encoding = PositionalEncoding(input_size=d_model)
        self.decoder_positional_encoding = PositionalEncoding(input_size=d_model)

        self.mel_linear = LinearNorm(d_model, n_mel_channels)
        self.stop_linear = LinearNorm(d_model, 1, w_init_gain='sigmoid')

    def parse_output(self, outputs, output_lengths):
        """
        Masks the padded part of output

        Arguments
        ---------
        outputs: list
            a list of tensors - raw outputs
        output_lengths: torch.Tensor
            a tensor representing the lengths of all outputs

        Returns
        -------
        mel_outputs: torch.Tensor
        mel_outputs_postnet: torch.Tensor
        gate_outputs: torch.Tensor
        """
        mel_outputs, mel_outputs_postnet, gate_outputs = outputs

        if self.mask_padding and output_lengths is not None:
            mask = get_mask_from_lengths(
                output_lengths, max_len=mel_outputs.size(-1)
            )
            mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
            mask = mask.permute(1, 0, 2)

            mel_outputs.clone().masked_fill_(mask, 0.0)
            mel_outputs_postnet.masked_fill_(mask, 0.0)
            gate_outputs.masked_fill_(mask[:, 0, :], 1e3)  # gate energies

        return mel_outputs, mel_outputs_postnet, gate_outputs

    def parse_decoder_outputs(self, mel_outputs, gate_outputs):
        """Prepares decoder outputs for output

        Arguments
        ---------
        mel_outputs: torch.Tensor
            MEL-scale spectrogram outputs
        gate_outputs: torch.Tensor
            gate output energies

        Returns
        -------
        mel_outputs: torch.Tensor
            MEL-scale spectrogram outputs
        gate_outputs: torch.Tensor
            gate output energies
        """
        # (T_out, B) -> (B, T_out)
        if gate_outputs.dim() == 1:
            gate_outputs = gate_outputs.unsqueeze(0)
        else:
            gate_outputs = gate_outputs.transpose(0, 1).contiguous()

        # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
        mel_outputs = mel_outputs.transpose(0, 1).contiguous()
        # decouple frames per step
        shape = (mel_outputs.shape[0], -1, self.n_mel_channels)
        mel_outputs = mel_outputs.view(*shape)
        # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
        mel_outputs = mel_outputs.transpose(1, 2)

        return mel_outputs, gate_outputs

    def get_go_frame(self, memory):
        """Gets all zeros frames to use as first decoder input

        Arguments
        ---------
        memory: torch.Tensor
            decoder outputs

        Returns
        -------
        decoder_input: torch.Tensor
            all zeros frames
        """
        B = memory.size(0)
        dtype = memory.dtype
        device = memory.device
        decoder_input = torch.zeros(
            B,
            self.n_mel_channels * self.n_frames_per_step,
            dtype=dtype,
            device=device,
        )
        return decoder_input

    def forward(self, inputs, masks):
        """Decoder forward pass for training

        Arguments
        ---------
        inputs: tuple
            batch object
        alignments_dim: int
            the desired dimension of the alignments along the last axis
            Optional but needed for data-parallel training

        Returns
        -------
        mel_outputs: torch.Tensor
            mel outputs from the decoder
        mel_outputs_postnet: torch.Tensor
            mel outputs from postnet
        gate_outputs: torch.Tensor
            gate outputs from the decoder
        output_lengths: torch.Tensor
            length of the output without padding
        """
        inputs, input_lengths, mel_padded, max_len, mel_len = inputs  # (text_padded, input_lengths, mel_padded, max_len, output_lengths)
        tgt_mask, src_mask, src_key_padding_mask, tgt_key_padding_mask = masks

        # Generate Embeddings
        embedded_inputs = self.encoder_embedding(inputs).transpose(1, 2)

        # Pass through encoder pre-net
        encoder_prenet_outputs = self.encoder_prenet(embedded_inputs)

        # Pass through decoder pre-net
        decoder_prenet_outputs = self.decoder_prenet(mel_padded.transpose(1, 2))
        
        encoder_prenet_outputs = encoder_prenet_outputs.transpose(0, 1)
        decoder_prenet_outputs = decoder_prenet_outputs.transpose(0, 1)

        # Add Scaled Positional Embeddings to encoder and decoder pre-nets
        encoder_prenet_outputs = encoder_prenet_outputs + self.encoder_positional_encoding(encoder_prenet_outputs)
        decoder_prenet_outputs = decoder_prenet_outputs + self.decoder_positional_encoding(decoder_prenet_outputs)

        # Input embedded phonemes into transformer's encoder
        memory = self.transformer_encoder(
            src=encoder_prenet_outputs,
            src_key_padding_mask=src_key_padding_mask,
        )

        # Input memory and mel spectograms into transformer's decoder
        mel_outputs = self.transformer_decoder(tgt=decoder_prenet_outputs,
                                               memory=memory,
                                               tgt_mask=tgt_mask,
                                               tgt_key_padding_mask=tgt_key_padding_mask,
                                                memory_key_padding_mask = src_key_padding_mask)

        # Calculate mel linear, mel stop, and post-net
        # Stop Linear
        stop_token = self.stop_linear(mel_outputs).squeeze()

        # Mel Linear and Post-Net
        mel_outputs = self.mel_linear(mel_outputs)

        mel_outputs, stop_token = self.parse_decoder_outputs(mel_outputs, stop_token)

        mel_outputs_postnet = self.postnet(mel_outputs)

        mel_outputs_postnet = mel_outputs + mel_outputs_postnet

        return self.parse_output(
            [mel_outputs, mel_outputs_postnet, stop_token], mel_len)

    def infer(self, inputs, input_lengths):
        """Produces outputs


        Arguments
        ---------
        inputs: torch.tensor
            text or phonemes converted

        input_lengths: torch.tensor
            the lengths of input parameters

        Returns
        -------
        mel_outputs_postnet: torch.Tensor
            final mel output of tacotron 2
        mel_lengths: torch.Tensor
            length of mels
        alignments: torch.Tensor
            sequence of attention weights
        """

        # Generate Embeddings
        embedded_inputs = self.encoder_embedding(inputs).transpose(1, 2)

        # Pass through encoder pre-net
        encoder_prenet_outputs = self.encoder_prenet(embedded_inputs)

        # Add Scaled Positional Embeddings to encoder and decoder pre-nets
        encoder_prenet_outputs = encoder_prenet_outputs + self.encoder_positional_encoding(encoder_prenet_outputs)

        encoder_prenet_outputs = encoder_prenet_outputs.transpose(0, 1)

        # Input embedded phonemes into transformer's encoder
        src_key_padding_mask = get_mask_from_lengths(input_lengths).to(inputs.device, non_blocking=True)
        memory = self.transformer_encoder(
            src=encoder_prenet_outputs,
            src_key_padding_mask=src_key_padding_mask,
        )

        decoder_input = self.get_go_frame(memory)
        print(f"Decoder Input Shape:{decoder_input.shape}")

        mask = get_mask_from_lengths(input_lengths)

        mel_lengths = torch.zeros(
            [memory.size(0)], dtype=torch.int32, device=memory.device
        )
        not_finished = torch.ones(
            [memory.size(0)], dtype=torch.int32, device=memory.device
        )

        mel_outputs, gate_outputs = (
            torch.zeros(1),
            torch.zeros(1),
        )
        #print(f"Memory shape: {memory.shape}")

        first_iter = True
        while True:
            decoder_input = self.decoder_prenet(decoder_input)
            if len(decoder_input.shape) != 3:
                decoder_input = decoder_input.unsqueeze(1)
            
            #print(f"decoder_input shape: {decoder_input.shape}")
            decoder_output = self.transformer_decoder(memory=memory, tgt=decoder_input)
            #print(f"Decoder Output Shape: {decoder_output.shape}")

            # Calculate mel linear and stop token
            # Stop Linear
            gate_output = self.stop_linear(decoder_output).squeeze()
            #print(f"Stop Linear Output Shape: {gate_output.shape}")

            # Mel Linear and Post-Net
            mel_output = self.mel_linear(decoder_output)
            #print(f"Mel Linear Output Shape: {mel_output.shape}")

            if first_iter:
                mel_outputs = mel_output
                gate_outputs = gate_output
                first_iter = False
            else:
                mel_outputs = torch.cat(
                    (mel_outputs, mel_output), dim=0
                )
                gate_outputs = torch.cat((gate_outputs, gate_output), dim=0)
            #print(f"Mel Linear Output Shape: {mel_output.shape}")
            #print(f"Mel_Outputs Shape: {mel_outputs.shape}")
            
            sigmoid = torch.sigmoid(gate_output).unsqueeze(1)

            dec = (
                torch.le(sigmoid, torch.tensor(self.gate_threshold))
                .to(torch.int32)
                .squeeze(1)
            )

            not_finished = not_finished * dec
            mel_lengths += not_finished
            if self.early_stopping and torch.sum(not_finished) == 0:
                break
            if len(mel_outputs) == self.max_decoder_steps:
                break
            
            #print(f"Pre_Assignment Output Shape: {mel_output.shape}")
            decoder_input = mel_output

        mel_outputs, gate_outputs = self.parse_decoder_outputs(
            mel_outputs, gate_outputs
        )

        return mel_outputs, gate_outputs

def infer(model, text_sequences, input_lengths):
    """
    An inference hook for pretrained synthesizers

    Arguments
    ---------
    model: TransformerTTS
        the tacotron model
    text_sequences: torch.Tensor
        encoded text sequences
    input_lengths: torch.Tensor
        input lengths

    Returns
    -------
    result: tuple
        (mel_outputs_postnet, mel_lengths) - the exact
        model output
    """
    return model.infer(text_sequences, input_lengths)


LossStats = namedtuple(
    "TransformerLoss", "loss mel_loss gate_loss"
)


class Loss(nn.Module):
    """The TransformerTTS loss implementation based on Tacotron2

    The loss consists of an MSE loss on the spectrogram and a BCE gate loss

    The output of the module is a LossStats tuple, which includes both the
    total loss

    Arguments
    ---------
    gate_loss_weight: float
        The constant by which the gate loss will be multiplied. In the paper, it is 5.0 ~ 8.0

    Example
    -------
    >>> import torch
    >>> _ = torch.manual_seed(42)
    >>> from speechbrain.lobes.models.Tacotron2 import Loss
    >>> loss = Loss(guided_attention_sigma=0.2)
    >>> mel_target = torch.randn(2, 80, 861)
    >>> gate_target = torch.randn(1722, 1)
    >>> mel_out = torch.randn(2, 80, 861)
    >>> mel_out_postnet = torch.randn(2, 80, 861)
    >>> gate_out = torch.randn(2, 861)
    >>> alignments = torch.randn(2, 861, 173)
    >>> targets = mel_target, gate_target
    >>> model_outputs = mel_out, mel_out_postnet, gate_out, alignments
    >>> input_lengths = torch.tensor([173,  91])
    >>> target_lengths = torch.tensor([861, 438])
    >>> loss(model_outputs, targets, input_lengths, target_lengths, 1)
    TacotronLoss(loss=tensor(4.8566), mel_loss=tensor(4.0097), gate_loss=tensor(0.8460), attn_loss=tensor(0.0010), attn_weight=tensor(1.))
    """

    def __init__(
            self,
            gate_loss_weight=5.0
    ):
        super().__init__()
        self.mse_loss = nn.MSELoss()
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.gate_loss_weight = gate_loss_weight

    def forward(
            self, model_output, targets
    ):
        """Computes the loss

        Arguments
        ---------
        model_output: tuple
            the output of the model's forward():
            (mel_outputs, mel_outputs_postnet, gate_outputs, alignments)
        targets: tuple
            the targets

        Returns
        -------
        result: LossStats
            the total loss - and individual losses (mel and gate)

        """
        mel_target, gate_target = targets[0], targets[1]
        mel_target.requires_grad = False
        gate_target.requires_grad = False
        gate_target = gate_target.view(-1, 1)

        mel_out, mel_out_postnet, gate_out = model_output

        gate_out = gate_out.view(-1, 1)

        mel_loss = self.mse_loss(mel_out, mel_target) + self.mse_loss(
            mel_out_postnet, mel_target
        )

        gate_loss = self.gate_loss_weight * self.bce_loss(gate_out, gate_target)  # Applying weight to stop token loss
        
        total_loss = mel_loss + gate_loss
        return LossStats(
            total_loss, mel_loss, gate_loss
        )


class TextMelCollate:
    """Zero-pads model inputs and targets based on number of frames per step

    Arguments
    ---------
    n_frames_per_step: int
        the number of output frames per step
    """

    def __init__(self, n_frames_per_step=1):
        self.n_frames_per_step = n_frames_per_step

    # TODO: Make this more intuitive, use the pipeline
    def __call__(self, batch):
        """Collate's training batch from normalized text and mel-spectrogram

        Arguments
        ---------
        batch: list
            [text_normalized, mel_normalized]

        Returns
        -------
        text_padded: torch.Tensor
        input_lengths: torch.Tensor
        mel_padded: torch.Tensor
        gate_padded: torch.Tensor
        output_lengths: torch.Tensor
        len_x: torch.Tensor
        labels: torch.Tensor
        wavs: torch.Tensor
        """

        # TODO: Remove for loops and this dirty hack
        raw_batch = list(batch)
        for i in range(
                len(batch)
        ):  # the pipeline return a dictionary with one element
            batch[i] = batch[i]["mel_text_pair"]

        # Right zero-pad all one-hot text sequences to max input length
        input_lengths, ids_sorted_decreasing = torch.sort(
            torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True
        )
        max_input_len = input_lengths[0]

        text_padded = torch.LongTensor(len(batch), max_input_len)
        text_padded.zero_()
        for i in range(len(ids_sorted_decreasing)):
            text = batch[ids_sorted_decreasing[i]][0]
            text_padded[i, : text.size(0)] = text

        # Right zero-pad mel-spec
        num_mels = batch[0][1].size(0)
        max_target_len = max([x[1].size(1) for x in batch])
        if max_target_len % self.n_frames_per_step != 0:
            max_target_len += (
                    self.n_frames_per_step - max_target_len % self.n_frames_per_step
            )
            assert max_target_len % self.n_frames_per_step == 0

        # include mel padded and gate padded
        mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
        mel_padded.zero_()
        gate_padded = torch.FloatTensor(len(batch), max_target_len)
        gate_padded.zero_()
        output_lengths = torch.LongTensor(len(batch))
        labels, wavs = [], []
        for i in range(len(ids_sorted_decreasing)):
            idx = ids_sorted_decreasing[i]
            mel = batch[idx][1]
            mel_padded[i, :, : mel.size(1)] = mel
            gate_padded[i, mel.size(1) - 1:] = 1
            output_lengths[i] = mel.size(1)
            labels.append(raw_batch[idx]["label"])
            wavs.append(raw_batch[idx]["wav"])

        # count number of items - characters in text
        len_x = [x[2] for x in batch]
        len_x = torch.Tensor(len_x)

        return (
            text_padded,
            input_lengths,
            mel_padded,
            gate_padded,
            output_lengths,
            len_x,
            labels,
            wavs,
        )


class PositionalEncoding(nn.Module):
    """This class implements the absolute sinusoidal positional encoding function.
    PE(pos, 2i)   = sin(pos/(10000^(2i/dmodel)))
    PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))

    Based on Cornell & Zhong's implementation in Transformer.py

    Arguments
    ---------
    input_size: int
        Embedding dimension.
    max_len : int, optional
        Max length of the input sequences (default 2500).

    Example
    -------
    >>> a = torch.rand((8, 120, 512))
    >>> enc = PositionalEncoding(input_size=a.shape[-1])
    >>> b = enc(a)
    >>> b.shape
    torch.Size([1, 120, 512])
    """

    def __init__(self, input_size, max_len=2500):
        super().__init__()
        if input_size % 2 != 0:
            raise ValueError(
                f"Cannot use sin/cos positional encoding with odd channels (got channels={input_size})"
            )
        self.max_len = max_len
        pe = torch.zeros(self.max_len, input_size, requires_grad=False)
        positions = torch.arange(0, self.max_len).unsqueeze(1).float()
        denominator = torch.exp(
            torch.arange(0, input_size, 2).float()
            * -(math.log(10000.0) / input_size)
        )

        pe[:, 0::2] = torch.sin(positions * denominator)
        pe[:, 1::2] = torch.cos(positions * denominator)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

        # Define learnable scaling parameter
        self.alpha = nn.Parameter(torch.Tensor(1))  # 1-dimensional tensor

        # Initialize alpha parameter
        nn.init.normal_(self.alpha)  # Initialize with random values

    def forward(self, x):
        """
        Arguments
        ---------
        x : torch.Tensor
            Input feature shape (batch, time, fea)

        Returns
        -------
        The positional encoding.
        """
        scaled_pos_embedding = self.alpha * self.pe[:, : x.size(1)].clone().detach()
        return scaled_pos_embedding


def dynamic_range_compression(x, C=1, clip_val=1e-5):
    """Dynamic range compression for audio signals"""
    return torch.log(torch.clamp(x, min=clip_val) * C)


def mel_spectogram(
        sample_rate,
        hop_length,
        win_length,
        n_fft,
        n_mels,
        f_min,
        f_max,
        power,
        normalized,
        norm,
        mel_scale,
        compression,
        audio,
):
    """calculates MelSpectrogram for a raw audio signal

    Arguments
    ---------
    sample_rate : int
        Sample rate of audio signal.
    hop_length : int
        Length of hop between STFT windows.
    win_length : int
        Window size.
    n_fft : int
        Size of FFT.
    n_mels : int
        Number of mel filterbanks.
    f_min : float
        Minimum frequency.
    f_max : float
        Maximum frequency.
    power : float
        Exponent for the magnitude spectrogram.
    normalized : bool
        Whether to normalize by magnitude after stft.
    norm : str or None
        If "slaney", divide the triangular mel weights by the width of the mel band
    mel_scale : str
        Scale to use: "htk" or "slaney".
    compression : bool
        whether to do dynamic range compression
    audio : torch.Tensor
        input audio signal

    Returns
    -------
    mel : torch.Tensor
        The computed mel spectrogram features.
    """
    from torchaudio import transforms

    audio_to_mel = transforms.MelSpectrogram(
        sample_rate=sample_rate,
        hop_length=hop_length,
        win_length=win_length,
        n_fft=n_fft,
        n_mels=n_mels,
        f_min=f_min,
        f_max=f_max,
        power=power,
        normalized=normalized,
        norm=norm,
        mel_scale=mel_scale,
    ).to(audio.device)

    mel = audio_to_mel(audio)

    if compression:
        mel = dynamic_range_compression(mel)

    return mel

Overwriting /notebooks/models/TransformerTTS.py


### LJSpeech Preparation

In [None]:
%%file /notebooks/ljspeech_prepare.py
"""
LJspeech data preparation.
Download: https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2

Authors
 * Yingzhi WANG 2022
 * Sathvik Udupa 2022
 * Pradnya Kandarkar 2023
"""

import os
import csv
import json
import random
import logging
import torch
import torchaudio
import numpy as np
from tqdm import tqdm
from speechbrain.utils.data_utils import download_file
from speechbrain.dataio.dataio import load_pkl, save_pkl
import tgt
from speechbrain.inference.text import GraphemeToPhoneme
import re
from unidecode import unidecode
from speechbrain.utils.text_to_sequence import _g2p_keep_punctuations


logger = logging.getLogger(__name__)
OPT_FILE = "opt_ljspeech_prepare.pkl"
METADATA_CSV = "metadata.csv"
TRAIN_JSON = "train.json"
VALID_JSON = "valid.json"
TEST_JSON = "test.json"
WAVS = "wavs"
DURATIONS = "durations"

logger = logging.getLogger(__name__)
OPT_FILE = "opt_ljspeech_prepare.pkl"


def prepare_ljspeech(
    data_folder,
    save_folder,
    splits=["train", "valid"],
    split_ratio=[90, 10],
    model_name=None,
    seed=1234,
    pitch_n_fft=1024,
    pitch_hop_length=256,
    pitch_min_f0=65,
    pitch_max_f0=400,
    skip_prep=False,
    use_custom_cleaner=False,
    device="cpu",
):
    """
    Prepares the csv files for the LJspeech datasets.

    Arguments
    ---------
    data_folder : str
        Path to the folder where the original LJspeech dataset is stored
    save_folder : str
        The directory where to store the csv/json files
    splits : list
        List of dataset splits to prepare
    split_ratio : list
        Proportion for dataset splits
    model_name : str
        Model name (used to prepare additional model specific data)
    seed : int
        Random seed
    pitch_n_fft : int
        Number of fft points for pitch computation
    pitch_hop_length : int
        Hop length for pitch computation
    pitch_min_f0 : int
        Minimum f0 for pitch computation
    pitch_max_f0 : int
        Max f0 for pitch computation
    skip_prep : bool
        If True, skip preparation
    use_custom_cleaner : bool
        If True, uses custom cleaner defined for this recipe
    device : str
        Device for to be used for computation (used as required)

    Returns
    -------
    None

    Example
    -------
    >>> from recipes.LJSpeech.TTS.ljspeech_prepare import prepare_ljspeech
    >>> data_folder = 'data/LJspeech/'
    >>> save_folder = 'save/'
    >>> splits = ['train', 'valid']
    >>> split_ratio = [90, 10]
    >>> seed = 1234
    >>> prepare_ljspeech(data_folder, save_folder, splits, split_ratio, seed)
    """
    # Sets seeds for reproducible code
    random.seed(seed)

    if skip_prep:
        return

    # Creating configuration for easily skipping data_preparation stage
    conf = {
        "data_folder": data_folder,
        "splits": splits,
        "split_ratio": split_ratio,
        "save_folder": save_folder,
        "seed": seed,
    }
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)

    # Setting output files
    meta_csv = os.path.join(data_folder, METADATA_CSV)
    wavs_folder = os.path.join(data_folder, WAVS)

    save_opt = os.path.join(save_folder, OPT_FILE)
    save_json_train = os.path.join(save_folder, TRAIN_JSON)
    save_json_valid = os.path.join(save_folder, VALID_JSON)
    save_json_test = os.path.join(save_folder, TEST_JSON)

    phoneme_alignments_folder = None
    duration_folder = None
    pitch_folder = None
    # Setting up additional folders required for FastSpeech2
    if model_name is not None and "FastSpeech2" in model_name:
        # This step requires phoneme alignments to be present in the data_folder
        # We automatically download the alignments from https://www.dropbox.com/s/v28x5ldqqa288pu/LJSpeech.zip
        # Download and unzip LJSpeech phoneme alignments from here: https://drive.google.com/drive/folders/1DBRkALpPd6FL9gjHMmMEdHODmkgNIIK4
        alignment_URL = (
            "https://www.dropbox.com/s/v28x5ldqqa288pu/LJSpeech.zip?dl=1"
        )
        phoneme_alignments_folder = os.path.join(
            data_folder, "TextGrid", "LJSpeech"
        )
        download_file(
            alignment_URL, data_folder + "/alignments.zip", unpack=True
        )

        duration_folder = os.path.join(data_folder, "durations")
        if not os.path.exists(duration_folder):
            os.makedirs(duration_folder)

        # extract pitch for both Fastspeech2 and FastSpeech2WithAligner models
        pitch_folder = os.path.join(data_folder, "pitch")
        if not os.path.exists(pitch_folder):
            os.makedirs(pitch_folder)

    # Check if this phase is already done (if so, skip it)
    if skip(splits, save_folder, conf):
        logger.info("Skipping preparation, completed in previous run.")
        return

    # Additional check to make sure metadata.csv and wavs folder exists
    assert os.path.exists(meta_csv), "metadata.csv does not exist"
    assert os.path.exists(wavs_folder), "wavs/ folder does not exist"

    # Prepare data splits
    msg = "Creating json file for ljspeech Dataset.."
    logger.info(msg)
    data_split, meta_csv = split_sets(data_folder, splits, split_ratio)

    if "train" in splits:
        prepare_json(
            model_name,
            data_split["train"],
            save_json_train,
            wavs_folder,
            meta_csv,
            phoneme_alignments_folder,
            duration_folder,
            pitch_folder,
            pitch_n_fft,
            pitch_hop_length,
            pitch_min_f0,
            pitch_max_f0,
            use_custom_cleaner,
            device,
        )
    if "valid" in splits:
        prepare_json(
            model_name,
            data_split["valid"],
            save_json_valid,
            wavs_folder,
            meta_csv,
            phoneme_alignments_folder,
            duration_folder,
            pitch_folder,
            pitch_n_fft,
            pitch_hop_length,
            pitch_min_f0,
            pitch_max_f0,
            use_custom_cleaner,
            device,
        )
    if "test" in splits:
        prepare_json(
            model_name,
            data_split["test"],
            save_json_test,
            wavs_folder,
            meta_csv,
            phoneme_alignments_folder,
            duration_folder,
            pitch_folder,
            pitch_n_fft,
            pitch_hop_length,
            pitch_min_f0,
            pitch_max_f0,
            use_custom_cleaner,
            device,
        )
    save_pkl(conf, save_opt)


def skip(splits, save_folder, conf):
    """
    Detects if the ljspeech data_preparation has been already done.
    If the preparation has been done, we can skip it.

    Arguments
    ---------
    splits : list
        The portions of data to review.
    save_folder : str
        The path to the directory containing prepared files.
    conf : dict
        Configuration to match against saved config.

    Returns
    -------
    bool
        if True, the preparation phase can be skipped.
        if False, it must be done.
    """
    # Checking json files
    skip = True

    split_files = {
        "train": TRAIN_JSON,
        "valid": VALID_JSON,
        "test": TEST_JSON,
    }

    for split in splits:
        if not os.path.isfile(os.path.join(save_folder, split_files[split])):
            skip = False

    #  Checking saved options
    save_opt = os.path.join(save_folder, OPT_FILE)
    if skip is True:
        if os.path.isfile(save_opt):
            opts_old = load_pkl(save_opt)
            if opts_old == conf:
                skip = True
            else:
                skip = False
        else:
            skip = False
    return skip


def split_sets(data_folder, splits, split_ratio):
    """Randomly splits the wav list into training, validation, and test lists.
    Note that a better approach is to make sure that all the classes have the
    same proportion of samples for each session.

    Arguments
    ---------
    data_folder : str
        The path to the directory containing the data.
    splits : list
        The list of the selected splits.
    split_ratio : list
        List composed of three integers that sets split ratios for train,
        valid, and test sets, respectively.
        For instance split_ratio=[80, 10, 10] will assign 80% of the sentences
        to training, 10% for validation, and 10% for test.

    Returns
    -------
    dictionary containing train, valid, and test splits.
    """
    meta_csv = os.path.join(data_folder, METADATA_CSV)
    csv_reader = csv.reader(
        open(meta_csv), delimiter="|", quoting=csv.QUOTE_NONE
    )

    meta_csv = list(csv_reader)

    index_for_sessions = []
    session_id_start = "LJ001"
    index_this_session = []
    for i in range(len(meta_csv)):
        session_id = meta_csv[i][0].split("-")[0]
        if session_id == session_id_start:
            index_this_session.append(i)
            if i == len(meta_csv) - 1:
                index_for_sessions.append(index_this_session)
        else:
            index_for_sessions.append(index_this_session)
            session_id_start = session_id
            index_this_session = [i]

    session_len = [len(session) for session in index_for_sessions]

    data_split = {}
    for i, split in enumerate(splits):
        data_split[split] = []
        for j in range(len(index_for_sessions)):
            if split == "train":
                random.shuffle(index_for_sessions[j])
                n_snts = int(session_len[j] * split_ratio[i] / sum(split_ratio))
                data_split[split].extend(index_for_sessions[j][0:n_snts])
                del index_for_sessions[j][0:n_snts]
            if split == "valid":
                if "test" in splits:
                    random.shuffle(index_for_sessions[j])
                    n_snts = int(
                        session_len[j] * split_ratio[i] / sum(split_ratio)
                    )
                    data_split[split].extend(index_for_sessions[j][0:n_snts])
                    del index_for_sessions[j][0:n_snts]
                else:
                    data_split[split].extend(index_for_sessions[j])
            if split == "test":
                data_split[split].extend(index_for_sessions[j])

    return data_split, meta_csv


def prepare_json(
    model_name,
    seg_lst,
    json_file,
    wavs_folder,
    csv_reader,
    phoneme_alignments_folder,
    durations_folder,
    pitch_folder,
    pitch_n_fft,
    pitch_hop_length,
    pitch_min_f0,
    pitch_max_f0,
    use_custom_cleaner=False,
    device="cpu",
):
    """
    Creates json file given a list of indexes.

    Arguments
    ---------
    model_name : str
        Model name (used to prepare additional model specific data)
    seg_lst : list
        The list of json indexes of a given data split
    json_file : str
        Output json path
    wavs_folder : str
        LJspeech wavs folder
    csv_reader : _csv.reader
        LJspeech metadata
    phoneme_alignments_folder : path
        Path where the phoneme alignments are stored
    durations_folder : path
        Folder where to store the duration values of each audio
    pitch_folder : path
        Folder where to store the pitch of each audio
    pitch_n_fft : int
        Number of fft points for pitch computation
    pitch_hop_length : int
        Hop length for pitch computation
    pitch_min_f0 : int
        Minimum f0 for pitch computation
    pitch_max_f0 : int
        Max f0 for pitch computation
    use_custom_cleaner : bool
        If True, uses custom cleaner defined for this recipe
    device : str
        Device for to be used for computation (used as required)
    """
    g2p = GraphemeToPhoneme.from_hparams(
            "speechbrain/soundchoice-g2p", run_opts={"device": device}
        )
    logger.info(f"preparing {json_file}.")
    if model_name in ["Tacotron2", "FastSpeech2WithAlignment", "TransformerTTS"]:
        logger.info(
            "Computing phonemes for LJSpeech labels using SpeechBrain G2P. This may take a while."
        )
    if model_name is not None and "FastSpeech2" in model_name:
        logger.info(
            "Computing pitch as required for FastSpeech2. This may take a while."
        )

    json_dict = {}
    for index in tqdm(seg_lst):
        # Common data preparation
        id = list(csv_reader)[index][0]
        wav = os.path.join(wavs_folder, f"{id}.wav")
        label = list(csv_reader)[index][2]
        if use_custom_cleaner:
            label = custom_clean(label, model_name)

        json_dict[id] = {
            "uttid": id,
            "wav": wav,
            "label": label,
            "segment": True if "train" in json_file else False,
        }

        # FastSpeech2 specific data preparation
        if model_name == "FastSpeech2":
            audio, fs = torchaudio.load(wav)

            # Parses phoneme alignments
            textgrid_path = os.path.join(
                phoneme_alignments_folder, f"{id}.TextGrid"
            )
            textgrid = tgt.io.read_textgrid(
                textgrid_path, include_empty_intervals=True
            )

            last_phoneme_flags = get_last_phoneme_info(
                textgrid.get_tier_by_name("words"),
                textgrid.get_tier_by_name("phones"),
            )
            (
                phonemes,
                duration,
                start,
                end,
                trimmed_last_phoneme_flags,
            ) = get_alignment(
                textgrid.get_tier_by_name("phones"),
                fs,
                pitch_hop_length,
                last_phoneme_flags,
            )

            # Gets label phonemes
            label_phoneme = " ".join(phonemes)
            spn_labels = [0] * len(phonemes)
            for i in range(1, len(phonemes)):
                if phonemes[i] == "spn":
                    spn_labels[i - 1] = 1
            if start >= end:
                print(f"Skipping {id}")
                continue

            # Saves durations
            duration_file_path = os.path.join(durations_folder, f"{id}.npy")
            np.save(duration_file_path, duration)

            # Computes pitch
            audio = audio[:, int(fs * start) : int(fs * end)]
            pitch_file = wav.replace(".wav", ".npy").replace(
                wavs_folder, pitch_folder
            )
            if not os.path.isfile(pitch_file):
                pitch = torchaudio.functional.detect_pitch_frequency(
                    waveform=audio,
                    sample_rate=fs,
                    frame_time=(pitch_hop_length / fs),
                    win_length=3,
                    freq_low=pitch_min_f0,
                    freq_high=pitch_max_f0,
                ).squeeze(0)

                # Concatenate last element to match duration.
                pitch = torch.cat([pitch, pitch[-1].unsqueeze(0)])

                # Mean and Variance Normalization
                mean = 256.1732939688805
                std = 328.319759158607

                pitch = (pitch - mean) / std

                pitch = pitch[: sum(duration)]
                np.save(pitch_file, pitch)

            # Updates data for the utterance
            json_dict[id].update({"label_phoneme": label_phoneme})
            json_dict[id].update({"spn_labels": spn_labels})
            json_dict[id].update({"start": start})
            json_dict[id].update({"end": end})
            json_dict[id].update({"durations": duration_file_path})
            json_dict[id].update({"pitch": pitch_file})
            json_dict[id].update(
                {"last_phoneme_flags": trimmed_last_phoneme_flags}
            )

        # FastSpeech2WithAlignment specific data preparation
        if model_name == "FastSpeech2WithAlignment":
            audio, fs = torchaudio.load(wav)
            # Computes pitch
            pitch_file = wav.replace(".wav", ".npy").replace(
                wavs_folder, pitch_folder
            )
            if not os.path.isfile(pitch_file):
                if torchaudio.__version__ < "2.1":
                    pitch = torchaudio.functional.compute_kaldi_pitch(
                        waveform=audio,
                        sample_rate=fs,
                        frame_length=(pitch_n_fft / fs * 1000),
                        frame_shift=(pitch_hop_length / fs * 1000),
                        min_f0=pitch_min_f0,
                        max_f0=pitch_max_f0,
                    )[0, :, 0]
                else:
                    pitch = torchaudio.functional.detect_pitch_frequency(
                        waveform=audio,
                        sample_rate=fs,
                        frame_time=(pitch_hop_length / fs),
                        win_length=3,
                        freq_low=pitch_min_f0,
                        freq_high=pitch_max_f0,
                    ).squeeze(0)

                    # Concatenate last element to match duration.
                    pitch = torch.cat([pitch, pitch[-1].unsqueeze(0)])

                    # Mean and Variance Normalization
                    mean = 256.1732939688805
                    std = 328.319759158607

                    pitch = (pitch - mean) / std

                np.save(pitch_file, pitch)

            phonemes = _g2p_keep_punctuations(g2p, label)
            # Updates data for the utterance
            json_dict[id].update({"phonemes": phonemes})
            json_dict[id].update({"pitch": pitch_file})

        if model_name == "TransformerTTS":
            phonemes = _g2p_keep_punctuations(g2p, label)
            # Updates data for the utterance
            json_dict[id].update({"phonemes": phonemes})

    # Writing the dictionary to the json file
    with open(json_file, mode="w") as json_f:
        json.dump(json_dict, json_f, indent=2)

    logger.info(f"{json_file} successfully created!")


def get_alignment(tier, sampling_rate, hop_length, last_phoneme_flags):
    """
    Returns phonemes, phoneme durations (in frames), start time (in seconds), end time (in seconds).
    This function is adopted from https://github.com/ming024/FastSpeech2/blob/master/preprocessor/preprocessor.py

    Arguments
    ---------
    tier : tgt.core.IntervalTier
        For an utterance, contains Interval objects for phonemes and their start time and end time in seconds
    sampling_rate : int
        Sample rate if audio signal
    hop_length : int
        Hop length for duration computation
    last_phoneme_flags : list
        List of (phoneme, flag) tuples with flag=1 if the phoneme is the last phoneme else flag=0


    Returns
    -------
    (phones, durations, start_time, end_time) : tuple
        The phonemes, durations, start time, and end time for an utterance
    """

    sil_phones = ["sil", "sp", "spn", ""]

    phonemes = []
    durations = []
    start_time = 0
    end_time = 0
    end_idx = 0
    trimmed_last_phoneme_flags = []

    flag_iter = iter(last_phoneme_flags)

    for t in tier._objects:
        s, e, p = t.start_time, t.end_time, t.text
        current_flag = next(flag_iter)

        # Trims leading silences
        if phonemes == []:
            if p in sil_phones:
                continue
            else:
                start_time = s

        if p not in sil_phones:
            # For ordinary phones
            # Removes stress indicators
            if p[-1].isdigit():
                phonemes.append(p[:-1])
            else:
                phonemes.append(p)
            trimmed_last_phoneme_flags.append(current_flag[1])
            end_time = e
            end_idx = len(phonemes)
        else:
            # Uses a unique token for all silent phones
            phonemes.append("spn")
            trimmed_last_phoneme_flags.append(current_flag[1])

        durations.append(
            int(
                np.round(e * sampling_rate / hop_length)
                - np.round(s * sampling_rate / hop_length)
            )
        )

    # Trims tailing silences
    phonemes = phonemes[:end_idx]
    durations = durations[:end_idx]

    return phonemes, durations, start_time, end_time, trimmed_last_phoneme_flags


def get_last_phoneme_info(words_seq, phones_seq):
    """This function takes word and phoneme tiers from a TextGrid file as input
    and provides a list of tuples for the phoneme sequence indicating whether
    each of the phonemes is the last phoneme of a word or not.

    Each tuple of the returned list has this format: (phoneme, flag)


    Arguments
    ---------
    words_seq : tier
        word tier from a TextGrid file
    phones_seq : tier
        phoneme tier from a TextGrid file

    Returns
    -------
    last_phoneme_flags : list
        each tuple of the returned list has this format: (phoneme, flag)
    """

    # Gets all phoneme objects for the entire sequence
    phoneme_objects = phones_seq._objects
    phoneme_iter = iter(phoneme_objects)

    # Stores flags to show if an element (phoneme) is a the last phoneme of a word
    last_phoneme_flags = list()

    # Matches the end times of the phoneme and word objects to get the last phoneme information
    for word_obj in words_seq._objects:
        word_end_time = word_obj.end_time

        current_phoneme = next(phoneme_iter, None)
        while current_phoneme:
            phoneme_end_time = current_phoneme.end_time
            if phoneme_end_time == word_end_time:
                last_phoneme_flags.append((current_phoneme.text, 1))
                break
            else:
                last_phoneme_flags.append((current_phoneme.text, 0))
            current_phoneme = next(phoneme_iter, None)

    return last_phoneme_flags


def custom_clean(text, model_name):
    """
    Uses custom criteria to clean text.

    Arguments
    ---------
    text : str
        Input text to be cleaned
    model_name : str
        whether to treat punctuations

    Returns
    -------
    text : str
        Cleaned text
    """

    _abbreviations = [
        (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
        for x in [
            ("mrs", "missus"),
            ("mr", "mister"),
            ("dr", "doctor"),
            ("st", "saint"),
            ("co", "company"),
            ("jr", "junior"),
            ("maj", "major"),
            ("gen", "general"),
            ("drs", "doctors"),
            ("rev", "reverend"),
            ("lt", "lieutenant"),
            ("hon", "honorable"),
            ("sgt", "sergeant"),
            ("capt", "captain"),
            ("esq", "esquire"),
            ("ltd", "limited"),
            ("col", "colonel"),
            ("ft", "fort"),
        ]
    ]
    text = unidecode(text.lower())
    if model_name != "FastSpeech2WithAlignment":
        text = re.sub("[:;]", " - ", text)
        text = re.sub(r'[)(\[\]"]', " ", text)
        text = text.strip().strip().strip("-")

    text = re.sub(" +", " ", text)
    for regex, replacement in _abbreviations:
        text = re.sub(regex, replacement, text)
    return text


### Hyperparameters

In [13]:
%%file /notebooks/train.yaml

############################################################################
# Model: TransformerTTS
# Tokens: Phonemes (ARPABET)
# Training: LJSpeech
# Authors: Salman Hussain Ali, 2024
# ############################################################################


###################################
# Experiment Parameters and setup #
###################################
seed: 1234
__set_seed: !apply:torch.manual_seed [!ref <seed>]
output_folder: !ref ./results/transformertts/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt
epochs: 750
keep_checkpoint_interval: 50
model_name: "TransformerTTS"

###################################
# Progress Samples                #
###################################
# Progress samples are used to monitor the progress
# of an ongoing training session by outputting samples
# of spectrograms, alignments, etc at regular intervals

# Whether to enable progress samples
progress_samples: True

# The path where the samples will be stored
progress_sample_path: !ref <output_folder>/samples
# The interval, in epochs. For instance, if it is set to 5,
# progress samples will be output every 5 epochs
progress_samples_interval: 1
# The sample size for raw batch samples saved in batch.pth
# (useful mostly for model debugging)
progress_batch_sample_size: 3

#################################
# Data files and pre-processing #
#################################
data_folder: !PLACEHOLDER # e.g, /localscratch/ljspeech

train_json: !ref <save_folder>/train.json
valid_json: !ref <save_folder>/valid.json
test_json: !ref <save_folder>/test.json

splits: ["train","valid"]
split_ratio: [90,10]

skip_prep: False

# Use the original preprocessing from nvidia
# The cleaners to be used (applicable to nvidia only)
text_cleaners: ['english_cleaners']

################################
# Audio Parameters             #
################################
sample_rate: 16000
hop_length: 200
win_length: 1024
n_mel_channels: 80
n_fft: 1024
mel_fmin: 0.0
mel_fmax: 8000.0
mel_normalized: False
power: 1
norm: "slaney"
mel_scale: "slaney"
dynamic_range_compression: True

################################
# Optimization Hyperparameters #
################################
learning_rate: 0.001
weight_decay: 0.000006
batch_size: 32 #minimum 2
num_workers: 8
mask_padding: True
gate_loss_weight: 7.0

train_dataloader_opts:
  batch_size: !ref <batch_size>
  drop_last: True  #True #False
  num_workers: !ref <num_workers>
  collate_fn: !new:models.TransformerTTS.TextMelCollate

valid_dataloader_opts:
  batch_size: !ref <batch_size>
  num_workers: !ref <num_workers>
  collate_fn: !new:models.TransformerTTS.TextMelCollate

test_dataloader_opts:
  batch_size: !ref <batch_size>
  num_workers: !ref <num_workers>
  collate_fn: !new:models.TransformerTTS.TextMelCollate

################################
# Model Parameters and model   #
################################

# Input parameters
lexicon:
    - "AA"
    - "AE"
    - "AH"
    - "AO"
    - "AW"
    - "AY"
    - "B"
    - "CH"
    - "D"
    - "DH"
    - "EH"
    - "ER"
    - "EY"
    - "F"
    - "G"
    - "HH"
    - "IH"
    - "IY"
    - "JH"
    - "K"
    - "L"
    - "M"
    - "N"
    - "NG"
    - "OW"
    - "OY"
    - "P"
    - "R"
    - "S"
    - "SH"
    - "T"
    - "TH"
    - "UH"
    - "UW"
    - "V"
    - "W"
    - "Y"
    - "Z"
    - "ZH"
    - "-"
    - "!"
    - "'"
    - "("
    - ")"
    - ","
    - "."
    - ":"
    - ";"
    - "?"
    - " "

n_symbols: 52 #fixed depending on symbols in the lexicon (+1 for a dummy symbol used for padding, +1 for unknown)
padding_idx: 0
symbols_embedding_dim: 512

# Encoder Pre-Net parameters
encoder_kernel_size: 5
encoder_n_convolutions: 3
encoder_embedding_dim: 512
encoder_stride: 1
encoder_prenet_dropout: 0.1
encoder_dilation: 1
encoder_bias: True
encoder_padding: "same"

# Decoder Pre-Net parameters
# The number of frames in the target per encoder step
n_frames_per_step: 1
decoder_prenet_dim: 256
max_decoder_steps: 1000 
p_decoder_dropout: 0.5

# Transformer parameters
d_model: 256
transformer_nhead: 8
transformer_num_encoder_layers: 6
transformer_num_decoder_layers: 6
transformer_d_ffn: 2048
transformer_dropout: 0.1
transformer_activation: "relu"

# Mel-post processing network parameters
postnet_embedding_dim: 512
postnet_kernel_size: 5
postnet_n_convolutions: 5
gate_threshold: 0.5 # TODO - maybe remove


mel_spectogram: !name:models.TransformerTTS.mel_spectogram
  sample_rate: !ref <sample_rate>
  hop_length: !ref <hop_length>
  win_length: !ref <win_length>
  n_fft: !ref <n_fft>
  n_mels: !ref <n_mel_channels>
  f_min: !ref <mel_fmin>
  f_max: !ref <mel_fmax>
  power: !ref <power>
  normalized: !ref <mel_normalized>
  norm: !ref <norm>
  mel_scale: !ref <mel_scale>
  compression: !ref <dynamic_range_compression>

#model
model: !new:models.TransformerTTS.TransformerTTS
  mask_padding: !ref <mask_padding>
  n_mel_channels: !ref <n_mel_channels>
  # symbols
  n_symbols: !ref <n_symbols>
  symbols_embedding_dim: !ref <symbols_embedding_dim>
  # encoder pre-net
  #encoder_embedding_dim: !ref <encoder_embedding_dim>
  encoder_prenet_kernel_size: !ref <encoder_kernel_size>
  encoder_prenet_n_convolutions: !ref <encoder_n_convolutions>
  #encoder_prenet_padding: !ref <encoder_padding>
  encoder_prenet_dilation: !ref <encoder_dilation>
  encoder_prenet_bias: !ref <encoder_bias>
  encoder_prenet_dropout: !ref <encoder_prenet_dropout>
  encoder_prenet_stride: !ref <encoder_stride>
  # decoder pre-net
  n_frames_per_step: !ref <n_frames_per_step>
  decoder_prenet_hidden_dims: !ref <decoder_prenet_dim>
  decoder_prenet_dropout: !ref <p_decoder_dropout>
  # transformer
  d_model: !ref <d_model>
  transformer_nhead: !ref <transformer_nhead>
  transformer_num_encoder_layers: !ref <transformer_num_encoder_layers>
  transformer_num_decoder_layers: !ref <transformer_num_decoder_layers>
  transformer_d_ffn: !ref <transformer_d_ffn>
  transformer_dropout: !ref <transformer_dropout>
  transformer_activation: !ref <transformer_activation>
  # postnet
  postnet_embedding_dim: !ref <postnet_embedding_dim>
  postnet_kernel_size: !ref <postnet_kernel_size>
  postnet_n_convolutions: !ref <postnet_n_convolutions>
  #decoder_no_early_stopping: !ref <decoder_no_early_stopping>
  gate_threshold: !ref <gate_threshold>
  padding_idx: !ref <padding_idx>

criterion: !new:models.TransformerTTS.Loss
  gate_loss_weight: !ref <gate_loss_weight>

# Masks
lookahead_mask: !name:speechbrain.lobes.models.transformer.Transformer.get_lookahead_mask
padding_mask: !name:speechbrain.lobes.models.transformer.Transformer.get_mask_from_lengths

modules:
  model: !ref <model>

#optimizer
opt_class: !name:torch.optim.Adam
  lr: !ref <learning_rate>
  weight_decay: !ref <weight_decay>

#epoch object
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
  limit: !ref <epochs>

train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
  save_file: !ref <train_log>

#annealing_function
lr_annealing: !new:speechbrain.nnet.schedulers.IntervalScheduler
  intervals:
    - steps: 6000
      lr: 0.0005
    - steps: 8000
      lr: 0.0003
    - steps: 10000
      lr: 0.0001

#checkpointer
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
  checkpoints_dir: !ref <save_folder>
  recoverables:
    model: !ref <model>
    counter: !ref <epoch_counter>
    scheduler: !ref <lr_annealing>

infer: !name:models.TransformerTTS.infer

progress_sample_logger: !new:speechbrain.utils.train_logger.ProgressSampleLogger
  output_path: !ref <progress_sample_path>
  batch_sample_size: !ref <progress_batch_sample_size>
  formats:
    raw_batch: raw

input_encoder: !new:speechbrain.dataio.encoder.TextEncoder


Overwriting /notebooks/train.yaml


### Brain Class

In [23]:
%%file /notebooks/train.py

"""
 Recipe for training the TransformerTTS Text-To-Speech model, an end-to-end
 neural text-to-speech (TTS) system

 To run this recipe, do the following:
 # python train.py --device=cuda:0 --max_grad_norm=1.0 --data_folder=/your_folder/LJSpeech-1.1 hparams/train.yaml

 to infer simply load saved model and do
 savemodel.infer(text_Sequence,len(textsequence))

 where text_Sequence is the output of the text_to_sequence function from
 textToSequence.py (from textToSequence import text_to_sequence)

 Authors
 * Georges Abous-Rjeili 2021
 * Artem Ploujnikov 2021
 * Yingzhi Wang 2022
"""
import torch
import speechbrain as sb
import sys
import logging
from hyperpyyaml import load_hyperpyyaml

from speechbrain.inference import GraphemeToPhoneme
from speechbrain.utils.text_to_sequence import text_to_sequence, _g2p_keep_punctuations, _clean_text
from speechbrain.utils.data_utils import scalarize

logger = logging.getLogger(__name__)


class TransformerTTSBrain(sb.Brain):
    """The Brain implementation for TransformerTTS"""

    def on_fit_start(self):
        """Gets called at the beginning of ``fit()``, on multiple processes
        if ``distributed_count > 0`` and backend is ddp and initializes statistics
        """
        self.hparams.progress_sample_logger.reset()
        self.last_epoch = 0
        self.last_batch = None
        self.last_loss_stats = {}
        self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
        return super().on_fit_start()

    def compute_forward(self, batch, stage):
        """Computes the forward pass

        Arguments
        ---------
        batch: str
            a single batch
        stage: speechbrain.Stage
            the training stage

        Returns
        -------
        the model output
        """
        # Batch is the results of TextMelCollate
        effective_batch = self.batch_to_device(batch)

        inputs, y, num_items, _, _ = effective_batch

        _, input_lengths, mel, _, output_lengths = inputs

        # Getting target mask (to avoid looking ahead)
        mask_size = mel.shape[2]
        tgt_mask = torch.nn.Transformer.generate_square_subsequent_mask(mask_size, device=self.device)
        #tgt_mask = torch.triu(torch.ones(mask_size, mask_size) * float('-inf'), diagonal=1).to(self.device, non_blocking=True)
        #tgt_mask = self.hparams.lookahead_mask(mel).to(self.device, non_blocking=True)

        # Padding masks for source and targets (use padding_mask)
        src_key_padding_mask = self.hparams.padding_mask(input_lengths).to(self.device, non_blocking=True)
        
        tgt_key_padding_mask = self.hparams.padding_mask(output_lengths).to(self.device, non_blocking=True)
        
        masks = (tgt_mask, None, src_key_padding_mask, tgt_key_padding_mask)

        return self.modules.model(inputs,masks)

    def on_fit_batch_end(self, batch, outputs, loss, should_step):
        """At the end of the optimizer step, apply noam annealing."""
        if should_step:
            self.hparams.lr_annealing(self.optimizer)

    def compute_objectives(self, predictions, batch, stage):
        """Computes the loss given the predicted and targeted outputs.
        Arguments
        ---------
        predictions : torch.Tensor
            The model generated spectrograms and other metrics from `compute_forward`.
        batch : PaddedBatch
            This batch object contains all the relevant tensors for computation.
        stage : sb.Stage
            One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST.
        Returns
        -------
        loss : torch.Tensor
            A one-element tensor used for backpropagating the gradient.
        """
        effective_batch = self.batch_to_device(batch)
        # Hold on to the batch for the inference sample. This is needed because
        # the inference sample is run from on_stage_end only, where
        # batch information is not available
        self.last_batch = effective_batch
        # Hold on to a sample (for logging)
        self._remember_sample(effective_batch, predictions)
        # Compute the loss
        loss = self._compute_loss(predictions, effective_batch, stage)
        return loss

    def _compute_loss(self, predictions, batch, stage):
        """Computes the value of the loss function and updates stats

        Arguments
        ---------
        predictions: tuple
            model predictions
        batch: PaddedBatch
            Inputs for this training iteration.
        stage: sb.Stage
            One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST.

        Returns
        -------
        loss: torch.Tensor
            the loss value
        """
        inputs, targets, num_items, labels, wavs = batch
        text_padded, input_lengths, _, max_len, output_lengths = inputs
        
        mel_target, _ = targets
        mel_out, mel_out_postnet, gate_out = predictions
        
        target=self._get_spectrogram_sample(mel_target),
        output=self._get_spectrogram_sample(mel_out),

        loss_stats = self.hparams.criterion(
            predictions, targets
        )
        self.last_loss_stats[stage] = scalarize(loss_stats)
        return loss_stats.loss

    def _remember_sample(self, batch, predictions):
        """Remembers samples of spectrograms and the batch for logging purposes

        Arguments
        ---------
        batch: tuple
            a training batch
        predictions: tuple
            predictions (raw output of the TransformerTTS model)
        """
        inputs, targets, num_items, labels, wavs = batch
        text_padded, input_lengths, _, max_len, output_lengths = inputs
        mel_target, _ = targets
        mel_out, mel_out_postnet, gate_out = predictions

        self.hparams.progress_sample_logger.remember(
            target=self._get_spectrogram_sample(mel_target),
            output=self._get_spectrogram_sample(mel_out),
            output_postnet=self._get_spectrogram_sample(mel_out_postnet),
            raw_batch=self.hparams.progress_sample_logger.get_batch_sample(
                {
                    "text_padded": text_padded,
                    "input_lengths": input_lengths,
                    "mel_target": mel_target,
                    "mel_out": mel_out,
                    "mel_out_postnet": mel_out_postnet,
                    "gate_out": gate_out,
                    "labels": labels,
                    "wavs": wavs,
                }
            ),
        )

    def batch_to_device(self, batch):
        """Transfers the batch to the target device

        Arguments
        ---------
        batch: tuple
            the batch to use

        Returns
        -------
        batch: tuple
            the batch on the correct device
        """
        (
            text_padded,
            input_lengths,
            mel_padded,
            gate_padded,
            output_lengths,
            len_x,
            labels,
            wavs,
        ) = batch
        text_padded = text_padded.to(self.device, non_blocking=True).long()
        input_lengths = input_lengths.to(self.device, non_blocking=True).long()
        max_len = torch.max(input_lengths.data).item()
        mel_padded = mel_padded.to(self.device, non_blocking=True).float()
        gate_padded = gate_padded.to(self.device, non_blocking=True).float()

        output_lengths = output_lengths.to(
            self.device, non_blocking=True
        ).long()
        x = (text_padded, input_lengths, mel_padded, max_len, output_lengths)
        y = (mel_padded, gate_padded)
        len_x = torch.sum(output_lengths)
        return (x, y, len_x, labels, wavs)

    def _get_spectrogram_sample(self, raw):
        """Converts a raw spectrogram to one that can be saved as an image
        sample  = sqrt(exp(raw))

        Arguments
        ---------
        raw: torch.Tensor
            the raw spectrogram (as used in the model)

        Returns
        -------
        sample: torch.Tensor
            the spectrogram, for image saving purposes
        """
        sample = raw[0]
        return torch.sqrt(torch.exp(sample))

    def on_stage_end(self, stage, stage_loss, epoch):
        """Gets called at the end of an epoch.
        Arguments
        ---------
        stage : sb.Stage
            One of sb.Stage.TRAIN, sb.Stage.VALID, sb.Stage.TEST
        stage_loss : float
            The average loss for all of the data processed in this stage.
        epoch : int
            The currently-starting epoch. This is passed
            `None` during the test stage.
        """

        # Store the train loss until the validation stage
        
        # At the end of validation, we can write
        if stage == sb.Stage.VALID:
            # Update learning rate
            print("Stage End")
            lr = self.optimizer.param_groups[-1]["lr"]
            self.last_epoch = epoch

            # The train_logger writes a summary to stdout and to the logfile.
            self.hparams.train_logger.log_stats(  # 1#2#
                stats_meta={"Epoch": epoch, "lr": lr},
                train_stats=self.last_loss_stats[sb.Stage.TRAIN],
                valid_stats=self.last_loss_stats[sb.Stage.VALID],
            )

            # Save the current checkpoint and delete previous checkpoints.
            epoch_metadata = {
                **{"epoch": epoch},
                **self.last_loss_stats[sb.Stage.VALID],
            }
            self.checkpointer.save_and_keep_only(
                meta=epoch_metadata,
                min_keys=["loss"],
                ckpt_predicate=(
                    (
                        lambda ckpt: (
                            ckpt.meta["epoch"]
                            % self.hparams.keep_checkpoint_interval
                            != 0
                        )
                    )
                    if self.hparams.keep_checkpoint_interval is not None
                    else None
                ),
            )
            output_progress_sample = (
                self.hparams.progress_samples
                and epoch % self.hparams.progress_samples_interval == 0
            )
            # if output_progress_sample:
            #     self.run_inference_sample()
            #     self.hparams.progress_sample_logger.save(epoch)

        # We also write statistics about test data to stdout and to the logfile.
        if stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                {"Epoch loaded": self.hparams.epoch_counter.current},
                test_stats=self.last_loss_stats[sb.Stage.TEST],
            )
            if self.hparams.progress_samples:
                self.run_inference_sample()
                self.hparams.progress_sample_logger.save("test")

    def run_inference_sample(self):
        """Produces a sample in inference mode. This is called when producing
        samples and can be useful because"""
        if self.last_batch is None:
            return
        inputs, _, _, _, _ = self.last_batch
        text_padded, input_lengths, _, _, _ = inputs
        mel_out, _, _ = self.hparams.model.infer(
            text_padded[:1], input_lengths[:1]
        )
        self.hparams.progress_sample_logger.remember(
            inference_mel_out=self._get_spectrogram_sample(mel_out)
        )
        
def dataio_prepare(hparams):
    
    lexicon = hparams["lexicon"]
    input_encoder = hparams.get("input_encoder")
    
    # add a dummy symbol for idx 0 - used for padding.
    lexicon = ["@@"] + lexicon
    input_encoder.update_from_iterable(lexicon, sequence_input=False)
    input_encoder.add_unk()
    
    # Define audio pipeline:
    @sb.utils.data_pipeline.takes("wav","phonemes")
    @sb.utils.data_pipeline.provides("mel_text_pair")
    def audio_pipeline(wav, phonemes):
        # Calculate the mel spectrogram for the audio files
        audio = sb.dataio.dataio.read_audio(wav)
        mel = hparams["mel_spectogram"](audio=audio)

        # Encode phonemes to get the sequence of IDs corresponding to the symbols in the text.
        encoded_phonemes = input_encoder.encode_sequence_torch(phonemes).int()

        len_phonemes = len(encoded_phonemes)

        return encoded_phonemes, mel, len_phonemes


    datasets = {}
    data_info = {
        "train": hparams["train_json"],
        "valid": hparams["valid_json"],
        "test": hparams["test_json"],
    }
    for dataset in hparams["splits"]:
        datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
            json_path=data_info[dataset],
            replacements={"data_root": hparams["data_folder"]},
            dynamic_items=[audio_pipeline],
            output_keys=["mel_text_pair", "wav", "label"],
        )

    return datasets



if __name__ == "__main__":
    # Load hyperparameters file with command-line overrides
    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])

    with open(hparams_file) as fin:
        hparams = load_hyperpyyaml(fin, overrides)

    # create ddp_group with the right communication protocol
    sb.utils.distributed.ddp_init_group(run_opts)

    # Create experiment directory
    sb.create_experiment_directory(
        experiment_directory=hparams["output_folder"],
        hyperparams_to_save=hparams_file,
        overrides=overrides,
    )

    from ljspeech_prepare import prepare_ljspeech

    sb.utils.distributed.run_on_main(
        prepare_ljspeech,
        kwargs={
            "model_name": hparams["model_name"],
            "data_folder": hparams["data_folder"],
            "save_folder": hparams["save_folder"],
            "splits": hparams["splits"],
            "split_ratio": hparams["split_ratio"],
            "seed": hparams["seed"],
            "skip_prep": hparams["skip_prep"],
            "device": run_opts["device"]
        },
    )

    datasets = dataio_prepare(hparams)
    
    #print(hparams)

    # Brain class initialization
    TransformerTTS_brain = TransformerTTSBrain(
        modules=hparams["modules"],
        opt_class=hparams["opt_class"],
        hparams=hparams,
        run_opts=run_opts,
        checkpointer=hparams["checkpointer"],
    )

    # Training
    TransformerTTS_brain.fit(
        TransformerTTS_brain.hparams.epoch_counter,
        train_set=datasets["train"],
        valid_set=datasets["valid"],
        train_loader_kwargs=hparams["train_dataloader_opts"],
        valid_loader_kwargs=hparams["valid_dataloader_opts"],
    )

    # Test
    if "test" in datasets:
        TransformerTTS_brain.evaluate(
            datasets["test"],
            test_loader_kwargs=hparams["test_dataloader_opts"],
        )


Overwriting /notebooks/train.py


### Training Script

Script used to remove previous checkpoint

In [20]:
!rm -rf /notebooks/results/transformertts/1234/save/CKPT+2024-04-27+06-43-30+00

Training Script

In [None]:
!python train.py --device=cuda:0 --max_grad_norm=1.0 --data_folder=/notebooks/LJSpeech-1.1 train.yaml

speechbrain.core - Beginning experiment!
speechbrain.core - Experiment folder: ./results/transformertts/1234
ljspeech_prepare - Skipping preparation, completed in previous run.
speechbrain.core - Gradscaler enabled: False. Using precision: fp32.
speechbrain.core - 26.0M trainable parameters in TransformerTTSBrain
speechbrain.utils.fetching - Fetch hyperparams.yaml: Using existing file/symlink in pretrained_models/GraphemeToPhoneme-9b27d6eb840bf95c5aedf15ae8ed1172/hyperparams.yaml.
speechbrain.utils.fetching - Fetch custom.py: Delegating to Huggingface hub, source speechbrain/soundchoice-g2p.
speechbrain.utils.fetching - Fetch model.ckpt: Using existing file/symlink in pretrained_models/GraphemeToPhoneme-9b27d6eb840bf95c5aedf15ae8ed1172/model.ckpt.
speechbrain.utils.fetching - Fetch ctc_lin.ckpt: Using existing file/symlink in pretrained_models/GraphemeToPhoneme-9b27d6eb840bf95c5aedf15ae8ed1172/ctc_lin.ckpt.
speechbrain.utils.parameter_transfer - Loading pretrained files for: model, ctc