In [None]:
import torch
import torchaudio
from datasets import load_dataset
from transformers import (Wav2Vec2ForCTC, Wav2Vec2Processor, 
                          TrainingArguments, Trainer)

In [None]:
                       
class Wav2VecTrainer:
    def __init__(self, vocab_path, model_name="facebook/wav2vec2-large-xlsr-53"):
        # Load tokenizer and feature extractor (processor)
        self.tokenizer = Wav2Vec2CTCTokenizer(vocab_path, unk_token="<unk>", pad_token="<pad>", word_delimiter_token="|")
        self.feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)
        self.processor = Wav2Vec2Processor(feature_extractor=self.feature_extractor, tokenizer=self.tokenizer)
        
        # Load pre-trained wav2vec2 model
        self.model = Wav2Vec2ForCTC.from_pretrained(
            model_name,
            attention_dropout=0.1,
            hidden_dropout=0.1,
            feat_proj_dropout=0.0,
            mask_time_prob=0,
            layerdrop=0.1,
            ctc_loss_reduction="mean",
            pad_token_id=self.processor.tokenizer.pad_token_id,
            vocab_size=len(self.processor.tokenizer),
        )

    def load_data(self, dataset_name, split="train"):
        # Load dataset
        self.dataset = load_dataset(dataset_name, split=split)
        self.dataset = self.dataset.map(self._prepare_example, remove_columns=self.dataset.column_names)

    def _prepare_example(self, batch):
        # Load and process audio file
        audio = batch["audio"]
        batch["input_values"] = self.processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
        # Encode target labels
        batch["labels"] = self.processor.tokenizer(batch["transcript"]).input_ids
        return batch

    def train(self, output_dir="./wav2vec2-output", batch_size=8, epochs=3):
        # Training arguments
        training_args = TrainingArguments(
            output_dir=output_dir,
            group_by_length=True,
            per_device_train_batch_size=batch_size,
            gradient_accumulation_steps=2,
            evaluation_strategy="steps",
            num_train_epochs=epochs,
            fp16=True,
            save_steps=400,
            eval_steps=400,
            logging_steps=400,
            learning_rate=1e-4,
            warmup_steps=500,
            save_total_limit=2,
        )

        # Initialize Trainer
        trainer = Trainer(
            model=self.model,
            data_collator=self._data_collator,
            args=training_args,
            train_dataset=self.dataset,
            tokenizer=self.processor.feature_extractor,
        )

        # Train the model
        trainer.train()

    def _data_collator(self, features):
        # Collate data into batches
        input_values = [feature["input_values"] for feature in features]
        labels = [feature["labels"] for feature in features]
        
        # Zero-pad inputs and labels
        batch = self.processor.pad({"input_values": input_values, "labels": labels}, return_tensors="pt")
        
        # Replace padding with -100 to ignore during CTC loss calculation
        batch["labels"] = torch.where(batch["labels"] == self.processor.tokenizer.pad_token_id, -100, batch["labels"])
        return batch

# Example usage:
if __name__ == "__main__":
    trainer = Wav2VecTrainer(vocab_path="./vocab.json")
    trainer.load_data(dataset_name="common_voice", split="train")
    trainer.train()
