In [None]:
# Install dependencies

%pip install matplotlib

In [None]:
from time import sleep
from typing import NamedTuple

import matplotlib.pyplot as plt
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
from hydra import compose, initialize
from IPython.display import Audio as IPythonAudio
from tqdm.auto import tqdm

initialize(config_path="../config", version_base=None)
config = compose(config_name="split_creation")


def play_sample(sample: dict):
    """Play the audio of a sample."""
    audio = sample["audio"]["array"]
    display(IPythonAudio(audio, rate=sample["audio"]["sampling_rate"]))


plt.style.use("ggplot")

In [None]:
while True:
    try:
        coral = concatenate_datasets(
            dsets=[
                split
                for split in load_dataset(
                    "alexandrainst/coral", name="read_aloud"
                ).values()
                if split is not None
            ]
        )
        break
    except Exception as e:
        print(f"Encountered error: {str(e)}. Retrying...")
        continue

In [None]:
# Look at a sample

coral[0]

In [None]:
# Play some samples with the worst CER

worst_samples = coral.sort("asr_cer", reverse=True).select(range(100))
for sample in worst_samples:
    print(f"CER: {sample['asr_cer']:.0%}")
    print(f"Text: {sample['text']!r}")
    play_sample(sample)
    print()

In [None]:
TEST_SET_SPEAKER_IDS: list[str] = [
    "spe_55028d05581a88a8655fa1f74ddfb5a1",
    "spe_2937b289da4c0a7b9877c56ecead4794",
    "spe_4aa23a60464a18e3597cdeb3606ac572",
    "spe_fbf3381f525dbe5ddf1a2a1d36e9c4b9",
    "spe_deedf738efa054ae460989be3033a3cf",
    "spe_d19f3558739cb61e2cc2d8be52c19141",
    "spe_741ba3dd1acd26458718a591a980d743",
    "spe_3fbe6022caf3c819597d4a28eafc092e",
    "spe_d772d2cc8215cdcf3a962c7757156deb",
    "spe_04be495fbbcf0187a0f17708b556ea13",
    "spe_fa639f5932359117682753884585d883",
    "spe_066938532ac270d527696f89d81f0de4",
    "spe_c4ece6eb8bf41ab959af9e2f57a5aae6",
    "spe_ab9690f6ac8dd2226f4bb14699444ed5",
    "spe_71d9860fe866f922740368df660bd1d4",
    "spe_290d17059a29fe3df395be2311c96fc1",
    "spe_040c7192f1c56491f9b00c558ce87d83",
    "spe_7b8398c898a828791c0fc40d6d146b3f",
    "spe_6e7cb65603907f863e06d7a02e00fb67",
    "spe_df3293886215084f5fd6a447bb379b11",
    "spe_2f15ff95f96e7e173ffd77d5ce867858",
    "spe_199e03b334b15576a69be73ea39a34d5",
    "spe_dbf8b55bf5364a9d6eaed082697b36fc",
    "spe_9f6f2d21463e94f5403f67754913fabc",
    "spe_8948a0cc310c6fa8161665d4eda79846",
    "spe_5e319f90767d47e11731d95e314e4670",
    "spe_6e67cbe51a49d9e4abbd7699a4a89d91",
    "spe_03e8b9d0ee8d3192e113ff62c61e4916",
    "spe_fade5754bc6e205fcce917e85dd8def1",
    "spe_de430b1197cf26cb5f4011656a728ee5",
    "spe_65c05e58f399d854594d4716454a806b",
    "spe_26b2833fc94cadba302aba2a631da193",
    "spe_fa6a417205bd632f6832f42120d291ea",
    "spe_01fc2b156c7fe429f1b72bd3be5ad3c3",
    "spe_7d6e87835e35371cba677fffefb10fb1",
]
VALIDATION_SET_SPEAKER_IDS: list[str] = [
    "spe_a55cdc8a6a4230777bbe421825db705a",
    "spe_046c02b65af055859e0f0a1885b2cc5c",
    "spe_2dd1aa67190b348710f31482d291418c",
    "spe_8685f47cbde80df2b261c1dff5649f22",
    "spe_dabbad0be26f953503dcf196440eb7a7",
    "spe_3dd62e87b39a71dc50aaf90199dad34b",
    "spe_51b02c4d372de72ba1cab851642ab363",
    "spe_4b7ba1403d8540b3101c07b9c8a19474",
    "spe_6aeb15b456086536f45918dbdfc63ec6",
    "spe_af2f2d470c277174a74583322f89c8bd",
    "spe_349834612439f09df8374bd3016ba57e",
]

In [None]:
test = coral.filter(
    lambda sample: sample["id_speaker"] in TEST_SET_SPEAKER_IDS
    and sample["asr_cer"] < 0.6
    and sample["validated"] != "rejected"
    and sample["validated"] != "maybe",
    num_proc=8,
)
val = coral.filter(
    lambda sample: sample["id_speaker"] in VALIDATION_SET_SPEAKER_IDS
    and sample["asr_cer"] < 0.6
    and sample["validated"] != "rejected"
    and sample["validated"] != "maybe",
    num_proc=8,
)
train = coral.filter(
    lambda sample: sample["id_speaker"]
    not in TEST_SET_SPEAKER_IDS + VALIDATION_SET_SPEAKER_IDS
    and sample["asr_cer"] < 0.6
    and sample["validated"] != "rejected",
    num_proc=8,
)

In [None]:
new_coral = DatasetDict(dict(train=train, val=val, test=test))
new_coral

In [None]:
class AgeGroup(NamedTuple):
    """Named tuple to represent an age group."""

    min: int
    max: int | None

    def __repr__(self) -> str:
        """Return the string representation of the AgeGroup class."""
        if self.max is None:
            return f"{self.min}-"
        return f"{self.min}-{self.max - 1}"

    def __contains__(self, age: object) -> bool:
        """Check if an age is in the age group.

        Args:
            age:
                The age to check.

        Returns:
            Whether the age is in the age group.
        """
        if not isinstance(age, int):
            return False
        return self.min <= age and (self.max is None or age < self.max)


def age_to_group(age: int, age_groups: list[AgeGroup]) -> str:
    """Return the age group of a given age.

    Args:
        age:
            The age of the speaker.
        age_groups:
            A list of the possible age groups.

    Returns:
        The age group of the speaker.

    Raises:
        ValueError:
            If the age is not in any age group.
    """
    for age_group in age_groups:
        if age in age_group:
            return str(age_group)
    raise ValueError(f"Age {age} not in any age group, out of {age_groups}.")


def print_stats(split: Dataset) -> None:
    """Print statistics about the dataset."""
    print(f"Number of samples: {len(split):,}")
    print(f"Sample rate: {split[0]['audio']['sampling_rate']:,}")

    hours = sum(
        sample["audio"]["array"].shape[0] / sample["audio"]["sampling_rate"] / 60 / 60
        for sample in tqdm(split, desc="Counting number of hours")
    )
    print(f"Number of hours: {hours:,}")

    df = split.remove_columns("audio").to_pandas()

    print(f"Number of unique speakers: {df.id_speaker.nunique():,}")
    print(f"Number of unique sentences: {df.id_sentence.nunique():,}")
    print()

    print(df.gender.value_counts(normalize=True))
    print()

    df.dialect = df.dialect.map(config.sub_dialect_to_dialect)
    df.country_birth = df.country_birth.map(lambda x: "DK" if x is None else x)
    df.loc[df.country_birth != "DK", "dialect"] = "Non-native"
    print(df.dialect.value_counts(normalize=True))
    print()

    df["age_group"] = df.age.apply(
        lambda age: age_to_group(
            age=age,
            age_groups=[
                AgeGroup(min=min_age, max=max_age)
                for min_age, max_age in config.age_groups
            ],
        )
    )
    print(df.age_group.value_counts(normalize=True))

In [None]:
test.remove_columns("audio").to_pandas().id_sentence.nunique()

In [None]:
print_stats(train)

In [None]:
print_stats(val)

In [None]:
print_stats(test)

In [None]:
while True:
    try:
        new_coral.push_to_hub(
            "alexandrainst/coral",
            "read_aloud",
            commit_message="Update test/val tests to have better dialectal representation",
        )
        break
    except Exception as e:
        print(f"Failed to push to hub ({str(e)}) - retrying...")
        sleep(10)
new_coral