# Evaluate Performance of IPA Transcription Models

In [92]:
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from datasets import load_dataset
from eval_helpers import remove_diacritics
from evaluate import load

In [93]:
HF_TIMIT_PATH = "kylelovesllms/timit_asr_ipa"
HF_MODEL_PATH = "kylelovesllms/Wav2Vec2IpaFullTIMIT_6E"

## Step 1) Prepare Dataset

In [94]:
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [95]:
timit = load_dataset(HF_TIMIT_PATH)

In [96]:
# Calculate the amount of time 
def secsToHours(secs):
    min_per_sec = 1 / 60
    hour_per_min = 1 / 60
    return secs * min_per_sec * hour_per_min

train_time = sum(timit["train"]["duration"])
valid_time = sum(timit["validation"]["duration"])
test_time = sum(timit["test"]["duration"])
total_time = train_time + valid_time + test_time

print(f"total time = {secsToHours(total_time)}")
print(f"\ttrain time = {secsToHours(train_time)}")
print(f"\tvalid time = {secsToHours(valid_time)}")
print(f"\ttest time = {secsToHours(test_time)}")

total time = 4.270059236111111
	train time = 3.1197584722222222
	valid time = 0.5639583159722222
	test time = 0.5863424479166666


## Step 2) Load Baseline Model Wav2Vec2XLSR-53-espeak
- HF Link: https://huggingface.co/facebook/wav2vec2-xlsr-53-espeak-cv-ft

In [97]:
xlsr_processor = Wav2Vec2Processor.from_pretrained(HF_MODEL_PATH)
xlsr_model = Wav2Vec2ForCTC.from_pretrained(HF_MODEL_PATH)

### Prep audio to be processed

In [98]:
data_collator = DataCollatorCTCWithPadding(processor=xlsr_processor, padding=True)

In [99]:
processor = xlsr_processor
def prepare_dataset(batch):
    # Each batch corresponds with a specific entry (either in train or test)
    audio = batch["audio"]

    # batched output is "un-batched" to ensure mapping is correct

    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]

    batch["input_length"] = len(batch["input_values"])

    with processor.as_target_processor():
        timit_ipa_labels_arr = batch["ipa_transcription"]
        ipa_transcription = "".join(timit_ipa_labels_arr)
        batch["labels"] = processor(ipa_transcription).input_ids
    return batch

In [100]:
timit = timit.map(prepare_dataset, num_proc=4)

## Step 3) Apply Model (Get Transcriptions) 

In [101]:
def map_to_result(batch):
  with torch.no_grad():
    input_values = torch.tensor(batch["input_values"], device="cpu").unsqueeze(0)
    logits = xlsr_model(input_values).logits

  pred_ids = torch.argmax(logits, dim=-1)
  batch["pred_str"] = xlsr_processor.batch_decode(pred_ids)[0]
  batch["text"] = xlsr_processor.decode(batch["labels"], group_tokens=False)
  
  return batch

In [102]:
def sanitize_transcription(batch):
    # Cleanup Wav2Vec2 Transcription
    # cleanedup_w2v2p2_transcription = remove_diacritics(batch["pred_str"].split(" "))
    # transcription_as_str = "".join([symbol for symbol in cleanedup_w2v2p2_transcription if symbol != "[UNK]"])
    # print([symbol for symbol in cleanedup_w2v2p2_transcription if symbol != "[UNK]"])
    
    batch["pred_str"] = batch["pred_str"].replace("[UNK]", "_")
    print(batch["pred_str"])

    # Stringify the IPA transcription 
    batch["ipa_transcription"] = "".join([symbol for symbol in batch["ipa_transcription"] if symbol != "[UNK]"])
    return batch

In [103]:
# Run model inference and sanitize transcriptions
results = timit["validation"].map(map_to_result).map(sanitize_transcription)

Map: 100%|██████████| 670/670 [01:30<00:00,  7.40 examples/s]
Map:   0%|          | 0/670 [00:00<?, ? examples/s]

ðəbʌŋgə_oʊwɪ_p_ɛ_ə_ʔ_i_ɪttʃ_weɪɾɪd_ɪ_ðɪʃɔ_
doʊ_æ_k_iɾɪkɪ_iɪ_əʔɔɪ_i_æg_aɪkðæt
ʔɑ_j__ʊkɪŋ_ɚɪ_p_ɔɪ_ɪ_t
ʃihɛdjɚdɑ_k__t_g_i_iwɛ_ʃwɔɾɚɔ_jɪɚ
ʔɛtwaɪ_aɪɾɔ_ðətwɛ_θdeɪwɪ_hæ_ʃɪb_i
ʔiɾɪŋ_pɪ_ɪttʃ_aɪt_iɪ_k_i_ɪ_t_ɪŋθ_ɚ_ækə_ɪ__ɪ
gʌɾɪhɛkə_əbaɪɔ_ðɪ_dɛ_ttʃip
ð_kæ_əpdæddʒɪ_pɚtɪkɪ_iəpi_ɪŋ
əbɪggoʊtʔaɪd_iʔæ_b_dθ__ðɪ_ɑ__jɑ_d
ðɪ_g__pɪ_ɛkjɪ_ɚ_ɪ_tʔɛ_ðɛ_pɛ_g_ɪ_tɛ__tɪbitɛk_ə_ɑddʒɪk_
ðɪbʌŋgə_oʊwʌ_p_ɛ_ə_ʔ_i_ɪttʃ_eɪɾɪd_ɪ_ðəʃɔɚ
ʔeɪ_oʊtə_ʔɔ_keɪ_ɪ_tɪhɪ__ɔɪ_
doʊ_tʔæ_k_itəkɛ_ɪɪ_ɔɪ_i_æg_aɪkðæt
ʔɑ_j__ʊkɪŋ_ɔ__p_ɔɪ_ɛ_t
ʃihɛdjɛ_dɑ_k__tʔɪ_g_i_iwɔʃwɔɾɚʔɔ_jɪɛ_
_ɛɾɪ_aɪdtɪd_aɪwəθ_ɪɾɔ_tʃʊgɚbɔ_
ʔɪtwaɪ_aɪtʔə_ətwɛ__θdeɪwɪ_hæ_ʃʌb_i
ʔitɪŋ_pɪ_ɪdd_aɪʔ_iɪŋk_i_ɪ_t_ɛŋkθ_ɚækjɪ_ə__i
dʌ_ə_aɪɪɾi_ɪ_iɪg_ɪ_tæ__ʔɛ_tɪtiʔoʊ_ɛ_ɪ_əbʌ_ðiɪg_ɑ_ɛeɪʃɪ_ə__ɛ_
əbɪggoʊtʔaɪd_iʔæ_b_dθ__ðɪ_ɑ__jɑ_d
ʔɪ_pip_wɛ__ɔ_ddʒɛ_ɚ_ɪ_ðɛ_wʊdbi_oʊ_id_ɔ_wɛ__ɛ_
ðə_ɛ_ə_ɪkeɪʃɪ__ʔə_ðiɪʃ_ʔɑ_ʔɪ_ɔ__ə_
bɑ_bbɛ__dpeɪpɚʔɛ__i__ʔɪ_eɪbɪgbɑ__aɪɚ
doʊ_tʔæ_k_it_kɪ_iɛ_ʔɔɪ_i_æg_aɪkðæt
ʃihɛddʒʊ_dɑ_k__tɪŋg_i_iwɔʃwɔɾɚʔɔ_jɪɚ
bæ_kɪtbɔkɛ_biʔɛ_ʔɪ_ɚteɪ_iŋ_pɔ_t
t_ɪ_ʔɛk_ɛ_k_i__ə_ɚ_aʊ_d_ɪpʔə__oʊ_dʔæ__ʌɾɪ_aɪdwaɪ_ʔə_

Map: 100%|██████████| 670/670 [00:01<00:00, 446.29 examples/s]


In [104]:
results = results.map(sanitize_transcription)

Map:   0%|          | 0/670 [00:00<?, ? examples/s]

ðəbʌŋgə_oʊwɪ_p_ɛ_ə_ʔ_i_ɪttʃ_weɪɾɪd_ɪ_ðɪʃɔ_
doʊ_æ_k_iɾɪkɪ_iɪ_əʔɔɪ_i_æg_aɪkðæt
ʔɑ_j__ʊkɪŋ_ɚɪ_p_ɔɪ_ɪ_t
ʃihɛdjɚdɑ_k__t_g_i_iwɛ_ʃwɔɾɚɔ_jɪɚ
ʔɛtwaɪ_aɪɾɔ_ðətwɛ_θdeɪwɪ_hæ_ʃɪb_i
ʔiɾɪŋ_pɪ_ɪttʃ_aɪt_iɪ_k_i_ɪ_t_ɪŋθ_ɚ_ækə_ɪ__ɪ
gʌɾɪhɛkə_əbaɪɔ_ðɪ_dɛ_ttʃip
ð_kæ_əpdæddʒɪ_pɚtɪkɪ_iəpi_ɪŋ
əbɪggoʊtʔaɪd_iʔæ_b_dθ__ðɪ_ɑ__jɑ_d
ðɪ_g__pɪ_ɛkjɪ_ɚ_ɪ_tʔɛ_ðɛ_pɛ_g_ɪ_tɛ__tɪbitɛk_ə_ɑddʒɪk_
ðɪbʌŋgə_oʊwʌ_p_ɛ_ə_ʔ_i_ɪttʃ_eɪɾɪd_ɪ_ðəʃɔɚ
ʔeɪ_oʊtə_ʔɔ_keɪ_ɪ_tɪhɪ__ɔɪ_
doʊ_tʔæ_k_itəkɛ_ɪɪ_ɔɪ_i_æg_aɪkðæt
ʔɑ_j__ʊkɪŋ_ɔ__p_ɔɪ_ɛ_t
ʃihɛdjɛ_dɑ_k__tʔɪ_g_i_iwɔʃwɔɾɚʔɔ_jɪɛ_
_ɛɾɪ_aɪdtɪd_aɪwəθ_ɪɾɔ_tʃʊgɚbɔ_
ʔɪtwaɪ_aɪtʔə_ətwɛ__θdeɪwɪ_hæ_ʃʌb_i
ʔitɪŋ_pɪ_ɪdd_aɪʔ_iɪŋk_i_ɪ_t_ɛŋkθ_ɚækjɪ_ə__i
dʌ_ə_aɪɪɾi_ɪ_iɪg_ɪ_tæ__ʔɛ_tɪtiʔoʊ_ɛ_ɪ_əbʌ_ðiɪg_ɑ_ɛeɪʃɪ_ə__ɛ_
əbɪggoʊtʔaɪd_iʔæ_b_dθ__ðɪ_ɑ__jɑ_d
ʔɪ_pip_wɛ__ɔ_ddʒɛ_ɚ_ɪ_ðɛ_wʊdbi_oʊ_id_ɔ_wɛ__ɛ_
ðə_ɛ_ə_ɪkeɪʃɪ__ʔə_ðiɪʃ_ʔɑ_ʔɪ_ɔ__ə_
bɑ_bbɛ__dpeɪpɚʔɛ__i__ʔɪ_eɪbɪgbɑ__aɪɚ
doʊ_tʔæ_k_it_kɪ_iɛ_ʔɔɪ_i_æg_aɪkðæt
ʃihɛddʒʊ_dɑ_k__tɪŋg_i_iwɔʃwɔɾɚʔɔ_jɪɚ
bæ_kɪtbɔkɛ_biʔɛ_ʔɪ_ɚteɪ_iŋ_pɔ_t
t_ɪ_ʔɛk_ɛ_k_i__ə_ɚ_aʊ_d_ɪpʔə__oʊ_dʔæ__ʌɾɪ_aɪdwaɪ_ʔə_

Map: 100%|██████████| 670/670 [00:00<00:00, 1254.44 examples/s]


In [105]:
# Sanity Check
# print(f"Sentence: {results[0]['wrd']}")
print(results[1]["text"])
print(f'Ground Truth: {results[1]["ipa_transcription"]}')
print(f'Prediction:   {results[1]["pred_str"]}')

doʊ[UNK]æ[UNK]k[UNK]iɾɪkɛ[UNK]ɪɪ[UNK]əʔɔɪ[UNK]ɪ[UNK]æg[UNK]aɪkðæt
Ground Truth: doʊnæskmiɾɪkɛrɪɪnəʔɔɪlɪræglaɪkðæt
Prediction:   doʊ_æ_k_iɾɪkɪ_iɪ_əʔɔɪ_i_æg_aɪkðæt


## Step 4) Evaluate Model Output

In [106]:
cer_metric = load("cer")

In [107]:
total_cer = cer_metric.compute(predictions=results["pred_str"], references=results["ipa_transcription"])

In [108]:
print("Test CER: {:.3f}".format(total_cer))

Test CER: 0.383
