# TIMIT Test split performance analysis
This analyzes model performance on the test split TIMIT corpus, with special attention to performance on vowels. Reduction to the shared Buckeye/TIMIT symbol set is performed before analysis and performance metrics are calculated.

In [1]:
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any

import datasets
import matplotlib
import matplotlib.pyplot as plt
from phonecodes import phonecodes, phonecode_tables
import pandas as pd
import seaborn as sns


import multipa.evaluation
import multipa.evaluation_extras

# Visualiation settings
PALETTE = "gist_gray"
CONTEXT = "paper"
FONT_SCALE = 2

sns.color_palette(PALETTE)
sns.set_context(context=CONTEXT, font_scale=FONT_SCALE)
# font = {"size": 16}
# matplotlib.rc("font", **font)
# Remove the limits on the number of rows displayed in the notebook
pd.options.display.max_rows = None

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Data processing settings
TIMIT_DATASET = Path("../../data/TIMIT")
IPA_KEY = "ipa"
PREDICTION_KEY = "prediction"
NUM_PROC = 8

# Evaluation results
TIMIT_EVAL_DIR = Path("../../data/timit_results/")

DETAILED_PRED_DIR = TIMIT_EVAL_DIR / "detailed_predictions"
EDIT_DIST_DIR = TIMIT_EVAL_DIR / "edit_distances"

# Full original set, some may be removed after symbol reduction
TIMIT_VOWELS = ["ɑ", "æ", "ʌ", "ɔ", "ɛ", "ɪ", "i", "ʊ", "u", "ə", "ə̥", "ʉ", "ɨ", "ɹ̩", "ɚ"]
TIMIT_DIPHTHONGS = ["aʊ", "eɪ", "aɪ",  "oʊ", "ɔɪ"]

DIALECT_REGIONS = {
    "DR1": "DR1: New England",
    "DR2": "DR2: Northern",
    "DR3": "DR3: North Midland",
    "DR4": "DR4: South Midland",
    "DR5": "DR5: Southern",
    "DR6": "DR6: New York City",
    "DR7": "DR7: Western",
    "DR8": "DR8: Army Brat"
}


## Basic model performance comparisons
Show performance metrics for each model on TIMIT. 

In [3]:
# Manually define and join model source description

# Models fine-tuned on Buckeye, post process using BUCKEYE_REDUCED_MAPPING
wav2ipa_models = [
    ("ginic/full_dataset_train_1_wav2vec2-large-xlsr-53-buckeye-ipa", "Buckeye fine-tuned on full train split"),
    ("ginic/wav2vec2-large-lv60_phoneme-timit_english_timit-4k_buckeye-4k_bs32_3", "Lee 2025 fine-tuned again on Buckeye"),
    ]

# External models, post process using
external_models = [
    ("excalibur12/wav2vec2-large-lv60_phoneme-timit_english_timit-4k", "Lee 2025 Wav2Vec2.0 TIMIT fine-tuned"),

    ("openai_whisper-medium.en_to_epitran", "Whisper + Epitran"),
    ("facebook/wav2vec2-lv-60-espeak-cv-ft", "facebook/wav2vec2-lv-60-espeak-cv-ft"),
    ("facebook/wav2vec2-xlsr-53-espeak-cv-ft", "facebook/wav2vec2-xlsr-53-espeak-cv-ft"),
    ("allosaurus_eng2102_eng", "Allosaurus English"),
    ("ctaguchi/wav2vec2-large-xlsr-japlmthufielta-ipa1000-ns", "Taguchi et al. 2023"),
    # Intentionally omitting Whisper large - let's just keep the best in each category
    # ("openai_whisper-large-v3-turbo_to_epitran", "Whisper + Epitran"),
]

hue_order = [t[1] for t in wav2ipa_models + external_models]

model_sources_df = pd.DataFrame(wav2ipa_models + external_models, columns=["model", "Model source"])


# Read and preprocess TIMIT test set
Performs symbol reduction on the TIMIT test set.

In [4]:
timit_test_dataset = datasets.load_dataset("timit_asr", data_dir = TIMIT_DATASET, split="test")
print("TIMIT Test Size:", len(timit_test_dataset))
print("TIMIT Test Snippet:", timit_test_dataset[0])
print(timit_test_dataset[0]["phonetic_detail"]["utterance"])

TIMIT Test Size: 1680
TIMIT Test Snippet: {'file': '/Users/virginia/workspace/multipa/data/TIMIT/TEST/DR1/FAKS0/SA1.WAV', 'audio': {'path': '/Users/virginia/workspace/multipa/data/TIMIT/TEST/DR1/FAKS0/SA1.WAV', 'array': array([9.15527344e-05, 1.52587891e-04, 6.10351562e-05, ...,
       2.44140625e-04, 3.05175781e-04, 2.13623047e-04]), 'sampling_rate': 16000}, 'text': 'She had your dark suit in greasy wash water all year.', 'phonetic_detail': {'start': [0, 9640, 11240, 12783, 14078, 16157, 16880, 17103, 17587, 18760, 19720, 19962, 21514, 22680, 23800, 24104, 26280, 28591, 29179, 30337, 31880, 32500, 33170, 33829, 35150, 37370, 38568, 40546, 42357, 45119, 45624, 46855, 48680, 49240, 51033, 52378, 54500, 55461, 57395, 59179, 60600], 'stop': [9640, 11240, 12783, 14078, 16157, 16880, 17103, 17587, 18760, 19720, 19962, 21514, 22680, 23800, 24104, 26280, 28591, 29179, 30337, 31880, 32500, 33170, 33829, 35150, 37370, 38568, 40546, 42357, 45119, 45624, 46855, 48680, 49240, 51033, 52378, 54500, 

In [5]:
def batch_timit_to_ipa(dataset_entry:dict[str, Any]):
    timit_str = dataset_entry["phonetic_detail"]["utterance"]
    ipa_syms = phonecodes.timit2ipa(" ".join(timit_str)).split()
    dataset_entry[IPA_KEY] = "".join(ipa_syms)
    return dataset_entry

In [6]:
# Convert original TIMIT ARPABET to IPA
timit_test_with_ipa = timit_test_dataset.map(batch_timit_to_ipa, num_proc=NUM_PROC)
print(timit_test_with_ipa[0]["phonetic_detail"]["utterance"])
print(timit_test_with_ipa[0][IPA_KEY])

['h#', 'sh', 'iy', 'hv', 'ae', 'dcl', 'd', 'y', 'er', 'dcl', 'd', 'aa', 'r', 'kcl', 'k', 's', 'uw', 'dx', 'ih', 'ng', 'gcl', 'g', 'r', 'iy', 's', 'iy', 'w', 'aa', 'sh', 'epi', 'w', 'aa', 'dx', 'er', 'q', 'ao', 'l', 'y', 'iy', 'axr', 'h#']
ʃiɦædjɝdɑɹksuɾɪŋɡɹisiwɑʃwɑɾɝʔɔljiɚ


In [9]:
# Post process IPA to reduced TIMIT symbol set
timit_test_with_ipa = multipa.evaluation_extras.dataset_reduction_greedy_find_and_replace(timit_test_with_ipa, IPA_KEY, "timit", num_proc=NUM_PROC)
print(timit_test_with_ipa[0][IPA_KEY])


Map (num_proc=8): 100%|██████████| 1680/1680 [00:00<00:00, 9748.28 examples/s] 

ʃihædjɹ̩dɑɹksuɾɪŋɡɹisiwɑʃwɑɾɹ̩ʔɔljiɹ̩





# Dialect Region Performance Plots
This creates bar charts showing performance by dialect. Since the groupby and averaging was already done, we just need to read in the data and plot the results.

In [8]:
# Read in predictions and extract dialect region
detailed_results_dfs = []
for model, label in model_sources:
    clean_model_name = model.replace("/", "_")
    tmp_df = pd.read_csv(DETAILED_PRED_DIR / f"{clean_model_name}_detailed_predictions.csv")
    tmp_df["model_name"] = model
    tmp_df["Model source"] = label
    detailed_results_dfs.append(tmp_df)

detailed_preds_df = pd.concat(detailed_results_dfs)
detailed_preds_df["dialect"] = detailed_preds_df["filename"].apply(lambda x: x.split("/")[2].upper())
display(detailed_preds_df.head())


NameError: name 'model_sources' is not defined

In [None]:
# Show performance by dialect region
dialect_df = detailed_preds_df.groupby(["model_name", "Model source", "dialect"])["phone_error_rates"].mean().reset_index()
display(dialect_df.head())
dialect_df = dialect_df.merge(pd.DataFrame(DIALECT_REGIONS.items(), columns=["dialect", "Dialect Region"]), on="dialect")
dialect_df =  dialect_df.sort_values(by=["Dialect Region", "phone_error_rates"], ascending=[True, True])
display(dialect_df)

In [None]:
g = sns.FacetGrid(dialect_df, col="Dialect Region", col_wrap=4, height=4, aspect=0.75)
g.set_titles(col_template="{col_name}")
g.map_dataframe(sns.barplot, y="phone_error_rates", hue="Model source", palette=PALETTE, hue_order = hue_order)
g.add_legend(title="Model source")
g.set_ylabels("Average Phone Error Rate")
g.fig.suptitle("Models' Average Phone Error Rates by Dialect Region", fontsize=24, y=1.05)


In [None]:
# Dialect performance for just our model
our_model_dialect_df = dialect_df[dialect_df["Model source"] == "Our AutoIPA: fine-tuned on full train split"]
g = sns.barplot(data=our_model_dialect_df, y="Dialect Region", x="phone_error_rates", hue="Dialect Region", palette=PALETTE)
g.set_xlabel("Average Phone Error Rate")
g.set_xlim((0,0.5))
g.set(title="Our AutoIPA's TIMIT Performance by Dialect Region")
for bar in g.containers:
    g.bar_label(bar, fmt="%.2f", padding=5)


# Vowel Error Rate Analysis
How many instances of each vowel in the vocabulary are we getting wrong? 
$$ error\_rate(v) = \frac{count\_substitutions\_of(v) + count\_deletions(v)}{total\_count(v)}$$

In [None]:
gold_transcription_df = pd.read_csv(GOLD_TRANSCRIPTIONS_CSV)
gold_transcription_df["filename"] = gold_transcription_df["audio_filename"].str.lower()
gold_transcription_df["ipa_transcription"] = gold_transcription_df["ipa_transcription"].str.replace("ɝ", "ɹ̩")
vowel_counts = Counter()
for vowel in TIMIT_VOWELS + TIMIT_DIPHTHONGS:
    vowel_counts[vowel] += gold_transcription_df["ipa_transcription"].apply(lambda x: x.split().count(vowel)).sum()

vowel_counts

In [None]:
# Simple bar chart of vowel counts
plot_vowels, plot_counts = zip(*vowel_counts.most_common())
g = sns.barplot(y=plot_vowels, x=plot_counts, palette="colorblind")
g.set_xlim(0, 13500)
g.set_xlabel("count")
g.set(title="Counts of TIMIT Vowel Occurrences")
for bar in g.containers:
    g.bar_label(bar, fontsize='small')




In [None]:
EPS = "***"
def tally_edit_distance_errors(references, predictions):
    """Counts up edit distances from lists of already tokenized references and predictions."""
    subs = Counter()
    insertions = Counter()
    deletions = Counter()
    for ref_tokens, pred_tokens in zip(references, predictions):
        aligned_pairs = kaldialign.align(ref_tokens, pred_tokens, EPS)

        for r, p in aligned_pairs:
            if r == EPS:
                insertions[p] += 1
            elif p == EPS:
                deletions[r] += 1
            elif r != p:
                subs[(r, p)] += 1

    return subs, deletions, insertions

def diphthong_merge(t1, t2):
    """For merge detected diphthongs in predicted output when using ipatok.tokenise"""
    if t1+t2 in TIMIT_DIPHTHONGS:
        return True
    else:
        return False

In [None]:
# Re-do edit distance calculations with better tokenization, specifically turning
# on diphthong tokenization
our_model_detailed_preds_df = pd.read_csv(DETAILED_PRED_DIR / "ginic_full_dataset_train_1_wav2vec2-large-xlsr-53-buckeye-ipa_detailed_predictions.csv").drop(columns=["substitutions", "insertions", "deletions"])
full_edit_distance_analysis_df = pd.merge(gold_transcription_df, our_model_detailed_preds_df, on="filename")

full_edit_distance_analysis_df["ipa_tokens"] = full_edit_distance_analysis_df["ipa_transcription"].str.split()
full_edit_distance_analysis_df["predicted_ipa_tokens"] = full_edit_distance_analysis_df["prediction"].apply(lambda x: ipatok.tokenise(x, diphthongs=True, merge=diphthong_merge))
print(full_edit_distance_analysis_df["ipa_tokens"][:10])
print(full_edit_distance_analysis_df["predicted_ipa_tokens"][:10])
display(full_edit_distance_analysis_df.head())


In [None]:
sub_counter, del_counter, inserts_counter = tally_edit_distance_errors(full_edit_distance_analysis_df["ipa_tokens"], full_edit_distance_analysis_df["predicted_ipa_tokens"])


In [None]:
# Get subs and deletions in good format for analysis
detailed_error_counts = defaultdict(Counter)
subs_counts = Counter()
for (sub_tuple, count) in sub_counter.items():
    subs_counts[sub_tuple[0]] += count
    detailed_error_counts[sub_tuple[0]][sub_tuple[1]] += count

print("Substitution Counts:", subs_counts)

for (deleted, count) in del_counter.items():
    detailed_error_counts[deleted]["<deleted>"] += count

print("Detailed Error Counts:", detailed_error_counts)

In [None]:
# Compute vowel error rates
vowel_error_rates = {}
for v in TIMIT_VOWELS + TIMIT_DIPHTHONGS:
    subs_count = subs_counts[v]
    dels_count = del_counter[v]
    ver = (subs_count + dels_count)/ (vowel_counts[v])
    vowel_error_rates[v] = ver

ver_df = pd.DataFrame(vowel_error_rates.items(), columns=["Vowel", "Vowel Error Rate"]). sort_values(by="Vowel Error Rate", ascending=False)
error_ordering = ver_df[ver_df["Vowel Error Rate"] > 0]["Vowel"].tolist()
print("In descending frequency of errors:", error_ordering)

display(ver_df)


In [None]:
plt.figure(figsize=(4, 6))
sns.heatmap(
    ver_df.sort_values(by="Vowel Error Rate", ascending=False).set_index("Vowel"),
    cmap="rainbow",
    # cmap="spring_r",
    annot=True,
    fmt=".2f",
    yticklabels=True,
    # linewidths=1,
)
plt.title("AutoIPA TIMIT Vowel Error Rates\n(Descending worst to best)")
# plt.xticks(rotation=90)
plt.yticks(rotation=0)
plt.ylabel("")
plt.show()


In [None]:
interesting_vowels = ver_df[ver_df["Vowel Error Rate"] > 0.0]["Vowel"].tolist()
print(interesting_vowels)


interesting_errors = []
for v in interesting_vowels:
    for error, count in detailed_error_counts[v].items():
        interesting_errors.append((v, error, count))

interesting_errors_df = pd.DataFrame(interesting_errors, columns=["Vowel", "Error", "Count"])
interesting_errors_df["Ratio of Vowel's Errors"] = interesting_errors_df.groupby("Vowel", group_keys=False)["Count"].apply(lambda x: x / x.sum())
display(interesting_errors_df)


In [None]:
# Grab top ten errors for each vowel
top_errors_df = interesting_errors_df.groupby("Vowel").apply(lambda x: x.nlargest(5, "Count")).reset_index(drop=True)
top_errors_df["Vowel"] = pd.Categorical(top_errors_df["Vowel"], categories=error_ordering, ordered=True)
top_errors_df = top_errors_df.sort_values(by=["Vowel", "Count"], ascending=[True, False])
display(top_errors_df.head(20))

In [None]:
convention_errors = ["ɨ", "ʉ", "ə̥", "ə", "ɚ"]

convention_errors_df = top_errors_df[top_errors_df["Vowel"].isin(convention_errors)]
convention_errors_df["Vowel"] = pd.Categorical(convention_errors_df["Vowel"], categories=convention_errors, ordered=True)
g = sns.FacetGrid(convention_errors_df, col="Vowel", col_wrap=3, sharey=False, xlim=(0, 1), aspect=1.25)
g.map_dataframe(sns.barplot, x="Ratio of Vowel's Errors", y="Error", orient="h")
g.set_titles(col_template="{col_name}")
g.set_ylabels("Error or\nSubstitution")
g.set_xlabels("As ratio of total errors\naffecting the vowel")
g.fig.suptitle("Top 5 errors for vowels Wav2IPA always incorrectly transcribes", fontsize=24, y=1.05)


In [None]:
not_convention_errors = [v for v in interesting_vowels if v not in convention_errors]
print(not_convention_errors)
not_convention_errors_df = top_errors_df[top_errors_df["Vowel"].isin(not_convention_errors)]
not_convention_errors_df["Vowel"] = pd.Categorical(
not_convention_errors_df["Vowel"], categories=not_convention_errors, ordered=True
)
display(not_convention_errors_df.head(20))

g = sns.FacetGrid(not_convention_errors_df, col="Vowel", col_wrap=5, sharey=False, aspect=1.25, xlim=(0, 1))
g.map_dataframe(sns.barplot, x="Ratio of Vowel's Errors", y="Error", orient="h")
g.set_titles(col_template="{col_name}", fontsize=20)
g.set_ylabels("Error or\nSubstitution")
g.set_xlabels("As ratio of total errors\naffecting the vowel")
g.fig.suptitle("Remaining TIMIT Vowels: Top 5 Wav2IPA Errors for each vowel", fontsize=24, y=1.05)