In [1]:
# Imports
import os
from datasets import load_dataset, DatasetDict, Audio
import pandas as pd
import datasets
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, TrainingArguments, Trainer
from transformers.training_args import TrainingArguments
import torch
import numpy as np
from jiwer import wer, cer
from tqdm import tqdm
from dataclasses import dataclass
from typing import Dict, List, Union
import json
from inspect import signature

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# CONFIGURATION
LANG = 'en'
DATA_BASE = '../../data/asr_processed/en'
MODEL_NAME = 'facebook/wav2vec2-large-960h'
SAVE_DIR = '../../models/asr/en'
os.makedirs(SAVE_DIR, exist_ok=True)
SPLITS = ['train', 'val', 'test']

In [3]:
# LOAD CSVs AS DATASETS
data_files = {split: os.path.join(DATA_BASE, f"{split}.csv") for split in SPLITS}
dataset = DatasetDict({
    split: datasets.load_dataset('csv', data_files={split: path}, split=split)
    for split, path in data_files.items()
})

for split in SPLITS:
    dataset[split] = dataset[split].cast_column("path", datasets.Value("string"))

# Attach raw audio using relative path
def add_full_path(batch):
    batch["audio"] = [os.path.join('../../data/asr/en/train', x) for x in batch["path"]]
    return batch

dataset = dataset.map(add_full_path, batched=True)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

In [4]:
# PREPROCESSING FOR Wav2Vec2
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)

# Build vocabulary from dataset transcripts
vocab_set = set()
for item in tqdm(dataset['train'], desc="Building vocabulary"):
    vocab_set.update(list(item['sentence'].lower()))

vocab_set = sorted(vocab_set)
vocab_dict = {v: k for k, v in enumerate(vocab_set)}
processor.tokenizer.add_tokens(list(vocab_set))

def prepare_batch(batch):
    audio = batch["audio"]
    input_values = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    
    # processor.tokenizer directly for labels
    labels = processor.tokenizer(batch["sentence"]).input_ids
    return {
        "input_values": input_values,
        "labels": labels
    }

print("Tokenizing dataset...")
dataset = dataset.map(prepare_batch, remove_columns=dataset["train"].column_names, num_proc=4)

with open(os.path.join(SAVE_DIR, "vocab.json"), "w", encoding="utf-8") as f:
    json.dump(processor.tokenizer.get_vocab(), f, ensure_ascii=False, indent=2)

Building vocabulary: 100%|████████████████████████████████| 35997/35997 [10:09<00:00, 59.06it/s]


Tokenizing dataset...


In [5]:
# MODEL LOADING
model = Wav2Vec2ForCTC.from_pretrained(
    MODEL_NAME,
    vocab_size=len(processor.tokenizer),
    ignore_mismatched_sizes=True,
)
# model.resize_token_embeddings(len(processor.tokenizer))

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized because the shapes did not match:
- lm_head.weight: found shape torch.Size([32, 1024]) in the checkpoint and torch.Size([94, 1024]) in the model instantiated
- lm_head.bias: found shape torch.Size([32]) in the checkpoint and torch.Size([94]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
# DATA COLLATOR
@dataclass
class DataCollatorCTC:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features):
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt"
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt"
            )
        batch["labels"] = labels_batch["input_ids"].masked_fill(labels_batch["input_ids"] == self.processor.tokenizer.pad_token_id, -100)
        return batch

In [13]:
# Training Arguments
training_args = TrainingArguments(
    output_dir=SAVE_DIR,
    group_by_length=False,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=4,
    logging_dir=os.path.join(SAVE_DIR, "logs"),
    logging_strategy="steps",
    logging_steps=100,
    save_strategy="steps",
    save_steps=1000,
    eval_steps=500,
    num_train_epochs=5,
    learning_rate=1e-4,
    warmup_steps=500,
    save_total_limit=2,
    fp16=torch.cuda.is_available(),
    load_best_model_at_end=False,
    push_to_hub=False,
)

In [14]:
# METRICS
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    wer_score = wer(label_str, pred_str)
    cer_score = cer(label_str, pred_str)
    return {"wer": wer_score, "cer": cer_score}

In [12]:
# TRAIN & EVALUATE
data_collator = DataCollatorCTC(processor=processor, padding=True)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    train_dataset=dataset['train'],
    eval_dataset=dataset['val']
)

print("Starting training...")
trainer.train()

# Save best model
trainer.save_model(SAVE_DIR)
processor.save_pretrained(SAVE_DIR)

Starting training...




Step,Training Loss


KeyboardInterrupt: 