In [1]:
# to download the model checkpoints if they are not already present

# use this https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/matcha_ljspeech.ckpt   
# https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/generator_v1

In [1]:
import os
os.environ['PHONEMIZER_ESPEAK_LIBRARY'] = r'C:\Program Files\eSpeak NG\libespeak-ng.dll' # path to espeak-ng dll on Windows

In [18]:
import sys
import torch
import datetime as dt
import matplotlib.pyplot as plt
import numpy as np
import IPython.display as ipd
import soundfile as sf


from process_text import process_text 

sys.path.insert(0, os.path.join(os.getcwd(), 'Matcha_TTS_main')) # Matcha_TTS_main is the code from the original repo 
from matcha.models.matcha_tts import MatchaTTS


# for Hifigan
from matcha.hifigan.config import v1
from matcha.hifigan.denoiser import Denoiser
from matcha.hifigan.env import AttrDict
from matcha.hifigan.models import Generator as HiFiGAN

In [3]:
# load the model checkpoints
matcha_checkpoint_path = "checkpoints/matcha_ljspeech.ckpt"
hifigan_checkpoint_path = "checkpoints/generator_v1"

In [4]:
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
# load the matcha tts model

def load_model(checkpoint_path):
    model = MatchaTTS.load_from_checkpoint(
        checkpoint_path, 
        map_location=device,
        weights_only=False 
    )
    model.eval()
    return model

count_params = lambda x: f"{sum(p.numel() for p in x.parameters()):,}"

model = load_model(matcha_checkpoint_path)
print(f"Model loaded! Parameter count: {count_params(model)}")

  deprecate("LoRACompatibleLinear", "1.0.0", deprecation_message)


Model loaded! Parameter count: 18,204,193


In [7]:
# load the hifigan vocoder model

def load_vocoder(checkpoint_path):
    h = AttrDict(v1)
    hifigan = HiFiGAN(h).to(device)
    hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)['generator'])
    _ = hifigan.eval()
    hifigan.remove_weight_norm()
    return hifigan

vocoder = load_vocoder(hifigan_checkpoint_path)
denoiser = Denoiser(vocoder, mode='zeros')

Removing weight norm...


## the pipline is this : 
Text → process_text() → synthesise() → mel-spectrogram
                                      ↓
                              to_waveform() → audio waveform
                                      ↓
                              save_to_folder() → .wav file

In [8]:
# process_text

test_text = "Hello, Mr. Arezki! How are you?"
result = process_text(test_text, device)

print (result)

Processing complete!
{'x_orig': 'Hello, Mr. Arezki! How are you?', 'x_phones': 'həlˈoʊ, mˈɪstɚɹ ˈæɹɛzki! hˈaʊ ɑːɹ juː?', 'x': tensor([[ 50,   0,  83,   0,  54,   0, 156,   0,  57,   0, 135,   0,   3,   0,
          16,   0,  55,   0, 156,   0, 102,   0,  61,   0,  62,   0,  85,   0,
         123,   0,  16,   0, 156,   0,  72,   0, 123,   0,  86,   0,  68,   0,
          53,   0,  51,   0,   5,   0,  16,   0,  50,   0, 156,   0,  43,   0,
         135,   0,  16,   0,  69,   0, 158,   0, 123,   0,  16,   0,  52,   0,
          63,   0, 158,   0,   6]]), 'x_lengths': tensor([75]), 'sequence': [50, 83, 54, 156, 57, 135, 3, 16, 55, 156, 102, 61, 62, 85, 123, 16, 156, 72, 123, 86, 68, 53, 51, 5, 16, 50, 156, 43, 135, 16, 69, 158, 123, 16, 52, 63, 158, 6]}


#### use the same hyperparameters 


In [9]:

## Number of ODE Solver steps
n_timesteps = 10

## Changes to the speaking rate
length_scale=1.0

## Sampling temperature
temperature = 0.667

In [10]:
# Generate mel-spectrogram using the model
start_time = dt.datetime.now()

with torch.inference_mode():  # Disable gradients for inference
    output = model.synthesise(
        result['x'],              # Phoneme IDs tensor
        result['x_lengths'],      # Length of sequence
        n_timesteps=n_timesteps,
        temperature=temperature,
        length_scale=length_scale
    )

end_time = dt.datetime.now()
synthesis_time = (end_time - start_time).total_seconds()

print(f" Mel-spectrogram generated in {synthesis_time:.2f} seconds!")


 Mel-spectrogram generated in 0.89 seconds!


In [11]:
print(output.keys())

dict_keys(['encoder_outputs', 'decoder_outputs', 'attn', 'mel', 'mel_lengths', 'rtf'])


In [12]:
# Extract the mel-spectrogram
mel=output['mel']
mel_to_draw = mel.squeeze().cpu().numpy()  # Remove batch dim and move to CPU


In [13]:
# plot the mel-spectrogram
plt.figure(figsize=(10, 4))
plt.imshow(mel_to_draw, aspect='auto', origin='lower', cmap='viridis')  
plt.title('Generated Mel-Spectrogram')
plt.xlabel('Time Frames')   
plt.ylabel('Mel Frequency Channels')
plt.colorbar(format='%+2.0f dB')   
plt.savefig('figures/generated_mel_spectrogram_1.png')  # Save the figure 
plt.show()

  plt.show()


### generate audioi from mel

In [21]:
with torch.inference_mode(): # Wrap vocoder and denoiser in inference mode
    audio = vocoder(mel).clamp(-1, 1)
    audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
audio.cpu().squeeze()

ipd.display(ipd.Audio(audio, rate=22050))
#save the audio to a wav file
sf.write('output/generated_audio_1.wav', audio.numpy(), 22050)


In [22]:
print(output['rtf'])

0.3764221370281559
