In [None]:
!git clone https://github.com/HZhalex/F5-TTS-Vietnamese

In [None]:
import os
os.makedirs("ckpts/F5-TTS-Vietnamese", exist_ok=True)
!wget -O ckpts/F5-TTS-Vietnamese/model_500000.pt https://huggingface.co/hynt/F5-TTS-Vietnamese-ViVoice/resolve/main/model_last.pt

In [None]:
!pip install torchcodec

In [None]:
%cd F5-TTS-Vietnamese

!pip install -e .

In [None]:
import warnings
warnings.filterwarnings("ignore", category=SyntaxWarning)

In [None]:
import torch
import torchaudio
import soundfile as sf
from vocos import Vocos
from f5_tts.model import DiT, CFM
from f5_tts.model.utils import get_tokenizer, convert_char_to_pinyin


MODEL_PATH = "/content/ckpts/F5-TTS-Vietnamese/model_500000.pt"
VOCAB_PATH = "/content/F5-TTS-Vietnamese/vocab.txt"
REF_AUDIO_PATH = "/content/Bình (nam miền Bắc).wav"
REF_TEXT = "Anh chỉ muốn được nhìn nhận như là một huấn luyện viên."
GEN_TEXT = "Xin chào, đây là văn bản tôi muốn chuyển thành giọng nói"
OUTPUT_PATH = "output.wav"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


vocab_char_map, vocab_size = get_tokenizer(VOCAB_PATH, tokenizer="custom")
model = CFM(
    transformer=DiT(
        dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, text_mask_padding=False,
        text_num_embeds=vocab_size, mel_dim=100, conv_layers=4, pe_attn_head=1
    ),
    mel_spec_kwargs=dict(
        n_fft=1024, hop_length=256, win_length=1024, 
        n_mel_channels=100, target_sample_rate=24000, mel_spec_type="vocos"
    ),
    odeint_kwargs=dict(method="euler"),
    vocab_char_map=vocab_char_map,
).to(DEVICE)

checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
if "ema_model_state_dict" in checkpoint:
    state_dict = {k.replace("ema_model.", ""): v for k, v in checkpoint["ema_model_state_dict"].items() 
                  if k not in ["initted", "step"]}
elif "model_state_dict" in checkpoint:
    state_dict = checkpoint["model_state_dict"]
else:
    state_dict = checkpoint
model.load_state_dict(state_dict)
model.eval()


vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(DEVICE)


audio, sr = torchaudio.load(REF_AUDIO_PATH)
if audio.shape[0] > 1:
    audio = torch.mean(audio, dim=0, keepdim=True)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < 0.1:
    audio = audio * 0.1 / rms
if sr != 24000:
    audio = torchaudio.transforms.Resample(sr, 24000)(audio)
audio = audio.to(DEVICE)


text_list = [REF_TEXT + " " + GEN_TEXT]
final_text = convert_char_to_pinyin(text_list)
ref_len = audio.shape[-1] // 256
ref_text_len = len(REF_TEXT.encode("utf-8"))
gen_text_len = len(GEN_TEXT.encode("utf-8"))
duration = ref_len + int(ref_len / ref_text_len * gen_text_len)

with torch.inference_mode():
    generated, _ = model.sample(
        cond=audio, 
        text=final_text, 
        duration=duration, 
        steps=32, 
        cfg_strength=2.0, 
        sway_sampling_coef=-1.0
    )
    generated = generated[:, ref_len:, :]
    gen_mel_spec = generated.permute(0, 2, 1)
    wave = vocoder.decode(gen_mel_spec).cpu()
    if rms < 0.1:
        wave = wave * rms / 0.1


sf.write(OUTPUT_PATH, wave.squeeze().numpy(), 24000)
print(f"saved: {OUTPUT_PATH}")