In [1]:
# Imports
import os
import json
import pandas as pd
from datasets import load_dataset, DatasetDict, Audio
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, TrainingArguments, Trainer
from dataclasses import dataclass
from typing import Union
from jiwer import wer, cer
from tqdm import tqdm
from pathlib import Path
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# CONFIGURATION
LANG = "en"
DATA_BASE = "../../data/asr_processed/en"
SAVE_DIR = "../../models/asr/en"
AUDIO_BASE = Path("../../data/asr/en/train")
MODEL_NAME = "facebook/wav2vec2-base-960h"
SPLITS = ["train", "val", "test"]

os.makedirs(SAVE_DIR, exist_ok=True)

In [3]:
# Load Dataset + Fix paths
data_files = {split: os.path.join(DATA_BASE, f"{split}.csv") for split in SPLITS}
dataset = DatasetDict({
    split: load_dataset("csv", data_files={split: path}, split=split)
    for split, path in data_files.items()
})

def add_full_path(example):
    example["path"] = os.path.join(AUDIO_BASE, example["path"])
    return example

dataset = dataset.map(add_full_path, desc="Attaching audio paths")

Attaching audio paths: 100%|███████████████████| 9004/9004 [00:01<00:00, 4992.34 examples/s]
Attaching audio paths: 100%|█████████████████████| 493/493 [00:00<00:00, 8724.23 examples/s]
Attaching audio paths: 100%|█████████████████████| 503/503 [00:00<00:00, 6278.56 examples/s]


In [4]:
# Build Vocabulary
print("Building vocabulary from training data...")
vocab_set = set(char for text in dataset["train"]["sentence"] for char in text.lower())
vocab_list = sorted(vocab_set | set([" ", "|"]))
vocab_dict = {char: idx for idx, char in enumerate(vocab_list)}

vocab_path = os.path.join(SAVE_DIR, "vocab.json")
with open(vocab_path, "w", encoding="utf-8") as f:
    json.dump(vocab_dict, f, indent=2, ensure_ascii=False)

Building vocabulary from training data...


In [5]:
# Load Processor
processor = Wav2Vec2Processor.from_pretrained(
    MODEL_NAME,
    tokenizer_kwargs={
        "vocab_file": vocab_path,
        "unk_token": "[UNK]",
        "pad_token": "[PAD]",
        "word_delimiter_token": "|"
    }
)

processor.tokenizer.pad_token = "[PAD]"

In [6]:
# Preprocess Audio/Text Inputs
def prepare_batch(batch, processor):
    import torchaudio
    from torchaudio.transforms import Resample

    waveform, sr = torchaudio.load(batch["path"])

    if sr != 16000:
        resampler = Resample(orig_freq=sr, new_freq=16000)
        waveform = resampler(waveform)

    batch["input_values"] = processor(
        waveform[0], sampling_rate=16000
    ).input_values[0]
    batch["labels"] = processor.tokenizer(batch["sentence"]).input_ids
    return {"input_values": batch["input_values"], "labels": batch["labels"]}

print("Preprocessing dataset...")
dataset = dataset.map(
    prepare_batch,
    fn_kwargs={"processor": processor},
    remove_columns=dataset["train"].column_names,
    num_proc=4
)

Preprocessing dataset...


Map (num_proc=4): 100%|██████████████████████████| 9004/9004 [09:10<00:00, 16.36 examples/s]
Map (num_proc=4): 100%|████████████████████████████| 493/493 [00:27<00:00, 18.02 examples/s]
Map (num_proc=4): 100%|████████████████████████████| 503/503 [00:32<00:00, 15.53 examples/s]


In [7]:
# Load & Resize Model
model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME, vocab_size=len(processor.tokenizer))
print(f"Loaded model with vocab size {model.config.vocab_size}")

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loaded model with vocab size 32


In [8]:
# Data Collator
@dataclass
class DataCollatorCTC:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features):
        input_features = [{"input_values": f["input_values"]} for f in features]
        label_features = [{"input_ids": f["labels"]} for f in features]

        batch = self.processor.feature_extractor.pad(input_features, padding=self.padding, return_tensors="pt")
        with self.processor.as_target_processor():
            labels_batch = self.processor.tokenizer.pad(label_features, padding=self.padding, return_tensors="pt")

        batch["labels"] = labels_batch["input_ids"].masked_fill(
            labels_batch["input_ids"] == self.processor.tokenizer.pad_token_id, -100
        )

        return batch

In [9]:
# Training Arguments
training_args = TrainingArguments(
    output_dir=SAVE_DIR,
    eval_strategy="steps",
    save_strategy="steps",
    logging_strategy="steps",
    logging_steps=100,
    eval_steps=500,
    save_steps=1000,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=2,
    learning_rate=1e-4,
    warmup_steps=500,
    save_total_limit=2,
    group_by_length=False,
    dataloader_num_workers=0,
    fp16=False,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    report_to="none",
    push_to_hub=False
)

In [10]:
# Metrics
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)
    return {"wer": wer(label_str, pred_str), "cer": cer(label_str, pred_str)}

In [11]:
# Trainer
data_collator = DataCollatorCTC(processor=processor)
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    train_dataset=dataset["train"],
    eval_dataset=dataset["val"]
)

In [None]:
# Start Training
print("Starting training...")
trainer.train()

Starting training...




Step,Training Loss,Validation Loss


In [None]:
# Evaluate
print("Evaluating on test set...")
test_result = trainer.evaluate(dataset["test"])
print(f"Final Test WER: {test_result['eval_wer']:.4f} | CER: {test_result['eval_cer']:.4f}")

In [None]:
# Save Processor and Model
trainer.save_model(SAVE_DIR)
processor.save_pretrained(SAVE_DIR)

print(f"Model and processor saved to {SAVE_DIR}")