In [67]:
import os
import torch
from TTS.tts.models.glow_tts import GlowTTS, GlowTTSConfig
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor

import torchaudio
from speechbrain.inference.TTS import Tacotron2
from speechbrain.inference.vocoders import HIFIGAN

INFO:speechbrain.utils.quirks:Applied quirks (see `speechbrain.utils.quirks`): [disable_jit_profiling, allow_tf32]
INFO:speechbrain.utils.quirks:Excluded quirks specified by the `SB_DISABLE_QUIRKS` environment (comma-separated list): []


In [68]:
output_path = "tts_playground"
model_path = "train/run-February-22-2025_02+19AM-9b6e3e6/best_model.pth"
output_audio_path = os.path.join(output_path, "output.wav")

device = "cuda" if torch.cuda.is_available() else "cpu"

In [69]:
config_path = "train/run-February-22-2025_02+19AM-9b6e3e6/config.json"
config = GlowTTSConfig()
GlowTTSConfig.load_json(config, config_path)

In [70]:
hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir="pretrained_model/tts-hifigan-ljspeech")

INFO:speechbrain.utils.fetching:Fetch hyperparams.yaml: Fetching from HuggingFace Hub 'speechbrain/tts-hifigan-ljspeech' if not cached


hyperparams.yaml:   0%|          | 0.00/1.16k [00:00<?, ?B/s]

INFO:speechbrain.utils.fetching:Fetch custom.py: Fetching from HuggingFace Hub 'speechbrain/tts-hifigan-ljspeech' if not cached
  WeightNorm.apply(module, name, dim)
INFO:speechbrain.utils.fetching:Fetch generator.ckpt: Fetching from HuggingFace Hub 'speechbrain/tts-hifigan-ljspeech' if not cached


generator.ckpt:   0%|          | 0.00/55.8M [00:00<?, ?B/s]

INFO:speechbrain.utils.parameter_transfer:Loading pretrained files for: generator
  state_dict = torch.load(path, map_location=device)


In [71]:
ap = AudioProcessor.init_from_config(config)

tokenizer, config = TTSTokenizer.init_from_config(config)

INFO:TTS.utils.audio.processor:Setting up Audio Processor...
INFO:TTS.utils.audio.processor: | sample_rate: 22050
INFO:TTS.utils.audio.processor: | resample: False
INFO:TTS.utils.audio.processor: | num_mels: 80
INFO:TTS.utils.audio.processor: | log_func: np.log10
INFO:TTS.utils.audio.processor: | min_level_db: -100
INFO:TTS.utils.audio.processor: | frame_shift_ms: None
INFO:TTS.utils.audio.processor: | frame_length_ms: None
INFO:TTS.utils.audio.processor: | ref_level_db: 20
INFO:TTS.utils.audio.processor: | fft_size: 1024
INFO:TTS.utils.audio.processor: | power: 1.5
INFO:TTS.utils.audio.processor: | preemphasis: 0.0
INFO:TTS.utils.audio.processor: | griffin_lim_iters: 60
INFO:TTS.utils.audio.processor: | signal_norm: True
INFO:TTS.utils.audio.processor: | symmetric_norm: True
INFO:TTS.utils.audio.processor: | mel_fmin: 0
INFO:TTS.utils.audio.processor: | mel_fmax: None
INFO:TTS.utils.audio.processor: | pitch_fmin: 1.0
INFO:TTS.utils.audio.processor: | pitch_fmax: 640.0
INFO:TTS.utils.a

In [73]:
model = GlowTTS(config, ap, tokenizer, speaker_manager=None)
model.load_checkpoint(config, model_path, eval=True)
model.to(device)
model.eval()

GlowTTS(
  (encoder): Encoder(
    (emb): Embedding(132, 192)
    (prenet): ResidualConv1dLayerNormBlock(
      (conv_layers): ModuleList(
        (0-2): 3 x Conv1d(192, 192, kernel_size=(5,), stride=(1,), padding=(2,))
      )
      (norm_layers): ModuleList(
        (0-2): 3 x LayerNorm()
      )
      (proj): Conv1d(192, 192, kernel_size=(1,), stride=(1,))
    )
    (encoder): RelativePositionTransformer(
      (dropout): Dropout(p=0.1, inplace=False)
      (attn_layers): ModuleList(
        (0-5): 6 x RelativePositionMultiHeadAttention(
          (conv_q): Conv1d(192, 192, kernel_size=(1,), stride=(1,))
          (conv_k): Conv1d(192, 192, kernel_size=(1,), stride=(1,))
          (conv_v): Conv1d(192, 192, kernel_size=(1,), stride=(1,))
          (conv_o): Conv1d(192, 192, kernel_size=(1,), stride=(1,))
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (norm_layers_1): ModuleList(
        (0-5): 6 x LayerNorm()
      )
      (ffn_layers): ModuleList(
      

In [95]:
text = "Mary had a little lamb. The lamb became super smart at math 145 and created used Sunzi's theorem to perform RSA so that the lamb could escape Mary's clutches."

In [96]:
phoneme_seq = tokenizer.text_to_ids(text)
phoneme_tensor = torch.tensor(phoneme_seq, dtype=torch.long).unsqueeze(0).to(device)

print(phoneme_tensor)

tensor([[ 15,  51,  78,  11, 131,  10,  29,   7, 131,  49, 131,  14,  64,  22,
          49,  14, 131,  14,  29,  15, 127, 131,  31,  49, 131,  14,  29,  15,
         131,   5,  64,  13,   8,  64,  15, 131,  21,  23,  18,  50, 131,  21,
          15,  41,  78,  22, 131,  29,  22, 131,  15,  29, 117, 131,  25,  92,
          16, 131,  10,  92,  16,   7,  78,  64,   7, 131,   9,  44,  78,  22,
          11, 131,   9,   4,  64,  24, 131,  29,  16,   7, 131,  13,  78,  11,
           8,  64,  22,  64,   7, 131,  12,  23,  21,  22, 131,  21,  92,  16,
          28,  11,  28, 131, 117,  64,  78,  49,  15, 131,  22,  49, 131,  18,
          50,   9,  44,  78,  15, 131,  41,  78, 131,  51,  21, 131,   8,  64,
         131,  21,  17,  90, 131,  31,  29,  22, 131,  31,  49, 131,  14,  29,
          15, 131,  13,  90,   7, 131,  64,  21,  13,   8,  64,  18, 131,  15,
          51,  78,  11,  28, 131,  13,  14,  92,  22,  86,  64,  28, 127]],
       device='cuda:0')


In [97]:
x_lengths = torch.tensor([phoneme_tensor.shape[1]], dtype=torch.long).to(phoneme_tensor.device)

with torch.no_grad():
    outputs = model.inference(phoneme_tensor, aux_input={"x_lengths": x_lengths})

mel_output = outputs["model_outputs"]
print(mel_output)


tensor([[[-2.1818, -1.6928, -0.7711,  ..., -3.8520, -3.9608, -3.9626],
         [-2.1097, -1.5116, -0.5922,  ..., -3.9377, -4.0046, -3.9939],
         [-2.1965, -1.4620, -0.5542,  ..., -3.9659, -4.0233, -4.0048],
         ...,
         [-3.4247, -2.5844, -2.1146,  ..., -3.9564, -4.0126, -4.0046],
         [-3.3328, -2.5142, -2.1522,  ..., -3.9475, -3.9948, -3.9962],
         [-3.0474, -2.4216, -2.1282,  ..., -3.9408, -3.9800, -3.9884]]],
       device='cuda:0')


In [98]:
# Check the shape of mel_output
print(mel_output.shape)

# If the shape is incorrect, reshape or transform it
# Assuming mel_output is a numpy array or a PyTorch tensor
if mel_output.shape[1] != 80:
    # Example transformation (this will depend on your specific use case)
    mel_output = mel_output.permute(0, 2, 1)  # Swap dimensions if needed

print(mel_output.shape)
# Decode the batch with the correct shape
waveforms = hifi_gan.decode_batch(mel_output)

torch.Size([1, 1028, 80])
torch.Size([1, 80, 1028])


In [99]:
torchaudio.save(output_audio_path, waveforms.squeeze(1), 22050)