In [None]:
import torch
import re
import soundfile as sf
from IPython.display import Audio
from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor

import os
import io
import librosa
import numpy as np
from datasets import load_dataset, VerificationMode
from speechbrain.pretrained import EncoderClassifier




print("Loading trained model (step 4000), processor, and vocoder...")

repo_id = "oopssuper96/speecht5_finetuned_emirhan_tr"
revision_id = "804163c" 

print(f"Loading processor from {repo_id} (revision {revision_id})...")
processor = SpeechT5Processor.from_pretrained(
    repo_id, 
    revision=revision_id
)

print(f"Loading model from {repo_id} (revision {revision_id})...")
model = SpeechT5ForTextToSpeech.from_pretrained(
    repo_id,
    revision=revision_id
)
# ===============================================

vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
vocoder.to(device)
print(f"Models loaded and moved to {device}.")

# === 2. Định nghĩa lại CHÍNH XÁC các hàm tiền xử lý ===
def normalize_text(text):
    text = text.lower()
    text = re.sub(r'[^\w\s\']', '', text)
    text = ' '.join(text.split())
    return text

replacements = [
    ("à","af"),("á","as"),("ả","ar"),("ã","ax"),("ạ","aj"),
    ("ă","ah"),("ằ","ahf"),("ắ","ahs"),("ẳ","ahr"),("ẵ","ahx"),("ặ","ahj"),
    ("â","ay"),("ầ","ayf"),("ấ","ays"),("ẩ","ayr"),("ẫ","ayx"),("ậ","ayj"),
    ("è","ef"),("é","es"),("ẻ","er"),("ẽ","ex"),("ẹ","ej"),
    ("ê","ee"),("ề","eef"),("ế","ees"),("ể","eer"),("ễ","eex"),("ệ","eej"),
    ("ì","if"),("í","is"),("ỉ","ir"),("ĩ","ix"),("ị","ij"),
    ("ò","of"),("ó","os"),("ỏ","or"),("õ","ox"),("ọ","oj"),
    ("ô","oh"),("ồ","ohf"),("ố","ohs"),("ổ","ohr"),("ỗ","ohx"),("ộ","ohj"),
    ("ư","uw"),("ừ","uwf"),("ứ","uws"),("ử","uwr"),("ữ","uwx"),("ự","uwj"),
    ("ơ","ow"),("ờ","owf"),("ớ","ows"),("ở","owr"),("ỡ","owx"),("ợ","owj"),
    ("ù","uf"),("ú","us"),("ủ","ur"),("ũ","ux"),("ụ","uj"),
    ("ỳ","yf"),("ý","ys"),("ỷ","yr"),("ỹ","yx"),("ỵ","yj"),
    ("đ","d"),("gi","z"),("d","z"),("r","zh"),("x","s"),("s","sh"),
    ("tr","chr"),("ch","ch"),("th","th"),("ph","f"),("kh","kh"),("nh","nh"),("ng","ng"),("gh","g"),
]   
# ======================================

def cleanup_text(text):
    normalized_text = normalize_text(text)
    for src, dst in replacements:
        normalized_text = normalized_text.replace(src, dst)
    return normalized_text

# === 3. Chuẩn bị đầu vào ===

text = "Xin chào, tôi tên là đạt, tôi đang thử nghiệm mô hình này"

print(f"Original text: {text}")

final_text = cleanup_text(text)
print(f"Processed text: {final_text}")

inputs = processor(text=final_text, return_tensors="pt").to(device)

# === 4. Tải speaker model và tạo speaker embedding THẬT ===
print("Loading SpeechBrain X-vector model...")
spk_model_name = "speechbrain/spkrec-xvect-voxceleb"
speaker_model = EncoderClassifier.from_hparams(
    source=spk_model_name,
    run_opts={"device": device},
    savedir=os.path.join("/tmp", spk_model_name),
)

def create_speaker_embedding(waveform):
    if not isinstance(waveform, torch.Tensor):
        waveform = torch.tensor(waveform)
    waveform = waveform.to(device)
    with torch.no_grad():
        speaker_embeddings = speaker_model.encode_batch(waveform)
        speaker_embeddings = torch.nn.functional.normalize(speaker_embeddings, dim=2)
        speaker_embeddings = speaker_embeddings.squeeze(0) 
    return speaker_embeddings.cpu().numpy()

def get_real_speaker_embedding():
    print("Loading one sample from dataset for speaker embedding...")
    try:
        sample_dataset = load_dataset(
            "parquet",
            data_files="https://huggingface.co/datasets/NhutP/VietSpeech/resolve/main/data/train-00018-of-00027.parquet",
            split="train[1:2]", 
            verification_mode=VerificationMode.NO_CHECKS,
        )
        sample = sample_dataset[0]
        audio_data = sample["audio"]
        if "bytes" in audio_data and audio_data["bytes"] is not None:
            speech_array, sampling_rate = sf.read(io.BytesIO(audio_data["bytes"]))
        elif "path" in audio_data and audio_data["path"] is not None:
            speech_array, sampling_rate = sf.read(audio_data["path"])
        else:
            raise ValueError("Audio data invalid")
        if sampling_rate != 16000:
            speech_array = librosa.resample(
                speech_array, orig_sr=sampling_rate, target_sr=16000
            )
        if speech_array.ndim > 1:
            speech_array = speech_array.mean(axis=1)
        embedding = create_speaker_embedding(speech_array)
        print("Real speaker embedding created.")
        return torch.tensor(embedding).unsqueeze(0).to(device)
    except Exception as e:
        print(f"WARNING: Could not load real speaker embedding: {e}")
        print("Falling back to random embedding (will likely be noise).")
        return torch.randn((1, 512)).to(device)

speaker_embeddings = get_real_speaker_embedding()

# === 5. Tạo giọng nói ===
print("Generating speech...")
with torch.no_grad():
    speech = model.generate_speech(
        inputs["input_ids"], 
        speaker_embeddings, 
        vocoder=vocoder
    )

print("Speech generated!")
speech_numpy = speech.cpu().numpy()

# === 6. Phát và lưu file ===
output_filename = "vietnamese_output_step3500.wav"
sf.write(output_filename, speech_numpy, samplerate=16000)
print(f"Audio saved to {output_filename}")
print(f"Run 'Audio(data=sf.read('{output_filename}')[0], rate=16000)' in a notebook to listen.")
Audio(data=speech_numpy, rate=16000)

Loading trained model (step 4000), processor, and vocoder...
Loading processor from oopssuper96/speecht5_finetuned_emirhan_tr (revision 804163c)...
Loading model from oopssuper96/speecht5_finetuned_emirhan_tr (revision 804163c)...
Models loaded and moved to cuda.
Original text: Xin chào, tôi tên là đạt, tôi đang thử nghiệm mô hình này
Processed text: shin chafo tohi teen laf zajt tohi zang thuwzh ngieejm moh hifnh nafy
Loading SpeechBrain X-vector model...
Loading one sample from dataset for speaker embedding...
Falling back to random embedding (will likely be noise).
Generating speech...
Speech generated!
Audio saved to vietnamese_output_step3500.wav
Run 'Audio(data=sf.read('vietnamese_output_step3500.wav')[0], rate=16000)' in a notebook to listen.
