In [None]:
# Install dependencies

%pip install matplotlib

In [None]:
import matplotlib.pyplot as plt
from datasets import DatasetDict, load_dataset
from IPython.display import Audio as IPythonAudio
from tqdm.auto import tqdm


def play_sample(sample: dict):
    """Play the audio of a sample."""
    audio = sample["audio"]["array"]
    display(IPythonAudio(audio, rate=sample["audio"]["sampling_rate"]))


plt.style.use("ggplot")

In [None]:
while True:
    try:
        coral = load_dataset("alexandrainst/coral", name="read_aloud", split="train")
        break
    except Exception as e:
        print(f"Encountered error: {str(e)}. Retrying...")
        continue

In [None]:
# Look at a sample

coral[0]

In [None]:
# Play some samples with the worst CER

worst_samples = coral.sort("asr_cer", reverse=True).select(range(100))
for sample in worst_samples:
    print(f"CER: {sample['asr_cer']:.0%}")
    print(f"Text: {sample['text']!r}")
    play_sample(sample)
    print()

In [None]:
TEST_SET_SPEAKER_IDS: list[str] = [
    "spe_ac145f72d3d37064bfe62d11c58a3cb4",
    "spe_e01017cbabe39aa19980d30b022947dc",
    "spe_046c02b65af055859e0f0a1885b2cc5c",
    "spe_6691832f1f170d2876ec2d99de3d0b8f",
    "spe_bffe31a07537d14e22eef5c5efcd4fe6",
    "spe_e2cfb324371dd8ce3a1038a27eb6fb5b",
    "spe_b440df30591b8175bafbc7a036c538c7",
    "spe_2937b289da4c0a7b9877c56ecead4794",
    "spe_168bc05c7a02a8360343eb5fadbba7ed",
    "spe_c82a8417dd9f495eb70c34235647a1b7",
    "spe_fbf3381f525dbe5ddf1a2a1d36e9c4b9",
    "spe_8948a0cc310c6fa8161665d4eda79846",
    "spe_b977ebc0a2ba961cbe158190fce0dc06",
    "spe_6617b4c7273b31fc161fc6e07e620743",
    "spe_4aa23a60464a18e3597cdeb3606ac572",
    "spe_6e7cb65603907f863e06d7a02e00fb67",
    "spe_e08ab8fdf0306e3f0478577b4c5805bc",
    "spe_10b8eb8c3ba5fd8405d1516b7b12f2de",
    "spe_953ed42510a8dd0b33d909d580875241",
    "spe_3a85fdec89b8deb698bba43485b54fd2",
    "spe_4dee2dc8f6fccc98115781683d45acdd",
]
VALIDATION_SET_SPEAKER_IDS: list[str] = [
    "spe_92fea6e4419210f4c4219e84ec89837e",
    "spe_9c4dc6be57f6c63860331813a71417e5",
    "spe_0dd042aee46edc27b2ba0155abdf3d54",
    "spe_f3a0b2f9a75fcfc793a3109d8fbd6c94",
    "spe_9cc1a2ef1b284863ffe37ed105257843",
    "spe_55028d05581a88a8655fa1f74ddfb5a1",
    "spe_a75a8f0e82dc860942b4cb3129f0af35",
    "spe_9b0f671d81679eac001b2d95729c4dc3",
    "spe_6a029298b9eaa3d7e7f8f74510f88e70",
    "spe_eee793fd109985a678edaba7134f0f3f",
]

In [None]:
coral = coral.remove_columns(
    column_names=[
        "id_validator",
        "datetime_start",
        "datetime_end",
        "language_native",
        "language_spoken",
        "zipcode_birth",
        "zip_school",
        "education",
        "occupation",
        "asr_label",
    ]
)
coral

In [None]:
test = coral.filter(
    lambda sample: sample["id_speaker"] in TEST_SET_SPEAKER_IDS
    and sample["asr_cer"] < 0.6
    and sample["validated"] != "rejected"
    and sample["validated"] != "maybe",
    num_proc=8,
)
val = coral.filter(
    lambda sample: sample["id_speaker"] in VALIDATION_SET_SPEAKER_IDS
    and sample["asr_cer"] < 0.6
    and sample["validated"] != "rejected"
    and sample["validated"] != "maybe",
    num_proc=8,
)
train = coral.filter(
    lambda sample: sample["id_speaker"]
    not in TEST_SET_SPEAKER_IDS + VALIDATION_SET_SPEAKER_IDS
    and sample["asr_cer"] < 0.6
    and sample["validated"] != "rejected",
    num_proc=8,
)

In [None]:
test_hours = sum(
    sample["audio"]["array"].shape[0] / sample["audio"]["sampling_rate"] / 60 / 60
    for sample in tqdm(test)
)
val_hours = sum(
    sample["audio"]["array"].shape[0] / sample["audio"]["sampling_rate"] / 60 / 60
    for sample in tqdm(val)
)
train_hours = sum(
    sample["audio"]["array"].shape[0] / sample["audio"]["sampling_rate"] / 60 / 60
    for sample in tqdm(train)
)
print(f"Test hours: {test_hours:.2f}")
print(f"Val hours: {val_hours:.2f}")
print(f"Train hours: {train_hours:.2f}")

In [None]:
new_coral = DatasetDict(dict(train=train, val=val, test=test))
new_coral

In [None]:
new_coral.push_to_hub("alexandrainst/coral", "read_aloud")