In [None]:
# ============================================================
# üöÄ Fine-tune Whisper Small (OpenAI) with custom dataset ti·∫øng Vi·ªát
# Author: Mr.Jack (https://github.com/Mr-Jack-Tung)
# Date: 2025-11-02
# Description: H∆∞·ªõng d·∫´n fine-tune m√¥ h√¨nh Whisper Small c·ªßa OpenAI v·ªõi custom dataset ti·∫øng Vi·ªát
# ============================================================

In [1]:
# !pip install -q unsloth  # ‚ö° C√†i ƒë·∫∑t unsloth ƒë·ªÉ patch nhanh
!pip install -q "pyarrow<20.0.0" transformers datasets accelerate librosa jiwer evaluate

In [2]:
!pip install -q datasets soundfile torchcodec

In [3]:
# ‚ö° Install PEFT / LoRA dependencies (CPU-friendly)
!pip install -q peft accelerate safetensors
# Note: bitsandbytes is GPU-only and is not installed in this CPU-only environment

In [None]:
# 1Ô∏è‚É£ Import unsloth tr∆∞·ªõc (r·∫•t quan tr·ªçng)
# import unsloth  # ‚ö° b·∫≠t patch nhanh cho Trainer, torch, dataset

# 2Ô∏è‚É£ Import c√°c th∆∞ vi·ªán kh√°c
import torch
from datasets import load_dataset, Audio
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    TrainingArguments,
    Trainer,
)
import evaluate

In [6]:
# CPU-only notes
# - This notebook is configured for CPU-only fine-tuning. Training will be significantly slower than on GPU.
# - Keep datasets small, use small batch sizes, and prefer fewer epochs for experiments.
# - bitsandbytes and k-bit training are GPU-only and are not used here.
# - Run cells in order: installs -> imports -> model -> dataset -> training args -> LoRA prep -> training -> save

In [5]:
# 3Ô∏è‚É£ Khai b√°o model
model_name = "openai/whisper-small"
language = "vi"
task = "transcribe"

processor = WhisperProcessor.from_pretrained(model_name, language=language, task=task)
model = WhisperForConditionalGeneration.from_pretrained(model_name)

In [6]:
# 4Ô∏è‚É£ Load custom dataset
dataset = load_dataset("json", data_files="voice_label_data/metadata.json")
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

print(dataset)

DatasetDict({
    train: Dataset({
        features: ['audio', 'transcription'],
        num_rows: 6
    })
})


In [7]:
import torchaudio
import numpy as np
from datasets import load_dataset

dataset = load_dataset("json", data_files="voice_label_data/metadata.json")

def load_and_resample(batch):
    path = batch["audio"]
    waveform, sr = torchaudio.load(path)
    # Chuy·ªÉn stereo -> mono
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    # Resample v·ªÅ 16kHz
    if sr != 16000:
        waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
        sr = 16000
    batch["audio"] = {"array": waveform.squeeze(0).numpy(), "sampling_rate": sr}
    return batch

dataset_16k = dataset.map(load_and_resample)

print("\nCustom dataset:", dataset_16k)

sample = dataset_16k["train"][0]
print("\nSample data", sample.keys())
print(sample["transcription"])
print(sample["audio"])



Custom dataset: DatasetDict({
    train: Dataset({
        features: ['audio', 'transcription'],
        num_rows: 6
    })
})

Sample data dict_keys(['audio', 'transcription'])
Ch√†o b·∫°n
{'array': [-6.505216788355028e-06, 6.598177424166352e-05, -5.4998716223053634e-05, -0.00043667491991072893, -0.00043235099292360246, -0.00013816145656164736, -0.00018468918278813362, -0.0007775566191412508, -0.001126836403273046, -0.0010278073605149984, -0.0007466065580956638, -0.0002358816418563947, 2.7162963306182064e-05, -2.0874767869827338e-05, 0.000113399961264804, 0.0003635763132479042, 0.0002679301251191646, 0.0003522004117257893, 0.0003144242218695581, 0.0002645383065100759, 0.00043802964501082897, 0.000504985626321286, 0.00023536420485470444, 9.094105189433321e-05, 0.00017525105795357376, 0.0006176510360091925, 0.0006924319313839078, 0.0007539874641224742, 0.0007014681468717754, 0.0005626931088045239, 0.00047894028830341995, 0.0004359032027423382, 0.0004471737192943692, 0.00072371854912489

In [10]:
from IPython.display import Audio

# L·∫•y waveform v√† sample rate
waveform = sample["audio"]["array"]
sr = sample["audio"]["sampling_rate"]

print(f"Sample rate: {sr}, \nWaveform shape: {waveform}")

# Ph√°t audio
Audio(data=waveform, rate=sr)


Sample rate: 16000, 
Waveform shape: [-6.505216788355028e-06, 6.598177424166352e-05, -5.4998716223053634e-05, -0.00043667491991072893, -0.00043235099292360246, -0.00013816145656164736, -0.00018468918278813362, -0.0007775566191412508, -0.001126836403273046, -0.0010278073605149984, -0.0007466065580956638, -0.0002358816418563947, 2.7162963306182064e-05, -2.0874767869827338e-05, 0.000113399961264804, 0.0003635763132479042, 0.0002679301251191646, 0.0003522004117257893, 0.0003144242218695581, 0.0002645383065100759, 0.00043802964501082897, 0.000504985626321286, 0.00023536420485470444, 9.094105189433321e-05, 0.00017525105795357376, 0.0006176510360091925, 0.0006924319313839078, 0.0007539874641224742, 0.0007014681468717754, 0.0005626931088045239, 0.00047894028830341995, 0.0004359032027423382, 0.0004471737192943692, 0.0007237185491248965, 0.0009377060923725367, 0.0007554700714536011, 0.0005048390594311059, 0.00038517473149113357, 0.00023606458853464574, 0.00045838140067644417, 0.00063386646797880

In [11]:
# 5Ô∏è‚É£ collate_fn ‚Äî x·ª≠ l√Ω d·ªØ li·ªáu on-the-fly
def collate_fn(batch):
    input_features = [
        processor.feature_extractor(
            sample["audio"]["array"], sampling_rate=16000
        ).input_features[0]
        for sample in batch
    ]
    labels = [
        processor.tokenizer(sample["transcription"]).input_ids
        for sample in batch
    ]

    labels = torch.nn.utils.rnn.pad_sequence(
        [torch.tensor(l) for l in labels],
        batch_first=True,
        padding_value=processor.tokenizer.pad_token_id,
    )

    return {
        "input_features": torch.tensor(input_features),
        "labels": labels,
    }

# 6Ô∏è‚É£ Metric
wer_metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)
    wer = 100 * wer_metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}

In [18]:
import torch
import torchaudio

def collate_fn(batch):
    input_features = []
    labels = []

    for sample in batch:
        # Load waveform (d√π l√† ƒë∆∞·ªùng d·∫´n hay dict ƒë·ªÅu x·ª≠ l√Ω ƒë∆∞·ª£c)
        audio_data = sample["audio"]
        if isinstance(audio_data, dict):
            waveform = torch.tensor(audio_data["array"])
            sr = audio_data["sampling_rate"]
        elif isinstance(audio_data, str):  # n·∫øu ch·ªâ l√† path
            waveform, sr = torchaudio.load(audio_data)
        else:
            raise TypeError(f"Unexpected audio type: {type(audio_data)}")

        # Chu·∫©n h√≥a sample rate
        if sr != 16000:
            waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
            sr = 16000

        # Chuy·ªÉn stereo -> mono
        if waveform.ndim > 1:
            waveform = waveform.mean(dim=0)

        # Tr√≠ch xu·∫•t ƒë·∫∑c tr∆∞ng √¢m thanh
        feats = processor.feature_extractor(
            waveform.numpy(), sampling_rate=sr
        ).input_features[0]
        input_features.append(feats)

        # X·ª≠ l√Ω nh√£n
        tokenized = processor.tokenizer(sample["transcription"]).input_ids
        labels.append(torch.tensor(tokenized))

    # Padding cho labels
    labels = torch.nn.utils.rnn.pad_sequence(
        labels,
        batch_first=True,
        padding_value=processor.tokenizer.pad_token_id,
    )

    return {
        # "input_features": torch.tensor(input_features, dtype=torch.float32),
        "input_features": torch.tensor(np.array(input_features), dtype=torch.float32),
        "labels": labels,
    }


In [None]:
# 7Ô∏è‚É£ TrainingArguments
training_args = TrainingArguments(
    output_dir="./whisper-small-vi",
    per_device_train_batch_size=1,  # small batch for CPU
    gradient_accumulation_steps=2,
    # eval_strategy="steps",
    # save_steps=100,
    # eval_steps=100,
    save_strategy="no",   # kh√¥ng cho Trainer t·ª± save gi·ªØa ch·ª´ng
    save_safetensors=False,  # t·∫Øt safe serialization
    logging_steps=1,
    num_train_epochs=1,
    learning_rate=1e-4,
    fp16=False,  # disabled on CPU
    no_cuda=True,  # force CPU
    dataloader_num_workers=0,
    # save_total_limit=1,
    report_to="none",
    remove_unused_columns=False,  # Allow custom batch keys for Whisper
 )

# 8Ô∏è‚É£ Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"].select(range(6)),
    # eval_dataset=dataset["test"].select(range(10)),
    data_collator=collate_fn,
    # compute_metrics=compute_metrics,
 )

In [20]:
model

WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(80, 768, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(768, 768, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 768)
      (layers): ModuleList(
        (0-11): 12 x WhisperEncoderLayer(
          (self_attn): WhisperAttention(
            (k_proj): lora.Linear(
              (base_layer): Linear(in_features=768, out_features=768, bias=False)
              (lora_dropout): ModuleDict(
                (default): Dropout(p=0.05, inplace=False)
              )
              (lora_A): ModuleDict(
                (default): Linear(in_features=768, out_features=8, bias=False)
              )
              (lora_B): ModuleDict(
                (default): Linear(in_features=8, out_features=768, bias=False)
              )
              (lora_embedding_A): ParameterDict()
              (lora_embedding_B): ParameterDict()
      

In [None]:
# 8Ô∏è‚É£ Prepare LoRA (PEFT) adapter and run training
# CPU-only path: do not use bitsandbytes or k-bit preparation
from peft import LoraConfig, get_peft_model, TaskType
import torch

# Ensure model is on CPU and uses float32
device = torch.device('cpu')
# model = model.to(device).to(torch.float32)

# Configure LoRA
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["k_proj", "v_proj", "q_proj", "out_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="SEQ_2_SEQ_LM",
    inference_mode=False,
 )

# Wrap the original model with PEFT's LoRA
peft_model = get_peft_model(model, lora_config)

# Print the number of trainable parameters (will be a small fraction of the total)
peft_model.print_trainable_parameters() 

# Attach the PEFT model to the Trainer and run training
trainer.model = peft_model
trainer.train()

In [19]:
# 9Ô∏è‚É£ L∆∞u model (g·ªëc)
trainer.save_model("./whisper-small-vi")
processor.save_pretrained("./whisper-small-vi")
# Save LoRA/PEFT adapter as well
try:
    peft_model.save_pretrained("./whisper-small-vi-lora")
    print("Saved LoRA adapter to ./whisper-small-vi-lora")
except NameError:
    print("peft_model not found ‚Äî if you ran the LoRA cell the adapter will be saved automatically.")

Saved LoRA adapter to ./whisper-small-vi-lora


In [30]:
sample["audio"]

'voice_label_data/20251102_075340_ab7da323.wav'

In [31]:
import torchaudio

audio_path = sample["audio"]
waveform, sr = torchaudio.load(audio_path)

# Chuy·ªÉn stereo -> mono n·∫øu c·∫ßn
if waveform.shape[0] > 1:
    waveform = waveform.mean(dim=0, keepdim=True)

# Resample v·ªÅ 16kHz n·∫øu ch∆∞a ƒë√∫ng
if sr != 16000:
    waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
    sr = 16000
waveform.shape, sr

(torch.Size([1, 11973]), 16000)

In [32]:
# üîü Ki·ªÉm th·ª≠ inference

# ƒê∆∞a v√†o processor
inputs = processor(
    waveform.squeeze().numpy(),
    sampling_rate=sr,
    return_tensors="pt"
)

# √âp ki·ªÉu v√† ƒë∆∞a l√™n ƒë√∫ng device
inputs = {k: v.to(model.device).to(model.dtype) for k, v in inputs.items()}

# G·ªçi generate k√®m attention_mask
predicted_ids = model.generate(
    input_features=inputs["input_features"],
    attention_mask=inputs.get("attention_mask", None),
    task="transcribe",
    language="vi"
)

transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)

print("üó£Ô∏è Ground truth:", sample["transcription"])
print("‚ú® Whisper prediction:", transcription[0])

üó£Ô∏è Ground truth: Ch√†o b·∫°n
‚ú® Whisper prediction:  Ch√†o b·∫°c
