In [None]:
import re
import numpy as np
import soundfile as sf
from IPython.display import Audio as play
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf

from f5_tts.infer.utils_infer import (
    cfg_strength,
    cross_fade_duration,
    fix_duration,
    infer_process,
    load_model,
    load_vocoder,
    nfe_step,
    speed,
    sway_sampling_coef,
    target_rms,
)

from f5_tts.infer.utils_xeus import ApplyKmeans, load_xeus_model, extract_units

In [None]:
device = "cuda" # Set to your device
vocoder_name= "bigvgan"
config_file = "../configs/F5TTS_Base_EZ-VC.yaml"
ckpt_file = ckpt_file = str(cached_path(f"hf://SPRINGLab/EZ-VC/model_2700000.safetensors"))
vocab_file = str(cached_path(f"hf://SPRINGLab/EZ-VC/vocab.txt"))

In [None]:
# load XEUS model

xeus_model = load_xeus_model(device).eval()
apply_kmeans = ApplyKmeans(device)

In [None]:
# load vocoder

vocoder = load_vocoder(vocoder_name=vocoder_name, device=device)

In [None]:
# load TTS model

model_cfg = OmegaConf.load(config_file)
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch

ema_model = load_model(
    model_cls,
    model_arc,
    ckpt_file,
    mel_spec_type=vocoder_name,
    vocab_file=vocab_file,
    device=device
)

In [None]:
# inference process

def infer(ref_audio, ref_text, gen_text: str, spd: float = speed, nfe: int = nfe_step):

    generated_audio_segments = []
    reg1 = r"(?=\[\w+\])"
    chunks = re.split(reg1, gen_text)
    reg2 = r"\[(\w+)\]"
    for text in chunks:
        text = re.sub(reg2, "", text)
        gen_text_ = text.strip()

        audio_segment, final_sample_rate, spectrogram = infer_process(
            ref_audio,
            ref_text,
            gen_text_,
            ema_model,
            vocoder,
            mel_spec_type=vocoder_name,
            target_rms=target_rms,
            cross_fade_duration=cross_fade_duration,
            nfe_step=nfe,
            cfg_strength=cfg_strength,
            sway_sampling_coef=sway_sampling_coef,
            speed=spd,
            fix_duration=fix_duration,
            device=device,
        )
        generated_audio_segments.append(audio_segment)


    if generated_audio_segments:
        return np.concatenate(generated_audio_segments)
    else:
        print("No audio segments generated.")
        return None

In [None]:
ref_audio = "examples/wavs/14_208_000042_000000.wav"
# ref_text = "\ue18a\ue120\ue02a\ue121\ue001\ue091\ue09f\ue18c\ue1ca\ue134\ue09a\ue174\ue11a\ue195\ue03a\ue128\ue094\ue0b8\ue149\ue0d7\ue1d8\ue167\ue13c\ue069\ue149\ue0d4\ue074\ue0b8\ue0fe\ue152\ue06b\ue174\ue1c0\ue149\ue1a2\ue1b8\ue07c\ue08e\ue1a8\ue175\ue005\ue0d1\ue132\ue175\ue0ee\ue022\ue1ca\ue132\ue0f7\ue0f2\ue05b\ue1d8\ue071\ue123\ue19c\ue100\ue12f\ue02a\ue12f\ue100\ue062\ue10a\ue116\ue1f3\ue116\ue1f3\ue1cd\ue16e\ue01f\ue19a\ue12f\ue02a\ue0a6\ue05e\ue1cf\ue1a4\ue096\ue0dd\ue006\ue034\ue005\ue102\ue0d1\ue1a1\ue0e3\ue1a7\ue1a2\ue1a7\ue0cf\ue00a\ue1ae\ue1c8\ue17d\ue0b0\ue1c8\ue1be\ue1d8\ue04a\ue0d6\ue0fd\ue1cb\ue098\ue106\ue048\ue1d8\ue0e4\ue1e7\ue1b7\ue074\ue1a2\ue0f7\ue1d5\ue1ba\ue00e\ue161\ue040\ue14e\ue1e1\ue1da\ue1d8\ue197\ue04a\ue0d6\ue1ba\ue00e\ue182\ue0b9\ue152\ue0fe\ue152\ue106\ue022\ue1de\ue1e7\ue08f\ue11b\ue046\ue155\ue0d4\ue004\ue1c3\ue0cd\ue1bd\ue1ba\ue182\ue0b9\ue1a4\ue096\ue0dd\ue175\ue005\ue0ce\ue1a1\ue038\ue1a0\ue195\ue0b8\ue101\ue04b\ue17d\ue1bc\ue1e8\ue134\ue14c\ue005\ue1de\ue1ca\ue1a7\ue07c\ue1b8\ue0f2\ue05b\ue1d8\ue1de\ue0be\ue123\ue02a\ue131\ue10a\ue116\ue015\ue19a\ue12f\ue02a\ue1e9\ue197\ue13c\ue06e\ue0fd\ue1a6\ue152\ue0fe\ue108\ue004\ue0c6\ue13f\ue1b4\ue0d9\ue14e\ue1e1\ue111\ue039\ue1c3\ue1cf\ue152\ue098\ue09d\ue0b7\ue1c8\ue17d\ue1d8\ue071\ue1be\ue1d8\ue167\ue13c\ue158\ue152\ue108\ue0ac\ue1dd\ue02a\ue148\ue1b2\ue116\ue117\ue1c4\ue19a\ue12f\ue131\ue1b6\ue14e\ue1e1\ue0d7\ue197\ue0d6\ue00e\ue02e\ue098\ue0ed\ue071\ue000\ue08f\ue11b\ue1b7\ue048\ue0d6\ue08f\ue11b\ue14e\ue1e1\ue09f\ue181\ue1c8\ue1be\ue167\ue197\ue13c\ue094\ue10f\ue108\ue1bd\ue132\ue1d5\ue0c1\ue129\ue1aa\ue161\ue0aa\ue0b9\ue10c\ue1cb\ue098\ue009\ue108\ue0ac\ue05d\ue11e\ue19c\ue11f\ue02a\ue12f "
ref_text = extract_units(ref_audio, xeus_model, apply_kmeans, device)

In [None]:
src_wav = "examples/wavs/237_134493_000015_000004.wav"
src_text = extract_units(src_wav, xeus_model, apply_kmeans, device)

In [None]:
gen_wav = infer(ref_audio, ref_text, src_text, spd=1.0, nfe=12)

In [None]:
# Target speaker
play(ref_audio, rate=16000)

In [None]:
# Source speech
play(src_wav, rate=16000)

In [None]:
# Generated audio
play(gen_wav, rate=16000)

In [None]:
# Save the generated audio
sf.write("gen_audio.wav", gen_wav, 16000)