In [None]:
# Install dependencies

%pip install matplotlib

In [None]:
from time import sleep

import matplotlib.pyplot as plt
from datasets import DatasetDict, load_dataset
from IPython.display import Audio as IPythonAudio
from tqdm.auto import tqdm


def play_sample(sample: dict):
    """Play the audio of a sample."""
    audio = sample["audio"]["array"]
    display(IPythonAudio(audio, rate=sample["audio"]["sampling_rate"]))


plt.style.use("ggplot")

In [None]:
# If you disable internet connection, then it will download the cached version locally (which doesn't exist on the Hub)
while True:
    try:
        coral = load_dataset("alexandrainst/coral", name="read_aloud", split="train")
        break
    except Exception as e:
        print(f"Encountered error: {str(e)}. Retrying...")
        continue

In [None]:
# Look at a sample

coral[0]

In [None]:
# Play some samples with the worst CER

worst_samples = coral.sort("asr_cer", reverse=True).select(range(100))
for sample in worst_samples:
    print(f"CER: {sample['asr_cer']:.0%}")
    print(f"Text: {sample['text']!r}")
    play_sample(sample)
    print()

In [None]:
TEST_SET_SPEAKER_IDS: list[str] = [
    "spe_6e7cb65603907f863e06d7a02e00fb67",
    "spe_55028d05581a88a8655fa1f74ddfb5a1",
    "spe_c883de44acfc7d9cfd32d5c9fa162342",
    "spe_3a85fdec89b8deb698bba43485b54fd2",
    "spe_e530efaab83b53fb59942e984b57e5cb",
    "spe_a3d4edeab8ea4c9bb67e847103c4b5f7",
    "spe_01fc2b156c7fe429f1b72bd3be5ad3c3",
    "spe_4aa23a60464a18e3597cdeb3606ac572",
    "spe_19b1d393bbe3ad9db3457ccda9bda5ea",
    "spe_aaa7d1aa7c5e1df17aa33156b8b54677",
    "spe_ae8bb53db7e325a8ecbb3238f4578d38",
    "spe_915ba6768904307f02be61ac0afe6366",
    "spe_63f165c8164332ea0791ffe410b40f0b",
    "spe_10b8eb8c3ba5fd8405d1516b7b12f2de",
    "spe_18c97571c93826474f71a51f48c2debd",
    "spe_9a2fee70cd58727ea1dbcf056de688b7",
    "spe_0a531653d15a79f97c28519ee3d024a0",
    "spe_8baebafddb83768cc4abe225511c5e0d",
    "spe_d147969fc79b1ba50c56e5adb9f662af",
    "spe_c1ae98ac0db97166130c016e5a91000b",
    "spe_4facb80c94341b25425ec1d8962b1f8d",
    "spe_6a029298b9eaa3d7e7f8f74510f88e70",
]
VALIDATION_SET_SPEAKER_IDS: list[str] = [
    "spe_b977ebc0a2ba961cbe158190fce0dc06",
    "spe_33b9071834e38bd3a2829add3bbcccb3",
    "spe_4c0f1933310dc0958cedf35fcc92fdc5",
    "spe_abad80046cdb930d9e11c61018000313",
    "spe_9c4dc6be57f6c63860331813a71417e5",
    "spe_51b02c4d372de72ba1cab851642ab363",
    "spe_6e67cbe51a49d9e4abbd7699a4a89d91",
    "spe_bb1d0e9d3f2bca18658975b3073924cb",
    "spe_65c05e58f399d854594d4716454a806b",
    "spe_1aec105b85238c5ebbc4bfc72d0569e1",
]

In [None]:
coral = coral.remove_columns(
    column_names=[
        "id_validator",
        "datetime_start",
        "datetime_end",
        "language_native",
        "language_spoken",
        "zipcode_birth",
        "zip_school",
        "education",
        "occupation",
        "asr_label",
    ]
)
coral

In [None]:
test = coral.filter(
    lambda sample: sample["id_speaker"] in TEST_SET_SPEAKER_IDS
    and sample["asr_cer"] < 0.6
    and sample["validated"] != "rejected"
    and sample["validated"] != "maybe",
    num_proc=8,
)
val = coral.filter(
    lambda sample: sample["id_speaker"] in VALIDATION_SET_SPEAKER_IDS
    and sample["asr_cer"] < 0.6
    and sample["validated"] != "rejected"
    and sample["validated"] != "maybe",
    num_proc=8,
)
train = coral.filter(
    lambda sample: sample["id_speaker"]
    not in TEST_SET_SPEAKER_IDS + VALIDATION_SET_SPEAKER_IDS
    and sample["asr_cer"] < 0.6
    and sample["validated"] != "rejected",
    num_proc=8,
)

In [None]:
test_hours = sum(
    sample["audio"]["array"].shape[0] / sample["audio"]["sampling_rate"] / 60 / 60
    for sample in tqdm(test)
)
val_hours = sum(
    sample["audio"]["array"].shape[0] / sample["audio"]["sampling_rate"] / 60 / 60
    for sample in tqdm(val)
)
train_hours = sum(
    sample["audio"]["array"].shape[0] / sample["audio"]["sampling_rate"] / 60 / 60
    for sample in tqdm(train)
)
print(f"Test hours: {test_hours:.2f}")
print(f"Val hours: {val_hours:.2f}")
print(f"Train hours: {train_hours:.2f}")

In [None]:
new_coral = DatasetDict(dict(train=train, val=val, test=test))
new_coral

In [None]:
while True:
    try:
        new_coral.push_to_hub("alexandrainst/coral", "read_aloud")
        break
    except Exception:
        print("Failed to push to hub - retrying...")
        sleep(10)
new_coral