In [1]:
# to download the model checkpoints if they are not already present

# use this https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/matcha_ljspeech.ckpt   
# https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/generator_v1

# to download the dataset 
# https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2

In [2]:
import os 
os.chdir('..')
print(os.getcwd())
print(os.listdir())

c:\Users\arezk\Desktop\M2\this year\ML\Projet\matcha-tts-code\matcha-tts
['.git', '.gitignore', 'checkpoints', 'config.py', 'docs', 'figures', 'full_main.py', 'Guide', 'LJSpeech-1.1', 'matcha_env', 'Matcha_TTS_main', 'model', 'notebooks', 'output', 'struct', 'test', 'utils', '__pycache__']


In [3]:
import torch
import datetime as dt
import matplotlib.pyplot as plt
import numpy as np
import IPython.display as ipd
import soundfile as sf
import sys 
from utils.process_text import process_text

In [4]:
sys.path.append('Matcha_TTS_main/') # add the Matcha_TTS_main (code from the original repo) directory to the path

from matcha.models.matcha_tts import MatchaTTS 
# for Hifigan
from matcha.hifigan.config import v1
from matcha.hifigan.denoiser import Denoiser
from matcha.hifigan.env import AttrDict
from matcha.hifigan.models import Generator as HiFiGAN

In [5]:
import os
os.environ['PHONEMIZER_ESPEAK_LIBRARY'] = r'C:\Program Files\eSpeak NG\libespeak-ng.dll' # path to espeak-ng dll on Windows

In [6]:
# load the model checkpoints
matcha_checkpoint_path = "checkpoints/matcha_ljspeech.ckpt"
hifigan_checkpoint_path = "checkpoints/generator_v1"

In [7]:
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
# load the matcha tts model from
def load_model(checkpoint_path):
    model = MatchaTTS.load_from_checkpoint(
        checkpoint_path, 
        map_location=device,
        weights_only=False 
    )
    model.eval()
    return model

count_params = lambda x: f"{sum(p.numel() for p in x.parameters()):,}"

model = load_model(matcha_checkpoint_path)
print(f"Model loaded! Parameter count: {count_params(model)}")

  deprecate("LoRACompatibleLinear", "1.0.0", deprecation_message)


Model loaded! Parameter count: 18,204,193


In [9]:
# load the hifigan vocoder model

def load_vocoder(checkpoint_path):
    h = AttrDict(v1)
    hifigan = HiFiGAN(h).to(device)
    hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)['generator'])
    _ = hifigan.eval()
    hifigan.remove_weight_norm()
    return hifigan

vocoder = load_vocoder(hifigan_checkpoint_path)
denoiser = Denoiser(vocoder, mode='zeros')

  WeightNorm.apply(module, name, dim)


Removing weight norm...


## the pipline is this : 
Text → process_text() → synthesise() → mel-spectrogram
                                      ↓
                              to_waveform() → audio waveform
                                      ↓
                              save_to_folder() → .wav file

In [10]:
# process_text   we use the same text as the inference example in the original repo

test_text ="The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent."
result = process_text(test_text, device)

print (result)

Processing complete!
{'x_orig': 'The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.', 'x_phones': 'ðə sˈiːkɹᵻt sˈɜːvɪs bᵻlˈiːvd ðˌɐɾɪt wʌz vˈɛɹi dˈaʊtfəl ðæt ˌɛni pɹˈɛzɪdənt wʊd ɹˈaɪd ɹˈɛɡjʊlɚli ɪn ɐ vˈiəkəl wɪð ɐ fˈɪkst tˈɑːp, ˈiːvən ðˌoʊ tɹænspˈæɹənt.', 'x': tensor([[ 81,   0,  83,   0,  16,   0,  61,   0, 156,   0,  51,   0, 158,   0,
          53,   0, 123,   0, 177,   0,  62,   0,  16,   0,  61,   0, 156,   0,
          87,   0, 158,   0,  64,   0, 102,   0,  61,   0,  16,   0,  44,   0,
         177,   0,  54,   0, 156,   0,  51,   0, 158,   0,  64,   0,  46,   0,
          16,   0,  81,   0, 157,   0,  70,   0, 125,   0, 102,   0,  62,   0,
          16,   0,  65,   0, 138,   0,  68,   0,  16,   0,  64,   0, 156,   0,
          86,   0, 123,   0,  51,   0,  16,   0,  46,   0, 156,   0,  43,   0,
         135,   0,  62,   0,  48,   0,  83,   0,  54,   0,  16,   0,  81,   0,
        

#### use the same hyperparameters 


In [11]:

## Number of ODE Solver steps
n_timesteps = 10

## Changes to the speaking rate
length_scale=1.0

## Sampling temperature
temperature = 0.667

In [12]:
# Generate mel-spectrogram using the model
start_time = dt.datetime.now()

with torch.inference_mode():  # Disable gradients for inference
    output = model.synthesise(
        result['x'],              # Phoneme IDs tensor
        result['x_lengths'],      # Length of sequence
        n_timesteps=n_timesteps,
        temperature=temperature,
        length_scale=length_scale
    )

end_time = dt.datetime.now()
synthesis_time = (end_time - start_time).total_seconds()

print(f" Mel-spectrogram generated in {synthesis_time:.2f} seconds!")


 Mel-spectrogram generated in 4.20 seconds!


In [13]:
print(output.keys())

dict_keys(['encoder_outputs', 'decoder_outputs', 'attn', 'mel', 'mel_lengths', 'rtf'])


In [14]:
# Extract the mel-spectrogram
mel=output['mel']
mel_to_draw = mel.squeeze().cpu().numpy()  # Remove batch dim and move to CPU


In [15]:
# plot the mel-spectrogram
plt.figure(figsize=(10, 4))
plt.imshow(mel_to_draw, aspect='auto', origin='lower', cmap='viridis')  
plt.title('Generated Mel-Spectrogram')
plt.xlabel('Time Frames')   
plt.ylabel('Mel Frequency Channels')
plt.colorbar(format='%+2.0f dB')   
plt.savefig('figures/generated_mel_spectrogram_synthes_example.png')  # Save the figure 
plt.show()

  plt.show()


### generate audioi from mel

In [16]:
with torch.inference_mode(): # Wrap vocoder and denoiser in inference mode
    audio = vocoder(mel).clamp(-1, 1)
    audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
audio.cpu().squeeze()

ipd.display(ipd.Audio(audio, rate=22050))
#save the audio to a wav file
sf.write('output/generated_audio_synthes_example_original_model.wav', audio.numpy(), 22050)


In [17]:
def to_waveform(mel, vocoder, denoiser, device):
    """Convert mel to waveform with HiFi-GAN, handling shape/device safely."""
    if mel.dim() == 2:
        mel = mel.unsqueeze(0)
    if mel.shape[1] != 80 and mel.shape[2] == 80:
        mel = mel.transpose(1, 2)

    mel = mel.to(device=device, dtype=torch.float32)
    vocoder = vocoder.to(device)

    with torch.inference_mode():
        audio = vocoder(mel).clamp(-1, 1)
        audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
    return audio

output['waveform'] = to_waveform(output['mel'], vocoder, denoiser, device)
rtf_w = synthesis_time * 22050 / (output['waveform'].shape[-1])

print(output['rtf'])
print(rtf_w)

0.4288888263620843
0.4293176491055522
