In [None]:
!git clone https://github.com/HZhalex/F5-TTS-Vietnamese

In [None]:
import os
os.makedirs("ckpts/F5-TTS-Vietnamese", exist_ok=True)
!wget -O ckpts/F5-TTS-Vietnamese/model_500000.pt https://huggingface.co/hynt/F5-TTS-Vietnamese-ViVoice/resolve/main/model_last.pt

In [None]:
!pip install torchcodec 

In [None]:
%cd F5-TTS-Vietnamese

!pip install -e .

In [None]:
import os
import torch
import torchaudio
import soundfile as sf
from vocos import Vocos
from f5_tts.model import DiT, CFM
from f5_tts.model.utils import get_tokenizer, convert_char_to_pinyin
from tqdm import tqdm
import warnings
from concurrent.futures import ThreadPoolExecutor
import queue
import time
warnings.filterwarnings("ignore")

In [None]:
MODEL_PATH = "/content/ckpts/F5-TTS-Vietnamese/model_500000.pt"
VOCAB_PATH = "/content/F5-TTS-Vietnamese/vocab.txt"
REF_AUDIO_DIR = "/content/ref_audios"
REF_TEXT = "Anh chỉ muốn được nhìn nhận như là một huấn luyện viên."
BASE_DIR = "/content/drive/MyDrive/task data audio/tinh_ban"
SPLITS = ["train", "val", "test"]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# SINGLE-THREAD GENERATION (vì model không support parallel)
INFERENCE_STEPS = 16  # Quality cao như bạn yêu cầu
NUM_REF_AUDIOS = 104

# ASYNC I/O để tận dụng tài nguyên
NUM_SAVE_WORKERS = 16  # Save file song song
SAVE_BUFFER_SIZE = 500  # Buffer lớn

# Tối ưu PyTorch tối đa
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.set_float32_matmul_precision('high')
if DEVICE == "cuda":
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    # Tắt gradient computation hoàn toàn
    torch.set_grad_enabled(False)

print(f"   - FP16: {DEVICE == 'cuda'}")
print(f"   - TF32: {DEVICE == 'cuda'}")
print(f"   - Async save: {NUM_SAVE_WORKERS} workers")
print(f"   - Inference steps: {INFERENCE_STEPS}")

In [None]:
print("\n Loading model")
vocab_char_map, vocab_size = get_tokenizer(VOCAB_PATH, tokenizer="custom")
model = CFM(
    transformer=DiT(
        dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, text_mask_padding=False,
        text_num_embeds=vocab_size, mel_dim=100, conv_layers=4, pe_attn_head=1
    ),
    mel_spec_kwargs=dict(
        n_fft=1024, hop_length=256, win_length=1024,
        n_mel_channels=100, target_sample_rate=24000, mel_spec_type="vocos"
    ),
    odeint_kwargs=dict(method="euler"),
    vocab_char_map=vocab_char_map,
).to(DEVICE)

checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)
if "ema_model_state_dict" in checkpoint:
    state_dict = {k.replace("ema_model.", ""): v for k, v in checkpoint["ema_model_state_dict"].items()
                  if k not in ["initted", "step"]}
elif "model_state_dict" in checkpoint:
    state_dict = checkpoint["model_state_dict"]
else:
    state_dict = checkpoint
model.load_state_dict(state_dict)
model.eval()

# FP16 conversion
if DEVICE == "cuda":
    model = model.half()
    print(" Model converted to FP16")

vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(DEVICE)
if DEVICE == "cuda":
    vocoder = vocoder.half()

print(" Model loaded successfully!")

In [None]:
print("\n Loading reference audios")
ref_audios = {}
ref_lens = {}

for i in range(1, NUM_REF_AUDIOS + 1):
    ref_path = os.path.join(REF_AUDIO_DIR, f"id_{i}.wav")
    if not os.path.exists(ref_path):
        continue

    audio, sr = torchaudio.load(ref_path)
    if audio.shape[0] > 1:
        audio = torch.mean(audio, dim=0, keepdim=True)

    rms = torch.sqrt(torch.mean(torch.square(audio)))
    if rms < 0.1:
        audio = audio * 0.1 / rms

    if sr != 24000:
        audio = torchaudio.transforms.Resample(sr, 24000)(audio)

    ref_audios[i] = audio.to(DEVICE)
    if DEVICE == "cuda":
        ref_audios[i] = ref_audios[i].half()
    ref_lens[i] = audio.shape[-1] // 256

print(f" Loaded {len(ref_audios)}/{NUM_REF_AUDIOS} reference audios")

missing_refs = [i for i in range(1, NUM_REF_AUDIOS + 1) if i not in ref_audios]
if missing_refs:
    print(f"  Missing {len(missing_refs)} refs: {missing_refs[:5]}...")
else:
    print(f" All {NUM_REF_AUDIOS} reference audios present!")

ref_text_len = len(REF_TEXT.encode("utf-8"))

In [None]:
@torch.inference_mode()
def generate_audio_optimized(ref_audio, ref_len, gen_text):
    """Highly optimized single generation"""
    # Convert text (CPU bound)
    text_list = [REF_TEXT + " " + gen_text]
    final_text = convert_char_to_pinyin(text_list)

    # Calculate duration
    gen_text_len = len(gen_text.encode("utf-8"))
    duration = ref_len + int(ref_len / ref_text_len * gen_text_len)

    # Generate (GPU bound)
    generated, _ = model.sample(
        cond=ref_audio,
        text=final_text,
        duration=duration,
        steps=INFERENCE_STEPS,
        cfg_strength=2.0,
        sway_sampling_coef=-1.0
    )

    # Post-process
    generated = generated[:, ref_len:, :]
    gen_mel_spec = generated.permute(0, 2, 1)

    # Decode with autocast
    with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
        wave = vocoder.decode(gen_mel_spec)

    return wave.squeeze().cpu().float().numpy()



save_queue = queue.Queue(maxsize=SAVE_BUFFER_SIZE)
save_stats = {'saved': 0, 'errors': 0}

In [None]:
def save_worker():
    while True:
        item = save_queue.get()
        if item is None:
            break
        path, wave = item
        try:
            sf.write(path, wave, 24000)
            save_stats['saved'] += 1
        except Exception as e:
            save_stats['errors'] += 1
        save_queue.task_done()

save_executor = ThreadPoolExecutor(max_workers=NUM_SAVE_WORKERS)
for _ in range(NUM_SAVE_WORKERS):
    save_executor.submit(save_worker)

print(f"\n Save workers started ({NUM_SAVE_WORKERS} threads)")

In [None]:
total_start = time.time()
total_generated_all = 0

for split in SPLITS:
    split_start = time.time()


    labels_path = os.path.join(BASE_DIR, split, "labels.txt")
    audio_dir = os.path.join(BASE_DIR, split, "audio")
    os.makedirs(audio_dir, exist_ok=True)

    if not os.path.exists(labels_path):
        print(f"{labels_path} not found, skipping...")
        continue

    with open(labels_path, "r", encoding="utf-8") as f:
        labels = [line.strip() for line in f if line.strip()]

    total_samples = len(labels)
    print(f" Total samples: {total_samples:,}")

    samples_per_ref = total_samples // NUM_REF_AUDIOS
    remainder = total_samples % NUM_REF_AUDIOS
    print(f" Distribution: ~{samples_per_ref} samples per ref audio")
    print(f" Using all {len(ref_audios)} reference audios\n")

    sample_idx = 0
    generated_count = 0
    skipped_count = 0
    error_count = 0
    refs_used = 0

    pbar = tqdm(
        total=total_samples,
        desc=f"  {split.upper()}",
        ncols=100,
        bar_format='{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]'
    )

    # Speed tracking
    speed_window = []
    last_update = time.time()
    last_count = 0

    for ref_id in range(1, NUM_REF_AUDIOS + 1):
        if ref_id not in ref_audios:
            continue

        refs_used += 1
        num_samples = samples_per_ref + (1 if ref_id <= remainder else 0)

        if sample_idx >= total_samples:
            break

        end_idx = min(sample_idx + num_samples, total_samples)
        ref_labels = labels[sample_idx:end_idx]

        for i, text in enumerate(ref_labels):
            idx = sample_idx + i + 1
            output_path = os.path.join(audio_dir, f"{idx}.wav")

            # Skip existing
            if os.path.exists(output_path):
                skipped_count += 1
                pbar.update(1)
                continue

            try:
                # Generate audio
                wave = generate_audio_optimized(
                    ref_audios[ref_id],
                    ref_lens[ref_id],
                    text
                )

                # Queue for async save
                save_queue.put((output_path, wave))
                generated_count += 1

                # Update progress
                pbar.update(1)

                # Calculate and display speed every 5 seconds
                now = time.time()
                if now - last_update >= 5.0:
                    interval = now - last_update
                    count_delta = generated_count - last_count
                    instant_speed = count_delta / interval

                    speed_window.append(instant_speed)
                    if len(speed_window) > 12:  # Keep last 60 seconds
                        speed_window.pop(0)

                    avg_speed = sum(speed_window) / len(speed_window)

                    pbar.set_postfix({
                        'speed': f'{instant_speed:.1f}f/s',
                        'avg': f'{avg_speed:.1f}f/s',
                        'queue': save_queue.qsize()
                    })

                    last_update = now
                    last_count = generated_count

            except Exception as e:
                error_count += 1
                if error_count <= 5:  # Only show first 5 errors
                    pbar.write(f" Error at sample {idx}: {str(e)[:70]}")
                pbar.update(1)

        sample_idx = end_idx

    # Wait for all saves to complete
    save_queue.join()
    pbar.close()

    split_time = time.time() - split_start
    speed = generated_count / split_time if split_time > 0 else 0
    total_generated_all += generated_count

    print(f"\n{'─'*70}")
    print(f" {split.upper()} Completed:")
    print(f"   Refs used:      {refs_used}/{NUM_REF_AUDIOS}")
    print(f"   Generated:      {generated_count:,} files")
    print(f"   Skipped:        {skipped_count:,} files")
    print(f"   Errors:         {error_count:,} files")
    print(f"   Time:           {split_time/60:.1f} minutes ({split_time:.0f}s)")
    print(f"   Average speed:  {speed:.2f} files/sec")

    if speed > 0 and generated_count > 0:
        # Project time for remaining data
        remaining_splits = len(SPLITS) - SPLITS.index(split) - 1
        if remaining_splits > 0:
            estimated_remaining = (total_samples * remaining_splits) / speed / 60
            print(f"   Est. remaining: ~{estimated_remaining:.0f} minutes")

    print(f"{'─'*70}")

In [None]:
# Shutdown save workers
for _ in range(NUM_SAVE_WORKERS):
    save_queue.put(None)
save_executor.shutdown(wait=True)

total_time = time.time() - total_start
avg_speed = total_generated_all / total_time if total_time > 0 else 0

print(f"\n{'='*70}")
print(f" ALL PROCESSING COMPLETE!")
print(f"{'='*70}")
print(f"   Total generated:  {total_generated_all:,} files")
print(f"   Total time:       {total_time/60:.1f} minutes ({total_time/3600:.2f} hours)")
print(f"   Average speed:    {avg_speed:.2f} files/sec")
print(f"   Files saved:      {save_stats['saved']:,}")
print(f"   Save errors:      {save_stats['errors']:,}")
print(f"{'='*70}")