In [None]:
# ==== CELL 1: INSTALL & SETUP ====
# Th√™m 'torchvision' v√†o ƒë√¢y ƒë·ªÉ ƒë·ªìng b·ªô phi√™n b·∫£n v·ªõi torch v√† torchaudio
!pip install -q "TTS>=0.22.0" "torch<2.6" torchaudio torchvision transformers datasets accelerate torchcodec

!pip install -q librosa soundfile pyarrow pyyaml tqdm

!apt-get update -qq && apt-get install -y espeak-ng ffmpeg

import os, torch, warnings, pathlib
warnings.filterwarnings('ignore')

# Kaggle cache dirs (gi·∫£m l·ªói out-of-quota / write-permission)
os.environ["HF_HOME"] = "/kaggle/working/hf_home"
os.environ["HF_DATASETS_CACHE"] = "/kaggle/working/hf_datasets"
os.environ["HUGGINGFACE_HUB_CACHE"] = "/kaggle/working/hf_hub"
for d in ["/kaggle/working/hf_home", "/kaggle/working/hf_datasets", "/kaggle/working/hf_hub"]:
    pathlib.Path(d).mkdir(parents=True, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"‚úÖ PyTorch: {torch.__version__}")
print(f"‚úÖ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")
    print(f"‚úÖ VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.2f} GB")
print("‚úÖ Setup done.")

In [None]:
# ==== CELL 3: SHARD DATA PREP ====
import os, gc, numpy as np, pyarrow.parquet as pq, soundfile as sf, librosa
from io import BytesIO
from huggingface_hub import hf_hub_download
from tqdm import tqdm

REPO_ID = "NhutP/VietSpeech"                     # Dataset HF (ƒë√∫ng theo b·∫°n n√™u)
DATASET_DIR = "/kaggle/working/xtts_dataset"     # Th∆∞ m·ª•c c·ªë ƒë·ªãnh cho XTTS
WAV_DIR = f"{DATASET_DIR}/wavs"
os.makedirs(WAV_DIR, exist_ok=True)

def build_metadata_for_shard(
    shard_idx: int,
    repo_id: str = REPO_ID,
    max_samples: int = 2000,     # s·ªë m·∫´u m·ªói shard (t√πy VRAM/dung l∆∞·ª£ng)
    min_dur: float = 1.0,
    max_dur: float = 15.0,
    target_sr: int = 22050
) -> int:
    """
    - T·∫£i 1 shard parquet -> gi·∫£i m√£ audio h·ª£p l·ªá -> l∆∞u WAV 16-bit mono 22.05kHz
    - Ghi ƒê√à metadata.csv (format XTTS): audio_file|text|speaker_name|language
    - Tr·∫£ v·ªÅ s·ªë m·∫´u ƒë√£ ghi
    """
    # X√≥a WAV c≈© ƒë·ªÉ tr·ªëng ch·ªó
    for fn in os.listdir(WAV_DIR):
        try: os.remove(os.path.join(WAV_DIR, fn))
        except: pass

    metadata_lines = []
    written = 0

    filename = f"data/train-{shard_idx:05d}-of-00027.parquet"
    local_path = None
    try:
        print(f"\nüì• Shard {shard_idx+1}/27 ‚Üí {filename}")
        local_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
        pf = pq.ParquetFile(local_path)

        for batch in pf.iter_batches(batch_size=256):
            batch_dict = batch.to_pydict()
            n = len(batch_dict["audio"])
            for i in range(n):
                if written >= max_samples: break
                try:
                    audio_item = batch_dict["audio"][i]
                    audio_bytes = audio_item["bytes"]
                    wav, sr = sf.read(BytesIO(audio_bytes))  # np.float

                    dur = len(wav)/sr
                    if dur < min_dur or dur > max_dur:
                        continue

                    # resample & mono
                    if sr != target_sr:
                        wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
                    if wav.ndim > 1:
                        wav = wav.mean(axis=1)

                    # normalize & quantize PCM_16 (ti·∫øt ki·ªám dung l∆∞·ª£ng)
                    if np.max(np.abs(wav)) > 0:
                        wav = wav/np.max(np.abs(wav))*0.95
                    wav_i16 = (wav * 32767.0).astype(np.int16)

                    # text
                    text = batch_dict.get("transcription", [None])[i] or batch_dict.get("text", [None])[i] or ""
                    text = text.strip()
                    if not (3 <= len(text) <= 500):
                        continue

                    # save wav
                    out_name = f"vi_{shard_idx:02d}_{written:06d}.wav"
                    out_path = os.path.join(WAV_DIR, out_name)
                    sf.write(out_path, wav_i16, target_sr, subtype="PCM_16")

                    # metadata row (XTTS: audio|text|speaker|language)
                    speaker = f"spk_{shard_idx:02d}"
                    metadata_lines.append(f"{out_name}|{text}|{speaker}|vi")

                    written += 1
                except Exception:
                    continue
            if written >= max_samples:
                break
    finally:
        if local_path and os.path.exists(local_path):
            try: os.remove(local_path)
            except: pass
        gc.collect()

    meta_path = os.path.join(DATASET_DIR, "metadata.csv")
    with open(meta_path, "w", encoding="utf-8") as f:
        f.write("\n".join(metadata_lines))

    print(f"‚úÖ Shard {shard_idx+1}: wrote {written} samples ‚Üí {meta_path}")
    return written

print("‚úÖ Shard prep functions ready.")


In [None]:
# ==== CELL 3: SHARD DATA PREP ====
import os, gc, numpy as np, pyarrow.parquet as pq, soundfile as sf, librosa
from io import BytesIO
from huggingface_hub import hf_hub_download
from tqdm import tqdm

REPO_ID = "NhutP/VietSpeech"                     # Dataset HF (ƒë√∫ng theo b·∫°n n√™u)
DATASET_DIR = "/kaggle/working/xtts_dataset"     # Th∆∞ m·ª•c c·ªë ƒë·ªãnh cho XTTS
WAV_DIR = f"{DATASET_DIR}/wavs"
os.makedirs(WAV_DIR, exist_ok=True)

def build_metadata_for_shard(
    shard_idx: int,
    repo_id: str = REPO_ID,
    max_samples: int = 2000,     # s·ªë m·∫´u m·ªói shard (t√πy VRAM/dung l∆∞·ª£ng)
    min_dur: float = 1.0,
    max_dur: float = 15.0,
    target_sr: int = 22050
) -> int:
    """
    - T·∫£i 1 shard parquet -> gi·∫£i m√£ audio h·ª£p l·ªá -> l∆∞u WAV 16-bit mono 22.05kHz
    - Ghi ƒê√à metadata.csv (format XTTS): audio_file|text|speaker_name|language
    - Tr·∫£ v·ªÅ s·ªë m·∫´u ƒë√£ ghi
    """
    # X√≥a WAV c≈© ƒë·ªÉ tr·ªëng ch·ªó
    for fn in os.listdir(WAV_DIR):
        try: os.remove(os.path.join(WAV_DIR, fn))
        except: pass

    metadata_lines = []
    written = 0

    filename = f"data/train-{shard_idx:05d}-of-00027.parquet"
    local_path = None
    try:
        print(f"\nüì• Shard {shard_idx+1}/27 ‚Üí {filename}")
        local_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
        pf = pq.ParquetFile(local_path)

        for batch in pf.iter_batches(batch_size=256):
            batch_dict = batch.to_pydict()
            n = len(batch_dict["audio"])
            for i in range(n):
                if written >= max_samples: break
                try:
                    audio_item = batch_dict["audio"][i]
                    audio_bytes = audio_item["bytes"]
                    wav, sr = sf.read(BytesIO(audio_bytes))  # np.float

                    dur = len(wav)/sr
                    if dur < min_dur or dur > max_dur:
                        continue

                    # resample & mono
                    if sr != target_sr:
                        wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
                    if wav.ndim > 1:
                        wav = wav.mean(axis=1)

                    # normalize & quantize PCM_16 (ti·∫øt ki·ªám dung l∆∞·ª£ng)
                    if np.max(np.abs(wav)) > 0:
                        wav = wav/np.max(np.abs(wav))*0.95
                    wav_i16 = (wav * 32767.0).astype(np.int16)

                    # text
                    text = batch_dict.get("transcription", [None])[i] or batch_dict.get("text", [None])[i] or ""
                    text = text.strip()
                    if not (3 <= len(text) <= 500):
                        continue

                    # save wav
                    out_name = f"vi_{shard_idx:02d}_{written:06d}.wav"
                    out_path = os.path.join(WAV_DIR, out_name)
                    sf.write(out_path, wav_i16, target_sr, subtype="PCM_16")

                    # metadata row (XTTS: audio|text|speaker|language)
                    speaker = f"spk_{shard_idx:02d}"
                    metadata_lines.append(f"{out_name}|{text}|{speaker}|vi")

                    written += 1
                except Exception:
                    continue
            if written >= max_samples:
                break
    finally:
        if local_path and os.path.exists(local_path):
            try: os.remove(local_path)
            except: pass
        gc.collect()

    meta_path = os.path.join(DATASET_DIR, "metadata.csv")
    with open(meta_path, "w", encoding="utf-8") as f:
        f.write("\n".join(metadata_lines))

    print(f"‚úÖ Shard {shard_idx+1}: wrote {written} samples ‚Üí {meta_path}")
    return written

print("‚úÖ Shard prep functions ready.")


In [None]:
# ==== CELL 4: WRITE XTTS FINETUNE CONFIG (YAML) ====
import yaml, pathlib

OUT_DIR = "/kaggle/working/xtts_vietnamese"     # n∆°i ghi checkpoint/logs
CFG_PATH = f"{OUT_DIR}/config_ft.yaml"
pathlib.Path(OUT_DIR).mkdir(parents=True, exist_ok=True)

config = {
    "model": "xtts_v2",                # model family
    "output_path": OUT_DIR,
    "run_name": "xtts_vi_stream",
    "logger": "tensorboard",

    # ==== DATASET ====
    "dataset_config": [
        {
            "formatter": "ljspeech",   # d√πng formatter ljspeech: metadata + wavs
            "meta_file_train": f"{DATASET_DIR}/metadata.csv",
            "meta_file_val": f"{DATASET_DIR}/metadata.csv",   # t·∫°m th·ªùi d√πng chung
            "path": DATASET_DIR,       # base path
            "audio_dir": WAV_DIR,      # th∆∞ m·ª•c wavs
            "language": "vi",
            "n_val": 64,               # m·ªói shard validation t·∫°m 64 m·∫´u
        }
    ],

    # ==== AUDIO ====
    "audio": {
        "sample_rate": 22050,
        "num_mels": 80,          # XTTS d√πng mel-80 n·ªôi b·ªô (kh√¥ng c·∫ßn t·ª± t√≠nh)
        "fft_size": 1024,
        "hop_length": 256
    },

    # ==== OPTIMIZATION ====
    "optimizer": {
        "type": "adamw",
        "lr": 5e-5,
        "betas": [0.9, 0.98],
        "eps": 1e-9,
        "weight_decay": 0.0
    },

    # ==== TRAINER ====
    "trainer": {
        "max_steps": 1500,       # s·∫Ω b·ªã ghi ƒë√® m·ªói shard
        "eval_steps": 250,
        "save_step": 500,
        "log_step": 50,
        "precision": "fp16",
        "batch_size": 2,
        "grad_accum": 4,         # effective batch ~8
        "num_workers": 2,
        "pin_memory": True,
        "cudnn_benchmark": True,
        "seed": 42
    },

    # ==== LOSSES (ƒë·ªÉ trainer t·ª± c·∫•u h√¨nh m·∫∑c ƒë·ªãnh XTTS) ====
    "losses": {
        "use_xtts_defaults": True
    },

    # ==== FREEZE POLICY (nh·∫π nh√†ng; c√≥ th·ªÉ b·ªè n·∫øu mu·ªën full) ====
    "freeze": {
        # g·ª£i √Ω: cho ph√©p GPT + decoder h·ªçc; freeze ph·∫ßn √≠t quan tr·ªçng
        "freeze_encoder": False,
        "freeze_gpt": False,
        "freeze_decoder": False
    }
}

with open(CFG_PATH, "w") as f:
    yaml.safe_dump(config, f, sort_keys=False, allow_unicode=True)

print("‚úÖ Wrote YAML:", CFG_PATH)
print(open(CFG_PATH).read())


In [None]:
print("heellloo")
