In [None]:
!git clone https://github.com/HZhalex/F5-TTS-Vietnamese

In [None]:
import os
os.makedirs("ckpts/F5-TTS-Vietnamese", exist_ok=True)
!wget -O ckpts/F5-TTS-Vietnamese/model_500000.pt https://huggingface.co/hynt/F5-TTS-Vietnamese-ViVoice/resolve/main/model_last.pt

In [None]:
!pip install torchcodec 

In [None]:
%cd F5-TTS-Vietnamese

!pip install -e .

In [None]:
"""
Simple F5-TTS Inference Script
"""

import os
import torch
import torchaudio
import soundfile as sf

from f5_tts.model import DiT, CFM
from f5_tts.model.utils import get_tokenizer, convert_char_to_pinyin

# ==================== CẤU HÌNH ====================

MODEL_PATH = "/content/ckpts/F5-TTS-Vietnamese/model_500000.pt"
VOCAB_PATH = "/content/ckpts/F5-TTS-Vietnamese/vocab.txt"

REF_AUDIO_PATH = "/content/Bình (nam miền Bắc).wav"
REF_TEXT = "Anh chỉ muốn được nhìn nhận như là một huấn luyện viên."

GEN_TEXT = "Xin chào, đây là văn bản tôi muốn chuyển thành giọng nói"

OUTPUT_PATH = "output_generated.wav"

MODEL_CONFIG = {
    "dim": 1024,
    "depth": 22,
    "heads": 16,
    "ff_mult": 2,
    "text_dim": 512,
    "text_mask_padding": False,
    "conv_layers": 4,
    "pe_attn_head": 1,
}

MEL_SPEC_CONFIG = {
    "n_fft": 1024,
    "hop_length": 256,
    "win_length": 1024,
    "n_mel_channels": 100,
    "target_sample_rate": 24000,
    "mel_spec_type": "vocos",
}

NFE_STEP = 32
CFG_STRENGTH = 2.0
SWAY_SAMPLING_COEF = -1.0
SPEED = 1.0
TARGET_RMS = 0.1

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ==================== FUNCTIONS ====================

def load_vocoder(device=DEVICE):
    from vocos import Vocos
    print("Đang tải Vocos vocoder...")
    vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
    return vocoder


def load_f5tts_model(model_path, vocab_path, model_config, mel_spec_config, device):
    print(f"Đang load vocab từ: {vocab_path}")
    vocab_char_map, vocab_size = get_tokenizer(vocab_path, tokenizer="custom")
    print(f"Vocab size: {vocab_size}")
    
    print(f"Đang khởi tạo model DiT...")
    transformer = DiT(
        **model_config,
        text_num_embeds=vocab_size,
        mel_dim=mel_spec_config["n_mel_channels"]
    )
    
    model = CFM(
        transformer=transformer,
        mel_spec_kwargs=mel_spec_config,
        odeint_kwargs=dict(method="euler"),
        vocab_char_map=vocab_char_map,
    ).to(device)
    
    print(f"Đang load checkpoint từ: {model_path}")
    checkpoint = torch.load(model_path, map_location=device)
    
    if "ema_model_state_dict" in checkpoint:
        state_dict = {
            k.replace("ema_model.", ""): v
            for k, v in checkpoint["ema_model_state_dict"].items()
            if k not in ["initted", "step"]
        }
    elif "model_state_dict" in checkpoint:
        state_dict = checkpoint["model_state_dict"]
    else:
        state_dict = checkpoint
    
    model.load_state_dict(state_dict)
    model.eval()
    
    print("✓ Model đã được load thành công!")
    return model, vocab_char_map


def preprocess_audio(audio_path, target_sample_rate=24000, target_rms=TARGET_RMS, device=DEVICE):
    print(f"Đang load audio từ: {audio_path}")
    audio, sr = torchaudio.load(audio_path)
    
    if audio.shape[0] > 1:
        audio = torch.mean(audio, dim=0, keepdim=True)
    
    rms = torch.sqrt(torch.mean(torch.square(audio)))
    if rms < target_rms:
        audio = audio * target_rms / rms
    
    if sr != target_sample_rate:
        resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
        audio = resampler(audio)
    
    audio = audio.to(device)
    print(f"✓ Audio shape: {audio.shape}, Sample rate: {target_sample_rate}")
    print(f"✓ Original RMS: {rms.item():.4f}")
    
    return audio, target_sample_rate, rms


@torch.inference_mode()
def generate_speech(model, vocoder, ref_audio, ref_text, gen_text, vocab_char_map, ref_audio_rms):
    print("\n" + "="*50)
    print("BẮT ĐẦU GENERATE SPEECH")
    print("="*50)
    
    print(f"\nReference text: {ref_text}")
    print(f"Generate text: {gen_text}")
    
    text_list = [ref_text + " " + gen_text]
    final_text_list = convert_char_to_pinyin(text_list)
    print(f"Processed text: {final_text_list}")
    
    ref_audio_len = ref_audio.shape[-1] // MEL_SPEC_CONFIG["hop_length"]
    ref_text_len = len(ref_text.encode("utf-8"))
    gen_text_len = len(gen_text.encode("utf-8"))
    
    duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / SPEED)
    
    print(f"\nRef audio length: {ref_audio_len} frames")
    print(f"Target duration: {duration} frames")
    print(f"Estimated audio length: {duration * MEL_SPEC_CONFIG['hop_length'] / MEL_SPEC_CONFIG['target_sample_rate']:.2f}s")
    
    print(f"\nGenerating với {NFE_STEP} denoising steps...")
    generated, trajectory = model.sample(
        cond=ref_audio,
        text=final_text_list,
        duration=duration,
        steps=NFE_STEP,
        cfg_strength=CFG_STRENGTH,
        sway_sampling_coef=SWAY_SAMPLING_COEF,
    )
    
    print(f"✓ Generated mel shape: {generated.shape}")
    
    generated = generated[:, ref_audio_len:, :]
    
    print("Đang chuyển mel spectrogram sang waveform...")
    gen_mel_spec = generated.permute(0, 2, 1)
    generated_wave = vocoder.decode(gen_mel_spec).cpu()
    
    if ref_audio_rms < TARGET_RMS:
        generated_wave = generated_wave * ref_audio_rms / TARGET_RMS
    
    print(f"✓ Generated wave shape: {generated_wave.shape}")
    
    return generated_wave, MEL_SPEC_CONFIG["target_sample_rate"]


def main():
    print("\n" + "="*50)
    print("F5-TTS SIMPLE INFERENCE")
    print("="*50)
    print(f"Device: {DEVICE}\n")
    
    for path, name in [
        (MODEL_PATH, "Model"),
        (VOCAB_PATH, "Vocab"),
        (REF_AUDIO_PATH, "Reference audio")
    ]:
        if not os.path.exists(path):
            raise FileNotFoundError(f"{name} không tồn tại: {path}")
    
    model, vocab_char_map = load_f5tts_model(
        MODEL_PATH,
        VOCAB_PATH,
        MODEL_CONFIG,
        MEL_SPEC_CONFIG,
        DEVICE
    )
    
    vocoder = load_vocoder(DEVICE)
    
    ref_audio, sr, ref_audio_rms = preprocess_audio(
        REF_AUDIO_PATH,
        MEL_SPEC_CONFIG["target_sample_rate"],
        TARGET_RMS,
        DEVICE
    )
    
    generated_wave, sample_rate = generate_speech(
        model=model,
        vocoder=vocoder,
        ref_audio=ref_audio,
        ref_text=REF_TEXT,
        gen_text=GEN_TEXT,
        vocab_char_map=vocab_char_map,
        ref_audio_rms=ref_audio_rms
    )
    
    print(f"\nĐang lưu audio vào: {OUTPUT_PATH}")
    sf.write(
        OUTPUT_PATH,
        generated_wave.squeeze().numpy(),
        sample_rate
    )
    
    print("\n" + "="*50)
    print("✓ HOÀN THÀNH!")
    print(f"✓ File audio đã được lưu tại: {OUTPUT_PATH}")
    print("="*50 + "\n")


if __name__ == "__main__":
    main()