In [None]:
# Inference Notebook for Persian VITS (Single Speaker)
%matplotlib inline
import matplotlib.pyplot as plt
import IPython.display as ipd

import os
import json
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

import commons
import utils
from data_utils import TextAudioLoader, TextAudioCollate
from models import SynthesizerTrn
from text import text_to_sequence
from scipy.io.wavfile import write

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

def get_text(text, hps):
    """Convert text to sequence of phoneme/character IDs."""
    text_norm = text_to_sequence(text, hps.data.text_cleaners)
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    text_norm = torch.LongTensor(text_norm)
    return text_norm

## Persian Single Speaker Model

In [None]:
# Load configuration
config_path = "./configs/fa_single_speaker.json"  # Update path as needed
hps = utils.get_hparams_from_file(config_path)
print(f"Loaded config from: {config_path}")
print(f"Text cleaners: {hps.data.text_cleaners}")
print(f"Sampling rate: {hps.data.sampling_rate}")

In [None]:
# Import symbols after config is loaded (to ensure correct language symbols)
from text.symbols import symbols
print(f"Number of symbols: {len(symbols)}")
print(f"First 20 symbols: {symbols[:20]}")

In [None]:
# Initialize model
net_g = SynthesizerTrn(
    len(symbols),
    hps.data.filter_length // 2 + 1,
    hps.train.segment_size // hps.data.hop_length,
    **hps.model).to(device)
_ = net_g.eval()

print(f"Model initialized with {len(symbols)} symbols")
print(f"Model parameters: {sum(p.numel() for p in net_g.parameters())}")

In [None]:
# Load checkpoint
checkpoint_path = "./logs/fa_amir/G_10000.pth"  # Update with your actual checkpoint path
try:
    _ = utils.load_checkpoint(checkpoint_path, net_g, None)
    print(f"Successfully loaded checkpoint: {checkpoint_path}")
except Exception as e:
    print(f"Error loading checkpoint: {e}")
    print("Please check the checkpoint path and make sure training has produced some checkpoints.")

## Text-to-Speech Inference

In [None]:
def synthesize_speech(text, noise_scale=0.667, noise_scale_w=0.8, length_scale=1.0):
    """Synthesize speech from Persian text."""
    print(f"Input text: {text}")
    
    # Convert text to sequence
    stn_tst = get_text(text, hps)
    print(f"Text sequence length: {len(stn_tst)}")
    
    with torch.no_grad():
        x_tst = stn_tst.to(device).unsqueeze(0)
        x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
        
        audio = net_g.infer(
            x_tst, 
            x_tst_lengths, 
            noise_scale=noise_scale, 
            noise_scale_w=noise_scale_w, 
            length_scale=length_scale
        )[0][0,0].data.cpu().float().numpy()
    
    print(f"Generated audio length: {len(audio)} samples ({len(audio)/hps.data.sampling_rate:.2f} seconds)")
    return audio

In [None]:
# Test with simple Persian text
test_texts = [
    "سلام دنیا",  # Hello world
    "امروز روز خوبی است",  # Today is a good day
    "این یک آزمایش است",  # This is a test
]

for i, text in enumerate(test_texts, 1):
    print(f"\n=== Test {i} ===")
    try:
        audio = synthesize_speech(text)
        print("Audio generated successfully!")
        
        # Display audio player
        ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))
        
        # Optionally save to file
        output_file = f"output_sample_{i}.wav"
        write(output_file, hps.data.sampling_rate, audio)
        print(f"Saved to: {output_file}")
        
    except Exception as e:
        print(f"Error generating audio: {e}")

## Interactive Text Input

In [None]:
# Interactive cell for custom text input
custom_text = "متن فارسی خود را اینجا بنویسید"  # Write your Persian text here

print("Generating speech for custom text...")
try:
    audio = synthesize_speech(custom_text, noise_scale=0.667, length_scale=1.0)
    ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))
    
    # Save custom audio
    write("custom_output.wav", hps.data.sampling_rate, audio)
    print("Custom audio saved to: custom_output.wav")
    
except Exception as e:
    print(f"Error: {e}")

## Parameter Tuning

In [None]:
# Experiment with different synthesis parameters
sample_text = "این یک متن نمونه برای آزمایش پارامترهای مختلف است"

parameters = [
    {"noise_scale": 0.3, "length_scale": 1.0, "name": "Low noise, normal speed"},
    {"noise_scale": 0.667, "length_scale": 1.0, "name": "Default settings"},
    {"noise_scale": 1.0, "length_scale": 1.0, "name": "High noise, normal speed"},
    {"noise_scale": 0.667, "length_scale": 0.8, "name": "Default noise, fast speed"},
    {"noise_scale": 0.667, "length_scale": 1.2, "name": "Default noise, slow speed"},
]

for i, params in enumerate(parameters):
    print(f"\n=== {params['name']} ===")
    try:
        audio = synthesize_speech(
            sample_text, 
            noise_scale=params['noise_scale'], 
            length_scale=params['length_scale']
        )
        ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))
    except Exception as e:
        print(f"Error: {e}")

## Model Information

In [None]:
# Display model and configuration information
print("=== Model Configuration ===")
print(f"Model type: {hps.model.get('type', 'SynthesizerTrn')}")
print(f"Hidden channels: {hps.model.get('hidden_channels', 'N/A')}")
print(f"Filter channels: {hps.model.get('filter_channels', 'N/A')}")
print(f"Number of heads: {hps.model.get('n_heads', 'N/A')}")
print(f"Number of layers: {hps.model.get('n_layers', 'N/A')}")
print(f"Kernel size: {hps.model.get('kernel_size', 'N/A')}")

print("\n=== Data Configuration ===")
print(f"Sampling rate: {hps.data.sampling_rate}")
print(f"Filter length: {hps.data.filter_length}")
print(f"Hop length: {hps.data.hop_length}")
print(f"Win length: {hps.data.win_length}")
print(f"Text cleaners: {hps.data.text_cleaners}")
print(f"Add blank: {hps.data.get('add_blank', False)}")

print("\n=== Symbol Set ===")
print(f"Total symbols: {len(symbols)}")
print(f"Symbols: {symbols}")