# Evaluate Performance of IPA Transcription Models

In [None]:
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from datasets import load_dataset
from eval_helpers import remove_diacritics
from evaluate import load

In [None]:
HF_TIMIT_PATH = "kylelovesllms/timit_asr_ipa"

## Step 1) Prepare Dataset

In [None]:
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [None]:
timit = load_dataset(HF_TIMIT_PATH)

## Step 2) Load Baseline Model Wav2Vec2XLSR-53-espeak
- HF Link: https://huggingface.co/facebook/wav2vec2-xlsr-53-espeak-cv-ft

In [None]:
xlsr_processor = Wav2Vec2Processor.from_pretrained(
    "facebook/wav2vec2-xlsr-53-espeak-cv-ft")
xlsr_model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-xlsr-53-espeak-cv-ft")

### Prep audio to be processed

In [None]:
data_collator = DataCollatorCTCWithPadding(processor=xlsr_processor, padding=True)

In [None]:
processor = xlsr_processor
def prepare_dataset(batch):
    # Each batch corresponds with a specific entry (either in train or test)
    audio = batch["audio"]

    # batched output is "un-batched" to ensure mapping is correct

    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]

    batch["input_length"] = len(batch["input_values"])

    with processor.as_target_processor():
        timit_ipa_labels_arr = batch["ipa_transcription"]
        ipa_transcription = "".join(timit_ipa_labels_arr)
        batch["labels"] = processor(ipa_transcription).input_ids
    return batch

In [None]:
timit = timit.map(prepare_dataset, num_proc=4)

## Step 3) Apply Model (Get Transcriptions) 

In [None]:
def map_to_result(batch):
  with torch.no_grad():
    input_values = torch.tensor(batch["input_values"], device="cpu").unsqueeze(0)
    logits = xlsr_model(input_values).logits

  pred_ids = torch.argmax(logits, dim=-1)
  batch["pred_str"] = xlsr_processor.batch_decode(pred_ids)[0]
  batch["text"] = xlsr_processor.decode(batch["labels"], group_tokens=False)
  
  return batch

In [None]:
def sanitize_transcription(batch):
    # Cleanup Wav2Vec2 Transcription
    cleanedup_w2v2p2_transcription = remove_diacritics(batch["pred_str"].split(" "))
    transcription_as_str = "".join(cleanedup_w2v2p2_transcription)
    batch["pred_str"] = transcription_as_str

    # Stringify the IPA transcription 
    batch["ipa_transcription"] = "".join(batch["ipa_transcription"])
    return batch

In [None]:
# Run model inference and sanitize transcriptions
results = timit["test"].map(map_to_result).map(sanitize_transcription)

In [None]:
# Sanity Check
print(f'Ground Truth: {results[0]["ipa_transcription"]}')
print(f'Prediction:   {results[0]["pred_str"]}')

## Step 4) Evaluate Model Output

In [None]:
cer_metric = load("cer")

In [None]:
total_cer = cer_metric.compute(predictions=results["pred_str"], references=results["ipa_transcription"])

In [None]:
print("Test CER: {:.3f}".format(total_cer))