In [3]:
import sys, torch, transformers
print("PY:", sys.executable)
print("torch:", torch.__version__, "| build CUDA:", torch.version.cuda)
print("transformers:", transformers.__version__)


PY: c:\Users\Dung\anaconda3\envs\aic2025\python.exe
torch: 2.8.0+cu129 | build CUDA: 12.9
transformers: 4.55.4


In [12]:
import os, re, json, sys
import torch, librosa
from tqdm import tqdm
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC

# === Tuỳ chỉnh đường dẫn ===
AUDIO_ROOT = r"D:\VN_Multi_User_Video_Search\dataset_extraction\audio\audios"
SHOTS_ROOT = r"D:\VN_Multi_User_Video_Search\dataset_extraction\audio\audio_detection"
OUT_ROOT   = "./audio"
REPROCESS_POLICY = "skip_if_exists"  # "skip_if_exists" | "recompute_if_input_newer"

# === Regex cấu trúc tên ===
FOLDER_RE = re.compile(r"^[KL]\d+$", re.IGNORECASE)
AUDIO_RE  = re.compile(r"^V\d+\.wav$", re.IGNORECASE)

# === Model ===
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
repo = "nguyenvulebinh/wav2vec2-base-vietnamese-250h"
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = Wav2Vec2Processor.from_pretrained(repo, use_safetensors=True)
model = Wav2Vec2ForCTC.from_pretrained(repo, use_safetensors=True).to(device).eval()

# === Gom danh sách audio: AUDIO_ROOT/<Kxx|Lxx>/Vxxx.wav ===
if not os.path.isdir(AUDIO_ROOT):
    raise FileNotFoundError(f"Missing AUDIO_ROOT: {AUDIO_ROOT}")

all_audio = {}
for part in sorted(os.listdir(AUDIO_ROOT)):
    pdir = os.path.join(AUDIO_ROOT, part)
    if not (os.path.isdir(pdir) and FOLDER_RE.match(part)): continue
    mp = {}
    for fn in sorted(os.listdir(pdir)):
        if AUDIO_RE.match(fn):
            audio_id = os.path.splitext(fn)[0]
            mp[audio_id] = os.path.join(pdir, fn)
    if mp: all_audio[part] = mp

print("Tổng audio:", sum(len(v) for v in all_audio.values()), "| Số part:", len(all_audio))
os.makedirs(OUT_ROOT, exist_ok=True)

def should_skip(out_path, audio_path, shots_path):
    if not os.path.isfile(out_path): return False
    if REPROCESS_POLICY == "skip_if_exists": return True
    if REPROCESS_POLICY == "recompute_if_input_newer":
        try:
            out_m = os.path.getmtime(out_path)
            aud_m = os.path.getmtime(audio_path)
            det_m = os.path.getmtime(shots_path) if os.path.isfile(shots_path) else 0.0
            return (out_m >= aud_m) and (out_m >= det_m)
        except: return False
    return False

processed = skipped = 0

for part, mp in tqdm(all_audio.items(), desc="Parts"):
    out_dir = os.path.join(OUT_ROOT, part)
    os.makedirs(out_dir, exist_ok=True)

    for audio_id, audio_path in tqdm(mp.items(), leave=False, desc=part):
        out_path   = os.path.join(out_dir, f"{audio_id}.json")
        shots_path = os.path.join(SHOTS_ROOT, part, f"{audio_id}.json")
        if should_skip(out_path, audio_path, shots_path):
            skipped += 1; continue

        speech, sr = librosa.load(audio_path, mono=True, sr=16000)
        n = len(speech); speech = speech.astype("float32")

        if os.path.isfile(shots_path):
            try:
                with open(shots_path, "r", encoding="utf-8") as f: shots = json.load(f)
                if not isinstance(shots, (list, tuple)): shots = [[0.0, n/sr]]
            except: shots = [[0.0, n/sr]]
        else:
            shots = [[0.0, n/sr]]

        results = []
        for st, ed in shots:
            st = float(st); ed = float(ed)
            if ed - st < 1.0: results.append(""); continue
            segs, s = [], st
            while (ed - s) >= 1.0:
                e = min(ed, s + 10.0)
                seg = speech[int(s*sr): int(e*sr)]
                if seg.size: segs.append(seg)
                if e >= ed: break
                s = e
            if not segs: results.append(""); continue

            with torch.no_grad():
                batch = processor(segs, sampling_rate=sr, return_tensors="pt", padding="longest")
                logits = model(batch.input_values.to(device)).logits
                pred = torch.argmax(logits, dim=-1)
                texts = processor.batch_decode(pred, skip_special_tokens=True)
            results.append(" ".join(t.strip() for t in texts).strip())

        tmp = out_path + ".tmp"
        with open(tmp, "w", encoding="utf-8") as f: json.dump(results, f, ensure_ascii=False)
        os.replace(tmp, out_path)
        processed += 1

print(f"Hoàn tất. Mới xử lý: {processed} | Bỏ qua: {skipped}")


Tổng số audio: 1478
Số part (Kxx/Lxx): 30


Parts: 100%|██████████| 30/30 [36:36<00:00, 73.23s/it] 

Hoàn tất.



