In [None]:
import matplotlib.pyplot as plt
from datasets import Audio, DatasetDict, load_dataset
from IPython.display import Audio as IPythonAudio
from tqdm.auto import tqdm


def play_sample(sample: dict):
    """Play the audio of a sample."""
    audio = sample["audio"]["array"]
    display(IPythonAudio(audio, rate=16_000))


plt.style.use("ggplot")

In [None]:
while True:
    try:
        coral = load_dataset("alexandrainst/coral", name="read_aloud", split="train")
        break
    except Exception as e:
        print(f"Encountered error: {str(e)}. Retrying...")
        continue
coral_16k = coral.cast_column("audio", Audio(sampling_rate=16_000))

In [None]:
# Look at a sample

coral[0]

In [None]:
# Plot the WER distribution of approved vs non-approved samples

approved_samples = coral.filter(lambda example: example["validated"] == "approved")
rejected_samples = coral.filter(lambda example: example["validated"] == "rejected")

plt.hist(
    approved_samples["asr_wer"], bins=50, label="approved", alpha=0.5, density=True
)
plt.hist(
    rejected_samples["asr_wer"], bins=50, label="rejected", alpha=0.5, density=True
)
plt.xlim(0, 1)
plt.legend()
plt.savefig("wer-distribution-approved-rejected.png", dpi=200)
plt.show()

In [None]:
# Play some samples with the worst WER

worst_samples = coral.sort("asr_wer", reverse=True).select(range(100))
for sample in worst_samples:
    print(f"WER: {sample['asr_wer']:.0%}")
    print(f"Text: {sample['text']!r}")
    play_sample(sample)
    print()

In [None]:
TEST_SET_SPEAKER_IDS: list[str] = [
    "spe_c3c1fdae39d6bf6e462868f8f52b7e3e",
    "spe_d7da3d62a1a9885c1cb3280668437759",
    "spe_ce5c35bd408b296511ce0b05ecc33de1",
    "spe_e3013f96eed48bacc13dd8253609cf9b",
    "spe_4facb80c94341b25425ec1d8962b1f8d",
    "spe_20b91d51f72ee56930ca778cb16c29da",
    "spe_e33e46611f54ae91ed7b235c11ef2628",
    "spe_6b93da4530e853772df0fc8b2337142c",
    "spe_6617b4c7273b31fc161fc6e07e620743",
    "spe_1c32c35e35670f64d9e1a673c47aabd1",
    "spe_590765d7656376e83a33c54f9c2e3976",
    "spe_066938532ac270d527696f89d81f0de4",
    "spe_545558c5701d956a2c63057cd313ff50",
    "spe_af4e767c077909a95b9bd834ca224833",
    "spe_4b7ba1403d8540b3101c07b9c8a19474",
    "spe_e3742811d83011e22ec2ef5a7af32065",
    "spe_fbf3381f525dbe5ddf1a2a1d36e9c4b9",
    "spe_ef7f083c2097793e28388535a81e14ea",
    "spe_b9112f9327f2390093bbc082a1651bad",
    "spe_b788006083ced2efabc75a7907220250",
    "spe_6e7cb65603907f863e06d7a02e00fb67",
    "spe_003c825c9ad2f1496c22cc16a04b1598",
    "spe_935a99ce745c2c042a77f6e4c831fd94",
    "spe_55028d05581a88a8655fa1f74ddfb5a1",
    "spe_6efa16af1112af15ab482171e156d3f3",
    "spe_01fc2b156c7fe429f1b72bd3be5ad3c3",
    "spe_b19c9900784bdf8f8ef3ea3e78002011",
    "spe_02d28146c013111766f18f0d2198785e",
    "spe_3937cb6805e15326b37253d6148babb5",
    "spe_492647a87720047b55f4033d0df8082a",
    "spe_9b8d26599c6b7932dbac00832b73dcf8",
    "spe_6a029298b9eaa3d7e7f8f74510f88e70",
]
VALIDATION_SET_SPEAKER_IDS: list[str] = [
    "spe_a8ffed9a90c0e89338892f23dfaac338",
    "spe_50d0664744aa2a241c084363b04e39c5",
    "spe_003971defa823a2eb98f079cdc91c634",
    "spe_0dd042aee46edc27b2ba0155abdf3d54",
    "spe_cd5174de19523f69dd8613ea311997d4",
    "spe_0cf8566481d05e4ed329e34211a36311",
    "spe_4aa23a60464a18e3597cdeb3606ac572",
    "spe_647d4e905427d45ab699abe73d80ef1d",
    "spe_bb1d0e9d3f2bca18658975b3073924cb",
    "spe_93f1d99433d997beeec289d60e074ed2",
    "spe_9c4dc6be57f6c63860331813a71417e5",
]

In [None]:
test = coral.filter(
    lambda sample: sample["id_speaker"] in TEST_SET_SPEAKER_IDS
    and sample["asr_wer"] < 0.4
    and sample["validated"] != "rejected"
    and sample["validated"] != "maybe",
    num_proc=8,
)
val = coral.filter(
    lambda sample: sample["id_speaker"] in VALIDATION_SET_SPEAKER_IDS
    and sample["asr_wer"] < 0.4
    and sample["validated"] != "rejected"
    and sample["validated"] != "maybe",
    num_proc=8,
)
train = coral.filter(
    lambda sample: sample["id_speaker"]
    not in TEST_SET_SPEAKER_IDS + VALIDATION_SET_SPEAKER_IDS
    and sample["asr_wer"] < 1.0
    and sample["validated"] != "rejected",
    num_proc=8,
)

In [None]:
test_hours = sum(
    sample["audio"]["array"].shape[0] / 48_000 / 60 / 60 for sample in tqdm(test)
)
val_hours = sum(
    sample["audio"]["array"].shape[0] / 48_000 / 60 / 60 for sample in tqdm(val)
)
train_hours = sum(
    sample["audio"]["array"].shape[0] / 48_000 / 60 / 60 for sample in tqdm(train)
)
print(f"Test hours: {test_hours:.2f}")
print(f"Val hours: {val_hours:.2f}")
print(f"Train hours: {train_hours:.2f}")

In [None]:
new_coral = DatasetDict(dict(train=train, val=val, test=test))
new_coral

In [None]:
new_coral.push_to_hub("alexandrainst/coral", "read_aloud", private=True)