In [None]:
from flwr_datasets.partitioner import GroupedNaturalIdPartitioner
from flwr_datasets.visualization import plot_label_distributions
import matplotlib.pyplot as plt
from datasets import load_dataset

In [None]:
# Load train partition of SpeechCommands
sc = load_dataset("speech_commands", "v0.02", split="train", token=False)

# Use the "Grouped partitioner" from FlowerDatasets to construct groups of 30 unique speaker ids
partitioner = GroupedNaturalIdPartitioner(partition_by="speaker_id", group_size=30)

### Removing _silence_ clips

In [None]:
# Remove the silence audio clips (the dataset comes with 5 long audio clips. we don't want to show these in the plot below)
# Those silence audio clips are the entries in the dataset with `speaker_id`=None. Let's remove them
# At training time, each client with get 10% new data samples containing 1s-long silence clips
def filter_none_speaker(example):
    return example["speaker_id"] is not None


filtered_dataset = sc.filter(filter_none_speaker)

# Apply dataset to partitioner
partitioner.dataset = filtered_dataset

### Making a plot

In [None]:
fig, axis = plt.subplots(figsize=(16, 6))
fig, ax, df = plot_label_distributions(
    partitioner,
    axis=axis,
    label_name="label",
    plot_type="bar",
    size_unit="absolute",
    partition_id_axis="x",
    legend=True,
    verbose_labels=True,
    title="Per Partition Labels Distribution",
    legend_kwargs={"ncols": 2, "bbox_to_anchor": (1.05, 0.5)},
)

In [None]:
fig.savefig("whisper_flower_data.png", format="png", bbox_inches="tight")

### Process dataset into 12 classes

To go from 35 classes into 12, we need to apply the following cahnges:
- all audio clips that had the `is_unknown` set, will be assigned the same "target" label `11`
- Silence audio clips will assigned label `10`

We achieve this 35:12 mapping by means of the function below (similar to the one used in the code).

In [None]:
def prepare_dataset(batch):
    data = {}
    # All unknown keywords are assigned label 11. The silence clips get assigned label 10
    # In this way we have 12 classes with labels 0-11
    data["targets"] = (
        11 if batch["is_unknown"] else (10 if batch["label"] == 35 else batch["label"])
    )
    return data

In [None]:
dataset_12cls = filtered_dataset.map(prepare_dataset, num_proc=4)

In [None]:
# Re-construct the partitioner and apply the filtered dataset
partitioner = GroupedNaturalIdPartitioner(partition_by="speaker_id", group_size=30)
partitioner.dataset = dataset_12cls

In [None]:
# Generate the plot again, this time using the new "targets" key
fig, axis = plt.subplots(figsize=(16, 6))
fig, ax, df = plot_label_distributions(
    partitioner,
    axis=axis,
    label_name="targets",
    plot_type="bar",
    size_unit="absolute",
    partition_id_axis="x",
    legend=True,
    verbose_labels=False,
    title="Per Partition Labels Distribution",
    legend_kwargs={"ncols": 2, "bbox_to_anchor": (1.0, 0.5)},
)

In [None]:
# classes 0-9 correspond to keywords: 'yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off'
# Class 10 is 'silence' and class 11 is 'other' (combined remaining classes from the 35-class original representation)