In [None]:
# Install dependencies

%pip install matplotlib

In [None]:
from time import sleep

import matplotlib.pyplot as plt
from datasets import DatasetDict, load_dataset
from IPython.display import Audio as IPythonAudio
from tqdm.auto import tqdm


def play_sample(sample: dict):
    """Play the audio of a sample."""
    audio = sample["audio"]["array"]
    display(IPythonAudio(audio, rate=sample["audio"]["sampling_rate"]))


plt.style.use("ggplot")

In [None]:
# If you disable internet connection, then it will download the cached version locally (which doesn't exist on the Hub)
while True:
    try:
        coral = load_dataset("alexandrainst/coral", name="read_aloud", split="train")
        break
    except Exception as e:
        print(f"Encountered error: {str(e)}. Retrying...")
        continue

In [None]:
# Look at a sample

coral[0]

In [None]:
# Play some samples with the worst CER

worst_samples = coral.sort("asr_cer", reverse=True).select(range(100))
for sample in worst_samples:
    print(f"CER: {sample['asr_cer']:.0%}")
    print(f"Text: {sample['text']!r}")
    play_sample(sample)
    print()

In [None]:
TEST_SET_SPEAKER_IDS: list[str] = [
    "spe_f52921e1e787609ab99623340c5dd212",
    "spe_cdb17db7c331cbb89e40dcbbecf4d560",
    "spe_e3742811d83011e22ec2ef5a7af32065",
    "spe_42de8f7200a57e1d28ae5b415ba5b934",
    "spe_d84ebbfd3b0a3fbc12df4f960fe44ae3",
    "spe_199e03b334b15576a69be73ea39a34d5",
    "spe_0f8c666aaf602dfc580d99254e37ac77",
    "spe_e33e46611f54ae91ed7b235c11ef2628",
    "spe_37a526e88c934c7966038d34af9debf0",
    "spe_07c0276e66e920209cf22266b24fa5e4",
    "spe_9b8d26599c6b7932dbac00832b73dcf8",
    "spe_6a029298b9eaa3d7e7f8f74510f88e70",
    "spe_7b8398c898a828791c0fc40d6d146b3f",
    "spe_5e319f90767d47e11731d95e314e4670",
    "spe_4b7ba1403d8540b3101c07b9c8a19474",
    "spe_436e439616edf662c232486b3face2f1",
    "spe_647d4e905427d45ab699abe73d80ef1d",
    "spe_51b02c4d372de72ba1cab851642ab363",
    "spe_50cddf66f739637c1b3c534938649b8e",
    "spe_6e7cb65603907f863e06d7a02e00fb67",
    "spe_f1d26280a22ad55b85083b19d61f243a",
    "spe_9f92cb4d6feb94dab9c691811656e33e",
    "spe_55028d05581a88a8655fa1f74ddfb5a1",
]
VALIDATION_SET_SPEAKER_IDS: list[str] = [
    "spe_9c4dc6be57f6c63860331813a71417e5",
    "spe_4a7e760bd0a2775337880155e8ac0ec2",
    "spe_03e8b9d0ee8d3192e113ff62c61e4916",
    "spe_92fea6e4419210f4c4219e84ec89837e",
    "spe_b977ebc0a2ba961cbe158190fce0dc06",
    "spe_4aa23a60464a18e3597cdeb3606ac572",
    "spe_20b91d51f72ee56930ca778cb16c29da",
    "spe_fbf3381f525dbe5ddf1a2a1d36e9c4b9",
    "spe_4d03787c2092b6bee053e75e2cfa4aa3",
    "spe_877ac9c88e53b43ebfe464da79aa6da3",
    "spe_ffc1068fc082deac40144691e1ae754c",
]

In [None]:
coral = coral.remove_columns(
    column_names=[
        "id_validator",
        "datetime_start",
        "datetime_end",
        "language_native",
        "language_spoken",
        "zipcode_birth",
        "zip_school",
        "education",
        "occupation",
        "asr_label",
    ]
)
coral

In [None]:
test = coral.filter(
    lambda sample: sample["id_speaker"] in TEST_SET_SPEAKER_IDS
    and sample["asr_cer"] < 0.6
    and sample["validated"] != "rejected"
    and sample["validated"] != "maybe",
    num_proc=8,
)
val = coral.filter(
    lambda sample: sample["id_speaker"] in VALIDATION_SET_SPEAKER_IDS
    and sample["asr_cer"] < 0.6
    and sample["validated"] != "rejected"
    and sample["validated"] != "maybe",
    num_proc=8,
)
train = coral.filter(
    lambda sample: sample["id_speaker"]
    not in TEST_SET_SPEAKER_IDS + VALIDATION_SET_SPEAKER_IDS
    and sample["asr_cer"] < 0.6
    and sample["validated"] != "rejected",
    num_proc=8,
)

In [None]:
test_hours = sum(
    sample["audio"]["array"].shape[0] / sample["audio"]["sampling_rate"] / 60 / 60
    for sample in tqdm(test)
)
val_hours = sum(
    sample["audio"]["array"].shape[0] / sample["audio"]["sampling_rate"] / 60 / 60
    for sample in tqdm(val)
)
train_hours = sum(
    sample["audio"]["array"].shape[0] / sample["audio"]["sampling_rate"] / 60 / 60
    for sample in tqdm(train)
)
print(f"Test hours: {test_hours:.2f}")
print(f"Val hours: {val_hours:.2f}")
print(f"Train hours: {train_hours:.2f}")

In [None]:
new_coral = DatasetDict(dict(train=train, val=val, test=test))
new_coral

In [None]:
while True:
    try:
        new_coral.push_to_hub(
            "alexandrainst/coral",
            "read_aloud",
            commit_message="Update test/val tests to have better dialectal representation",
        )
        break
    except Exception as e:
        print(f"Failed to push to hub ({str(e)}) - retrying...")
        sleep(10)
new_coral