In [1]:
%%capture
pip install speechbrain

In [2]:
import re
import logging
import torch
import torchaudio
import random
import speechbrain as sb

In [10]:
%%file hyperparams.yaml


# Define the vocabulary size
vocab_size: 70
blank_index: 0  # For padding


# Training hyperparameters
number_of_epochs: 10
batch_size: 8
learning_rate: 0.0001

weight_decay: 0.000001
betas: [0.9, 0.98]


lexicon:
  - AA
  - AE
  - AH
  - AO
  - AW
  - AY
  - B
  - CH
  - D
  - DH
  - EH
  - ER
  - EY
  - F
  - G
  - HH
  - IH
  - IY
  - JH
  - K
  - L
  - M
  - N
  - NG
  - OW
  - OY
  - P
  - R
  - S
  - SH
  - T
  - TH
  - UH
  - UW
  - V
  - W
  - Y
  - Z
  - ZH
  - ' '


# Model hyperparameters
d_model: 512
nhead: 8
num_encoder_layers: 6
num_decoder_layers: 6
dim_feedforward: 2048
dropout: 0.1

sample_rate: 22050
hop_length: 256
win_length: null
n_mel_channels: 80
n_fft: 1024
mel_fmin: 0.0
mel_fmax: 8000.0
power: 1
norm: "slaney"
mel_scale: "slaney"
dynamic_range_compression: True
mel_normalized: False
min_max_energy_norm: True
min_f0: 65  #(torchaudio pyin values)
max_f0: 2093 #(torchaudio pyin values)


# Epoch counter
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>

# Transformer model
Seq2SeqTransformer: !new:torch.nn.Transformer
    d_model: !ref <d_model>
    nhead: !ref <nhead>
    num_encoder_layers: !ref <num_encoder_layers>
    num_decoder_layers: !ref <num_decoder_layers>
    dim_feedforward: !ref <dim_feedforward>
    dropout: !ref <dropout>
    batch_first: True

# Embeddings
encoder_emb: !new:torch.nn.Embedding
    num_embeddings: !ref <vocab_size>
    embedding_dim: !ref <d_model>
    padding_idx: !ref <blank_index>

decoder_emb: !new:torch.nn.Embedding
    num_embeddings: !ref <vocab_size>
    embedding_dim: !ref <d_model>
    padding_idx: !ref <blank_index>

# Positional embeddings and custom prenets
pos_emb: !new:custom_model.ScaledPositionalEncoding
    d_model: !ref <d_model>
    max_len: 5000

encoder_prenet: !new:custom_model.EncoderPrenet
    embedding_dim: !ref <d_model>
    num_channels: 512
    kernel_size: 5
    dropout_rate: 0.5

decoder_prenet: !new:custom_model.DecoderPrenet
    input_dim: 80  # Number of mel channels
    hidden_dim: 256
    output_dim: !ref <d_model>
    final_dim: !ref <d_model>
    dropout_rate: 0.5

# Tacotron2 specific modules from SpeechBrain
postnet: !new:speechbrain.lobes.models.Tacotron2.Postnet

mel_spectogram: !name:speechbrain.lobes.models.FastSpeech2.mel_spectogram
    sample_rate: !ref <sample_rate>
    hop_length: !ref <hop_length>
    win_length: !ref <win_length>
    n_fft: !ref <n_fft>
    n_mels: !ref <n_mel_channels>
    f_min: !ref <mel_fmin>
    f_max: !ref <mel_fmax>
    power: !ref <power>
    normalized: !ref <mel_normalized>
    min_max_energy_norm: !ref <min_max_energy_norm>
    norm: !ref <norm>
    mel_scale: !ref <mel_scale>
    compression: !ref <dynamic_range_compression>


mel_linear: !new:speechbrain.nnet.linear.Linear
    input_size: !ref <d_model>
    n_neurons: 80  # Number of mel channels

stop_linear: !new:speechbrain.nnet.linear.Linear
    input_size: !ref <d_model>
    n_neurons: 1

input_encoder: !new:speechbrain.dataio.encoder.TextEncoder

# Masks
lookahead_mask: !name:speechbrain.lobes.models.transformer.Transformer.get_lookahead_mask
padding_mask: !name:speechbrain.lobes.models.transformer.Transformer.get_key_padding_mask


modules:
    Seq2SeqTransformer: !ref <Seq2SeqTransformer>
    encoder_emb: !ref <encoder_emb>
    decoder_emb: !ref <decoder_emb>
    pos_emb: !ref <pos_emb>
    encoder_prenet: !ref <encoder_prenet>
    decoder_prenet: !ref <decoder_prenet>
    postnet: !ref <postnet>
    mel_linear: !ref <mel_linear>
    stop_linear: !ref <stop_linear>

model: !new:torch.nn.ModuleList.
    - [!ref <Seq2SeqTransformer>, !ref <encoder_emb>, !ref <decoder_emb>, !ref <pos_emb>, !ref <encoder_prenet>, !ref <decoder_prenet>, !ref <postnet>, !ref <mel_linear>, !ref <stop_linear>]



pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
    loadables:
        model: !ref <model>
        input_encoder: !ref <input_encoder>




Overwriting hyperparams.yaml


In [19]:
%%file TextToSpeech.py
from speechbrain.utils.fetching import fetch
from speechbrain.inference.interfaces import Pretrained
from speechbrain.inference.text import GraphemeToPhoneme


logger = logging.getLogger(__name__)


class TextToSpeech(Pretrained):
    """
    A ready-to-use wrapper for Transformer TTS (text -> mel_spec).
    Arguments
    ---------
    hparams
        Hyperparameters (from HyperPyYAML)"""

    HPARAMS_NEEDED = ["model", "blank_index", "padding_mask", "lookahead_mask", "mel_spectogram", "input_encoder"]
    MODULES_NEEDED = ["modules"]

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.input_encoder = self.hparams.input_encoder
        self.input_encoder.update_from_iterable(self.hparams.lexicon,sequence_input=False)
        self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")


    def text_to_phoneme(self, text):
        """
        Generates phoneme sequences for the given text using a Grapheme-to-Phoneme (G2P) model.

        Args:
            text (str): The input text.

        Returns:
            list: List of phoneme sequences for the words in the text.
        """
        abbreviation_expansions = {
            "Mr.": "Mister",
            "Mrs.": "Misess",
            "Dr.": "Doctor",
            "No.": "Number",
            "St.": "Saint",
            "Co.": "Company",
            "Jr.": "Junior",
            "Maj.": "Major",
            "Gen.": "General",
            "Drs.": "Doctors",
            "Rev.": "Reverend",
            "Lt.": "Lieutenant",
            "Hon.": "Honorable",
            "Sgt.": "Sergeant",
            "Capt.": "Captain",
            "Esq.": "Esquire",
            "Ltd.": "Limited",
            "Col.": "Colonel",
            "Ft.": "Fort"
        }

        for abbreviation, expansion in abbreviation_expansions.items():
            text = text.replace(abbreviation, expansion)

        phonemes = self.g2p(text)
        phonemes = self.input_encoder.encode_sequence(phonemes)
        phoneme_seq = torch.LongTensor(phonemes)

        return phoneme_seq, len(phoneme_seq)

    def encode_batch(self, texts):
        """Computes mel-spectrogram for a list of texts

        Texts must be sorted in decreasing order on their lengths

        Arguments
        ---------
        texts: List[str]
            texts to be encoded into spectrogram

        Returns
        -------
        tensors of output spectrograms, output lengths and alignments
        """
        with torch.no_grad():
          phonemes = [self.text_to_phoneme(text)[0] for text in texts]
          phoneme_padded, lengths = self.pad_sequences(phonemes)
          phonemes_emb = self.mods.encoder_emb(phoneme_padded)
          encoder_prenet_out = self.mods.encoder_prenet(phonemes_emb)
          pos_enc_output = self.mods.pos_emb(encoder_prenet_out)
          enc_emb = pos_enc_output + encoder_prenet_out

          # Initialize decoder inputs for autoregressive generation
          decoder_input = torch.zeros(1, 1, 80, device=self.device)
          res = []
          stop_token_predictions_list=[]
          stop_condition = False
          itr=0
          max_itr=1000
          res.append(decoder_input)


          while not stop_condition and itr<max_itr:

            decoder_prenet_out = self.mods.decoder_prenet(decoder_input)
            pos_dec_output = self.mods.pos_emb(decoder_prenet_out)
            dec_emb = decoder_prenet_out + pos_dec_output

            src_mask = torch.zeros(enc_emb.size(1), enc_emb.size(1), device=self.device)
            src_key_padding_mask = self.hparams.padding_mask(enc_emb, pad_idx=self.hparams.blank_index)

            decoder_outputs = self.mods.Seq2SeqTransformer(enc_emb, dec_emb, src_mask=src_mask,
                                                              src_key_padding_mask=src_key_padding_mask
                                                              )
            mel_pred = self.mods.mel_linear(decoder_outputs).transpose(1,2)
            postnet_out=self.mods.postnet(mel_pred)
            mel_predictions=mel_pred+postnet_out

            stop_token_predictions= self.mods.stop_linear(decoder_outputs).squeeze(-1)
            stop_token_predictions_list.append(stop_token_predictions)

            decoder_input=mel_predictions.transpose(1,2)
            res.append(decoder_input)
            itr=itr+1

          final_res=torch.cat(res,dim=1)
          final_stop_tokens=torch.cat(stop_token_predictions_list,dim=1)

          return final_res.transpose(1, 2)

    def should_stop(self, stop_token_pred):
        # Implement your stopping condition here.
        # This could check for a predicted end-of-sequence token or a maximum length.
        # Convert logits to probabilities (assuming binary classification with sigmoid activation).
        stop_prob = torch.sigmoid(stop_token_pred).squeeze(-1)
        stop_decision = stop_prob > 0.5
        return stop_decision.any().item()

    def pad_sequences(self, sequences):
      """Pad sequences to the maximum length sequence in the batch.

      Arguments
      ---------
      sequences: List[torch.Tensor]
          The sequences to pad

      Returns
      -------
      Padded sequences and original lengths
      """
      max_length = max([len(seq) for seq in sequences])
      seq_padd = torch.zeros(len(sequences), max_length, dtype=torch.long)
      length_list = []
      for i, seq in enumerate(sequences):
          length = len(seq)
          seq_padd[i, :length] = seq
          length_list.append(length)
      return seq_padd, torch.tensor(length_list)

    def encode_text(self, text):
        """Runs inference for a single text str"""
        return self.encode_batch(text)

    def forward(self, texts):
        "Encodes the input texts."
        return self.encode_batch(texts)


Writing TextToSpeech.py


In [18]:
from speechbrain.inference.vocoders import HIFIGAN

texts = ["This is a example for synthesis."]

my_tts_model = TextToSpeech.from_hparams(source="/content/")
hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir="tmpdir_vocoder")
mel_output = my_tts_model.encode_text(texts)

# Running Vocoder (spectrogram-to-waveform)
waveforms = hifi_gan.decode_batch(mel_output)

# Save the waverform
torchaudio.save('example_TTS.wav',waveforms.squeeze(1), 22050)

