# Dataset Challenges
> “Whoever fights monsters should see to it that in the process he does not become a monster. And if you gaze long enough into an abyss, the abyss will gaze back into you.”
>
> ― Friedrich Nietzsche

In [None]:
# default_exp dataset_challenges

In [None]:
# hide
import blackhc.project.script
from nbdev.showdoc import *

Appended /home/blackhc/PycharmProjects/bald-ical/src to paths
Switched to directory /home/blackhc/PycharmProjects/bald-ical
%load_ext autoreload
%autoreload 2


To model real-world use cases better, we need:
* redundant/duplicated data;
* noisy labels (emulating noisy oracles);
* class imbalance;
* out-of-distribution data/outliers included in the unlabelled data.

RepeatedMNIST takes care of the first challenge.
This chapter takes care of the second one.

## Noisy Labels


In [None]:
# exports

import numpy as np
import torch
import torch.utils.data as data

from batchbald_redux.repeated_mnist import TransformedDataset

In [None]:
# exports


def override_labels(dataset: data.Dataset, indices: list, new_labels: list):
    indices_set = set(indices)
    reverse_indices = {idx: rank for rank, idx in enumerate(indices)}

    def override_label(idx, data):
        if idx not in indices_set:
            return data

        x, y = data
        ridx = reverse_indices[idx]
        new_y = new_labels[ridx]
        return x, new_y

    return TransformedDataset(dataset, transformer=override_label)


def corrupt_labels(
    dataset: data.Dataset, *, percentage: int, num_classes: int, generator: np.random.Generator, device=None
):
    N = len(dataset)
    num_corrupted = N * percentage // 100

    indices = generator.choice(N, size=num_corrupted, replace=False)
    new_labels = generator.choice(num_classes, size=num_corrupted, replace=True)

    updated_dataset = override_labels(dataset, indices, torch.as_tensor(new_labels, device=device))

    return updated_dataset


def corrupt_all_labels(dataset: data.Dataset, *, num_classes: int, generator: np.random.Generator, device=None):
    N = len(dataset)

    new_labels = torch.as_tensor(generator.choice(num_classes, size=N, replace=True), device=device)

    def override_label(idx, data):
        x, _ = data
        y = new_labels[idx]
        return x, y

    return TransformedDataset(dataset, transformer=override_label)

### Example

In [None]:
zero_dataset = data.TensorDataset(torch.zeros(10), torch.zeros(10))

noisy_dataset = corrupt_labels(zero_dataset, percentage=50, num_classes=10, generator=np.random.default_rng())

list(noisy_dataset)

[(tensor(0.), tensor(7)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(8)),
 (tensor(0.), tensor(5)),
 (tensor(0.), tensor(4))]

In [None]:
corrupted_dataset = corrupt_all_labels(zero_dataset, num_classes=10, generator=np.random.default_rng())

list(corrupted_dataset)

[(tensor(0.), tensor(5)),
 (tensor(0.), tensor(6)),
 (tensor(0.), tensor(2)),
 (tensor(0.), tensor(5)),
 (tensor(0.), tensor(3)),
 (tensor(0.), tensor(7)),
 (tensor(0.), tensor(7)),
 (tensor(0.), tensor(9)),
 (tensor(0.), tensor(0)),
 (tensor(0.), tensor(3))]

## Class Imbalances

In [None]:
# exports


def get_class_indices(dataset: data.Dataset, *, class_counts: list, generator: np.random.Generator):
    class_counts = list(class_counts)

    subset_indices = []

    remaining_samples = sum(class_counts)

    indices = generator.choice(len(dataset), size=remaining_samples, replace=False)
    for index in indices:
        _, y = dataset[index]

        if class_counts[y] > 0:
            subset_indices.append(index)
            class_counts[y] -= 1

    return subset_indices


def get_imbalanced_dataset(dataset: data.Dataset, *, class_counts: list, generator: np.random.Generator):
    subset_indices = get_class_indices(dataset, class_counts=class_counts, generator=generator)

    return data.Subset(dataset, subset_indices)

### Example

In [None]:
dataset = data.TensorDataset(torch.arange(9), torch.as_tensor(list(range(3)) * 3))

print(list(dataset))

imbalanced_indices = get_class_indices(dataset, class_counts=[3, 2, 1], generator=np.random.default_rng())

dataset[imbalanced_indices]

[(tensor(0), tensor(0)), (tensor(1), tensor(1)), (tensor(2), tensor(2)), (tensor(3), tensor(0)), (tensor(4), tensor(1)), (tensor(5), tensor(2)), (tensor(6), tensor(0)), (tensor(7), tensor(1)), (tensor(8), tensor(2))]


(tensor([1, 3, 6, 0, 7, 5]), tensor([1, 0, 0, 0, 1, 2]))

## Mixing in OOD data

In [None]:
# exports


def add_ood_dataset(
    *,
    dataset: data.Dataset,
    ood_dataset: data.Dataset,
    ood_percentage: int,
    ood_random_labels: bool,
    generator: np.random.Generator,
    num_classes=None,
    device=None
):
    subset_ood_N = len(dataset) * ood_percentage // 100

    ood_N = len(ood_dataset)
    assert subset_ood_N <= ood_N

    ood_indices = generator.choice(ood_N, size=subset_ood_N, replace=False)

    ood_subset = data.Subset(ood_dataset, torch.as_tensor(ood_indices, device=device))

    if ood_random_labels:
        assert num_classes
        ood_subset = corrupt_all_labels(ood_subset, num_classes=num_classes, generator=generator, device=device)

    return data.ConcatDataset((dataset, ood_subset))

### Example

In [None]:
dataset = data.TensorDataset(torch.zeros(10), torch.zeros(10))
ood_dataset = data.TensorDataset(torch.arange(1, 6), torch.arange(1, 6))

mixed_dataset = add_ood_dataset(
    dataset=dataset,
    ood_dataset=ood_dataset,
    ood_percentage=50,
    generator=np.random.default_rng(),
    ood_random_labels=False,
)

list(mixed_dataset)

[(tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(3), tensor(3)),
 (tensor(1), tensor(1)),
 (tensor(2), tensor(2)),
 (tensor(4), tensor(4)),
 (tensor(5), tensor(5))]

In [None]:
mixed_dataset2 = add_ood_dataset(
    dataset=dataset,
    ood_dataset=ood_dataset,
    ood_percentage=50,
    generator=np.random.default_rng(),
    ood_random_labels=True,
    num_classes=100,
)

list(mixed_dataset2)

[(tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(0.), tensor(0.)),
 (tensor(3), tensor(74)),
 (tensor(5), tensor(56)),
 (tensor(4), tensor(49)),
 (tensor(2), tensor(58)),
 (tensor(1), tensor(31))]