In [9]:
from datasets import load_dataset, Dataset
import pandas as pd

In [6]:
def even_sampling(dataset):
    df = pd.DataFrame(dataset)

    # Get count of the least represented label
    label_counts = df['label'].value_counts()
    min_count = label_counts.min()

    # Create balanced dataset by sampling equal numbers from each label
    balanced_df = pd.DataFrame()
    for label in range(len(label_counts)):
        label_subset = df[df['label'] == label].sample(min_count, random_state=42)
        balanced_df = pd.concat([balanced_df, label_subset])

    print(f"Original distribution: {label_counts}")
    print(f"New distribution: {balanced_df['label'].value_counts()}")

    # Convert back to Huggingface dataset
    balanced_dataset = Dataset.from_pandas(balanced_df).shuffle(seed=42)

    return balanced_dataset

In [7]:
DATASET_NAME = "amang1802/wildeweb_cls_1M"

In [13]:
dataset = load_dataset(DATASET_NAME)['train'].select(range(1000_000))
dataset = dataset.map(lambda score: {"label": max(0, score-1)}, input_columns=["classification_score"])
dataset = even_sampling(dataset)
split_dataset = dataset.train_test_split(test_size=0.2)

Original distribution: label
0    856617
1     71438
2     33667
3     20140
4     18138
Name: count, dtype: int64
New distribution: label
0    18138
1    18138
2    18138
3    18138
4    18138
Name: count, dtype: int64


In [15]:
split_dataset.push_to_hub("amang1802/wildeweb_cls_labels_v1")

Creating parquet from Arrow format: 100%|██████████| 37/37 [00:03<00:00,  9.99ba/s]
Creating parquet from Arrow format: 100%|██████████| 37/37 [00:03<00:00, 10.27ba/s]
Uploading the dataset shards: 100%|██████████| 2/2 [00:15<00:00,  7.78s/it]
Creating parquet from Arrow format: 100%|██████████| 19/19 [00:01<00:00, 10.66ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:06<00:00,  6.39s/it]


CommitInfo(commit_url='https://huggingface.co/datasets/amang1802/wildeweb_cls_labels_v1/commit/75019e3f76d5f472dd5a8ab5acbde41c06a6490f', commit_message='Upload dataset', commit_description='', oid='75019e3f76d5f472dd5a8ab5acbde41c06a6490f', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/amang1802/wildeweb_cls_labels_v1', endpoint='https://huggingface.co', repo_type='dataset', repo_id='amang1802/wildeweb_cls_labels_v1'), pr_revision=None, pr_num=None)