# Persian StyleTTS2 Inference

This notebook performs inference using a trained StyleTTS2 model for Persian. It includes a full text processing pipeline: Normalization -> Disambiguation -> Diacritization -> Phonemization.

## 1. Install Dependencies

In [1]:
# !pip install git+ssh://github.com/SadeghKrmi/pernorm.git
# !pip install git+ssh://github.com/SadeghKrmi/zirneshane.git
# !pip install git+ssh://github.com/SadeghKrmi/vaguye.git
# !pip install git+ssh://github.com/SadeghKrmi/hamnevise.git
# !pip install munch pydub pyyaml librosa nltk matplotlib accelerate transformers phonemizer einops einops-exts tqdm typing-extensions

## 2. Imports and Setup

In [2]:
import torch
import yaml
import os
import librosa
import numpy as np
from munch import Munch
from nltk.tokenize import word_tokenize
import IPython.display as ipd

# Persian NLP tools
from pernorm import PersianNormalizer
from hamnevise import HamneviseModel
from zirneshane import HybridZirneshanModel
from vaguye import PersianPhonemizer

# StyleTTS2 modules
from models import *
from utils import *
from text_utils import TextCleaner
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

Using device: cpu


## 3. Load Configurations and Models

In [3]:
# Load Config
config = yaml.safe_load(open("Configs/config_fa.yml"))

# Load Utils Models
ASR_config = config.get('ASR_config', False)
ASR_path = config.get('ASR_path', False)
text_aligner = load_ASR_models(ASR_path, ASR_config)

F0_path = config.get('F0_path', False)
pitch_extractor = load_F0_models(F0_path)

from Utils.PLBERT_fa.util import load_plbert
BERT_path = config.get('PLBERT_dir', False)
plbert = load_plbert(BERT_path)

# Build Model
model_params = recursive_munch(config['model_params'])
model = build_model(model_params, text_aligner, pitch_extractor, plbert)
_ = [model[key].eval() for key in model]
_ = [model[key].to(device) for key in model]

  WeightNorm.apply(module, name, dim)


In [4]:
# Load Checkpoint
model_path = "Models/LJSpeech/epoch_2nd_00023.pth"
print(f"Loading model from {model_path}...")

params_whole = torch.load(model_path, map_location='cpu')
params = params_whole['net']

for key in model:
    if key in params:
        print('%s loaded' % key)
        try:
            model[key].load_state_dict(params[key])
        except:
            from collections import OrderedDict
            state_dict = params[key]
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:] # remove `module.`
                new_state_dict[name] = v
            # load params
            model[key].load_state_dict(new_state_dict, strict=False)
_ = [model[key].eval() for key in model]

Loading model from Models/LJSpeech/epoch_2nd_00023.pth...
bert loaded
bert_encoder loaded
predictor loaded
decoder loaded
text_encoder loaded
predictor_encoder loaded
style_encoder loaded
diffusion loaded
text_aligner loaded
pitch_extractor loaded
mpd loaded
msd loaded
wd loaded


In [5]:
# Initialize Sampler
sampler = DiffusionSampler(
    model.diffusion.diffusion,
    sampler=ADPM2Sampler(),
    sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
    clamp=False
)

## 4. Initialize Text Processing Pipeline

In [6]:
# Initialize NLP models
normer = PersianNormalizer()
zmodel = HybridZirneshanModel.load()
hmodel, tokenizer = HamneviseModel.load(device=device)
phonemizer = PersianPhonemizer()
hamnevise_words = set(hmodel.word2idx.keys())
text_cleaner = TextCleaner()

def needs_hamnevise(text):
  """Check if any word in text requires hamnevise disambiguation."""
  words = text.split()
  return any(word in hamnevise_words for word in words)

def process_text(text):
    # 1. Normalize
    text = normer.normalize(text)
    
    # 2. Disambiguate (Hamnevise)
    if needs_hamnevise(text):
        text, _ = hmodel.disambiguate(text, tokenizer=tokenizer)
        
    # 3. Diacritize (Zirneshan)
    text = zmodel.predict(text)
    
    # 4. Phonemize (Vaguye)
    # Note: vaguye.phonemize might need specific arguments depending on version, checking simple usage
    phonemes = phonemizer.phonemize(text)
    
    return phonemes

‚ö†Ô∏è CUDA not available, using CPU
üì• Downloading default model from HuggingFace...
üìÇ Loading model from: /root/.cache/huggingface/hub/models--SadeghK--zirneshane/snapshots/fa0b943ba9024e24fee59b9840daf29b89960ce8/zirneshan-word-char-parsbert-embedding-classifier-v2.0.pt
‚úÖ Model loaded successfully!
   Epoch: 10
   F1 Score: 0.7417

üèóÔ∏è  Model Architecture:
   - Shared encoder: HooshvareLab/bert-fa-base-uncased
   - Word-specific heads: 138
   - Character vocab: 67
üì¶ Files ready
  Model : /root/.cache/huggingface/hub/models--SadeghK--Hamnevise/snapshots/f89b2c2ac747866768c4cdd061eb6eb759e47a17/hamnevise-persian-word-disambigution-v1.0.pt
  Config: /root/.cache/huggingface/hub/models--SadeghK--Hamnevise/snapshots/f89b2c2ac747866768c4cdd061eb6eb759e47a17/tokenizer-config-v1.0.json
üì• Loading dictionary files from: /root/StyleTTS2/.venv/lib/python3.12/site-packages/vaguye/persian-dict
‚úÖ Loaded persian-primary.json
‚úÖ Loaded persian-secondary.json
üìö Total entries lo

## 5. Inference Function

In [7]:
def length_to_mask(lengths):
    mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
    mask = torch.gt(mask+1, lengths.unsqueeze(1))
    return mask

def preprocess(wave):
    to_mel = torchaudio.transforms.MelSpectrogram(n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
    mean, std = -4, 4
    
    wave_tensor = torch.from_numpy(wave).float()
    mel_tensor = to_mel(wave_tensor)
    mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
    return mel_tensor

def compute_style(path):
    wave, sr = librosa.load(path, sr=24000)
    audio, index = librosa.effects.trim(wave, top_db=30)
    if sr != 24000:
        audio = librosa.resample(audio, sr, 24000)
    mel_tensor = preprocess(audio).to(device)

    with torch.no_grad():
        ref_s = model.style_encoder(mel_tensor.unsqueeze(1))
        ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))

    return torch.cat([ref_s, ref_p], dim=1)

def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):
    text = text.strip()
    
    # Full pipeline: Normalize -> Disambiguate -> Diacritize -> Phonemize
    ps = process_text(text)
    
    # Tokenize phonemes
    # Note: TextCleaner expects IPA characters to be in its dictionary.
    tokens = text_cleaner(ps)
    tokens.insert(0, 0)
    tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)

    with torch.no_grad():
        input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
        text_mask = length_to_mask(input_lengths).to(device)

        t_en = model.text_encoder(tokens, input_lengths, text_mask)
        bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
        d_en = model.bert_encoder(bert_dur).transpose(-1, -2)

        s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
                                          embedding=bert_dur,
                                          embedding_scale=embedding_scale,
                                            features=ref_s,
                                             num_steps=diffusion_steps).squeeze(1)

        s = s_pred[:, 128:]
        ref = s_pred[:, :128]

        ref = alpha * ref + (1 - alpha)  * ref_s[:, :128]
        s = beta * s + (1 - beta)  * ref_s[:, 128:]

        d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
        x, _ = model.predictor.lstm(d)
        duration = model.predictor.duration_proj(x)
        duration = torch.sigmoid(duration).sum(axis=-1)
        pred_dur = torch.round(duration.squeeze()).clamp(min=1)

        pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
        c_frame = 0
        for i in range(pred_aln_trg.size(0)):
            pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
            c_frame += int(pred_dur[i].data)

        en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
        if model_params.decoder.type == "hifigan":
            asr_new = torch.zeros_like(en)
            asr_new[:, :, 0] = en[:, :, 0]
            asr_new[:, :, 1:] = en[:, :, 0:-1]
            en = asr_new

        F0_pred, N_pred = model.predictor.F0Ntrain(en, s)

        asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
        if model_params.decoder.type == "hifigan":
            asr_new = torch.zeros_like(asr)
            asr_new[:, :, 0] = asr[:, :, 0]
            asr_new[:, :, 1:] = asr[:, :, 0:-1]
            asr = asr_new

        out = model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0))

    return out.squeeze().cpu().numpy()[..., :-50]

## 6. Run Inference

In [11]:
text = ".ÿß€åŸÜ €å⁄© ÿ¨ŸÖŸÑŸá ÿ™ÿ≥ÿ™€å ÿ®ÿ± ÿßÿ≥ÿßÿ≥ €å⁄© ŸÖÿØŸÑ ŸáŸàÿ¥ ŸÖÿµŸÜŸàÿπ€å ÿ¨ÿØ€åÿØ ÿßÿ≥ÿ™!"
# Reference audio for style
ref_path = "Data/test.wav" 


print(f"Using reference: {ref_path}")
ref_s = compute_style(ref_path)

wav = inference(text, ref_s, alpha=0.9, beta=0.9, diffusion_steps=10, embedding_scale=1)
ipd.Audio(wav, rate=24000, normalize=False)

Using reference: Data/test.wav
‚ùå Non-IPA char: '!' (0x21) ‚Äî EXCLAMATION MARK
