In [1]:
DATASET_CACHE_DIR = '.datasets/'

In [2]:
from datasets import load_dataset

snli = load_dataset('stanfordnlp/snli', cache_dir=DATASET_CACHE_DIR)

In [53]:
from math import ceil
import random
from typing import Any, Dict, List
from tqdm.contrib import tenumerate
from datasets import Dataset, concatenate_datasets


def binarize_labels(
      dataset: Dataset
    , labels_to_pos: List[Any]
    , labels_to_neg: List[Any]
    , pos_label: int = 1
    , neg_label: int = 0
    , sample_seed: int = 42
    , shuffle_seed: int = 42
) -> Dataset:
    assert 'label' in dataset.features
    assert set(labels_to_pos).isdisjoint(labels_to_neg)
    
    random.seed(sample_seed)

    pos_label2indices: Dict[Any, List] = {}
    neg_label2indices: Dict[Any, List] = {}
    for index, label in tenumerate(dataset['label']):
        if label in labels_to_pos:
            pos_label2indices.setdefault(label, []) \
                             .append(index)
        if label in labels_to_neg:
            neg_label2indices.setdefault(label, []) \
                             .append(index)
 
    pos_num = sum(len(indices) for indices in pos_label2indices.values())
    neg_num = sum(len(indices) for indices in neg_label2indices.values())
    sample_ratio = min(pos_num, neg_num) / max(pos_num, neg_num)

    if pos_num < neg_num:
        for label, indices in neg_label2indices.items():
            sample_size = ceil(sample_ratio * len(indices))
            neg_label2indices[label] = random.sample(indices, sample_size)
    else:
        for label, indices in pos_label2indices.items():
            sample_size = ceil(sample_ratio * len(indices))
            pos_label2indices[label] = random.sample(indices, sample_size)

    def _map_labels_to_pos(batch):
        batch['label'] = [pos_label for _ in range(len(batch['label']))]
        return batch
    
    def _map_labels_to_neg(batch):
        batch['label'] = [neg_label for _ in range(len(batch['label']))]
        return batch

    dataset_balanced_binarized = concatenate_datasets(
              [dataset.select(indices)
                      .map(_map_labels_to_pos, batched=True, num_proc=4) 
               for indices in pos_label2indices.values()] 
            + [dataset.select(indices)
                      .map(_map_labels_to_neg, batched=True, num_proc=4) 
               for indices in neg_label2indices.values()]
        )

    return dataset_balanced_binarized.shuffle(seed=shuffle_seed)

In [None]:
test_dataset = Dataset.from_dict({
    'original_label': [0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3],
    'label'         : [0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3]       
})

print(test_dataset)

test_binarized = binarize_labels(
          test_dataset
        , labels_to_pos=[2, 3]
        , labels_to_neg=[0, 1]
        , pos_label=1
        , neg_label=0
    )

from collections import Counter

print(Counter(test_binarized['label']))
print(Counter(test_binarized['original_label']))

Dataset({
    features: ['original_label', 'label'],
    num_rows: 15
})


  0%|          | 0/15 [00:00<?, ?it/s]

Map (num_proc=4):   0%|          | 0/5 [00:00<?, ? examples/s]

num_proc must be <= 2. Reducing num_proc to 2 for dataset of size 2.


Map (num_proc=2):   0%|          | 0/2 [00:00<?, ? examples/s]

num_proc must be <= 3. Reducing num_proc to 3 for dataset of size 3.


Map (num_proc=3):   0%|          | 0/3 [00:00<?, ? examples/s]

Counter({0: 5, 1: 5})
Counter({1: 5, 2: 3, 0: 2})
