In [None]:
# Imports
import os
from datasets import load_dataset, DatasetDict, Audio
import pandas as pd
import datasets
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, TrainingArguments, Trainer
import torch
import numpy as np
from jiwer import wer, cer
from tqdm import tqdm

In [None]:
# CONFIGURATION
LANG = 'en'
DATA_BASE = '../../data/asr_processed/en'
MODEL_NAME = 'facebook/wav2vec2-large-960h'
SAVE_DIR = '../../models/asr/en'
os.makedirs(SAVE_DIR, exist_ok=True)
SPLITS = ['train', 'val', 'test']

In [None]:
# LOAD CSVs AS DATASETS
data_files = {split: os.path.join(DATA_BASE, f"{split}.csv") for split in SPLITS}
dataset = DatasetDict({
    split: datasets.load_dataset('csv', data_files={split: path}, split=split)
    for split, path in data_files.items()
})

for split in SPLITS:
    dataset[split] = dataset[split].cast_column("path", datasets.Value("string"))

# Attach raw audio using relative path
def add_full_path(batch):
    batch["audio"] = [os.path.join('../../data/asr/en/train', x) for x in batch["path"]]
    return batch

dataset = dataset.map(add_full_path, batched=True)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

In [None]:
# PREPROCESSING FOR Wav2Vec2
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)

# Build vocabulary from dataset transcripts
vocab_set = set()
for item in tqdm(dataset['train'], desc="Building vocabulary"):
    vocab_set.update(list(item['sentence'].lower()))

vocab_set = sorted(vocab_set)
vocab_dict = {v: k for k, v in enumerate(vocab_set)}
processor.tokenizer.add_tokens(list(vocab_set))

def prepare_batch(batch):
    audio = batch["audio"]
    input_values = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values
    with processor.as_target_processor():
        labels = processor(batch["sentence"]).input_ids
    return {"input_values": input_values[0], "labels": labels}

print("Tokenizing dataset...")
dataset = dataset.map(prepare_batch, remove_columns=dataset["train"].column_names, num_proc=4)

In [None]:
# MODEL LOADING
model = Wav2Vec2ForCTC.from_pretrained(
    MODEL_NAME,
    vocab_size=len(processor.tokenizer),
    ignore_mismatched_sizes=True,
)
model.resize_token_embeddings(len(processor.tokenizer))

In [None]:
# DATA COLLATOR
from dataclasses import dataclass
from typing import Dict, List, Union

@dataclass
class DataCollatorCTC:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features):
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt"
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt"
            )
        batch["labels"] = labels_batch["input_ids"].masked_fill(labels_batch["input_ids"] == self.processor.tokenizer.pad_token_id, -100)
        return batch

In [None]:
# Training Arguments
training_args = TrainingArguments(
    output_dir=SAVE_DIR,
    group_by_length=True,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=4,
    evaluation_strategy="steps",
    num_train_epochs=10,
    save_steps=1000,
    eval_steps=500,
    logging_steps=100,
    learning_rate=1e-4,
    warmup_steps=500,
    save_total_limit=2,
    fp16=torch.cuda.is_available(),
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
)

In [None]:
# METRICS
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    wer_score = wer(label_str, pred_str)
    cer_score = cer(label_str, pred_str)
    return {"wer": wer_score, "cer": cer_score}

In [None]:
# TRAIN & EVALUATE
data_collator = DataCollatorCTC(processor=processor, padding=True)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    train_dataset=dataset['train'],
    eval_dataset=dataset['val'],
    tokenizer=processor.feature_extractor,
)

trainer.train()
trainer.save_model(SAVE_DIR)
processor.save_pretrained(SAVE_DIR)