# Dataset Challenges
> “Whoever fights monsters should see to it that in the process he does not become a monster. And if you gaze long enough into an abyss, the abyss will gaze back into you.”
>
> ― Friedrich Nietzsche

In [None]:
# default_exp dataset_challenges

In [None]:
# hide
import blackhc.project.script
from nbdev.showdoc import *

Appended /home/blackhc/PycharmProjects/bald-ical/src to paths
Switched to directory /home/blackhc/PycharmProjects/bald-ical
%load_ext autoreload
%autoreload 2


To model real-world use cases better, we need:
* redundant/duplicated data;
* noisy labels (emulating noisy oracles);
* class imbalance;
* out-of-distribution data/outliers included in the unlabelled data;
* noisy or ambiguous samples.

RepeatedMNIST takes care of the first and last challenge in a very specific way. 
This chapter takes care of the other ones.

## Noisy Labels

In [None]:
# exports

import bisect
from typing import Optional, Union

import numpy as np
import torch
import torch.utils.data as data

In [None]:
# exports


class _NamedDataset(data.Dataset):
    dataset: data.Dataset
    name: str

    def __init__(self, dataset: data.Dataset, name: str):
        self.dataset = dataset
        self.name = name

    def __getitem__(self, idx):
        return self.dataset[idx]

    def __len__(self):
        return len(self.dataset)

    def __repr__(self):
        return self.name

    def __add__(self, other):
        return _NamedDataset(data.ConcatDataset([self, other]), f"{self} + {other}")

    def __mul__(self, factor):
        return SubsetDataset(self, factor=factor, seed=0)


class NamedDataset(_NamedDataset):
    def __init__(self, dataset: data.Dataset, name: str):
        super().__init__(dataset, name)


def get_base_dataset(dataset, index):
    if isinstance(dataset, NamedDataset):
        return dataset
    elif isinstance(dataset, data.ConcatDataset):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return get_base_dataset(dataset.datasets[dataset_idx], sample_idx)
    elif isinstance(dataset, _NamedDataset):
        return get_base_dataset(dataset.dataset, index)
    return dataset

In [None]:
from batchbald_redux.fast_mnist import FastMNIST

MNIST = NamedDataset(FastMNIST(root="data/", download=True, device="cuda"), name="MNISTDataset")

MNIST

MNISTDataset

In [None]:
class OverridenTargetDataset(data.Dataset):
    indices_set: set
    reverse_indices: dict
    new_targets: list

    def __init__(self, dataset: data.Dataset, *, indices: list, new_targets: list):
        self.dataset = dataset
        self.indices_set = set(indices)
        self.reverse_indices = {idx: rank for rank, idx in enumerate(indices)}
        self.new_targets = new_targets

    def __getitem__(self, idx):
        data, label = self.dataset[idx]

        if idx not in self.indices_set:
            return data, label

        ridx = self.reverse_indices[idx]
        new_y = self.new_targets[ridx]
        return data, new_y

    def __len__(self):
        return len(self.dataset)

    def __repr__(self):
        return f"OverridenTargetDataset({var(self)})"


class CorruptedLabelsDataset(_NamedDataset):
    dataset: data.Dataset
    options: dict
    implementation: OverridenTargetDataset

    def __init__(
        self, dataset: data.Dataset, *, size_corrupted: Union[float, int], num_classes: int, seed: int, device=None
    ):
        options = dict(size_corrupted=size_corrupted, num_classes=num_classes, seed=seed)

        super().__init__(dataset, f"CorruptedLabelsDataset(dataset={dataset}, {options})")
        self.options = options

        generator = np.random.default_rng(seed)

        N = len(dataset)

        if size_corrupted > 1:
            num_corrupted = size_corrupted
        else:
            num_corrupted = int(N * size_corrupted)

        indices = generator.choice(N, size=num_corrupted, replace=False)
        new_targets = generator.choice(num_classes, size=num_corrupted, replace=True)

        self.implementation = OverridenTargetDataset(
            dataset, indices=indices, new_targets=torch.as_tensor(new_targets, device=device)
        )

    def __getitem__(self, idx):
        return self.implementation[idx]

    def __len__(self):
        return len(self.implementation)


class RandomLabelsDataset(_NamedDataset):
    options: dict
    new_labels: list

    def __init__(self, dataset: data.Dataset, *, num_classes: int, seed: int, device=None):
        options = dict(num_classes=num_classes, seed=seed)

        super().__init__(dataset, f"RandomLabelsDataset({dataset}, {options})")
        self.options = options

        generator = np.random.default_rng(seed)
        N = len(dataset)

        self.new_labels = torch.as_tensor(generator.choice(num_classes, size=N, replace=True), device=device)

    def __getitem__(self, idx):
        data, _ = self.dataset[idx]
        return data, self.new_labels[idx]

### Example

In [None]:
zero_dataset = NamedDataset(data.TensorDataset(torch.zeros(10), torch.zeros(10)), "ZeroDataset")

corrupted_labels_dataset = CorruptedLabelsDataset(zero_dataset, size_corrupted=0.5, num_classes=10, seed=1)

assert list(corrupted_labels_dataset) == [
    (torch.tensor(0.0), torch.tensor(8)),
    (torch.tensor(0.0), torch.tensor(0.0)),
    (torch.tensor(0.0), torch.tensor(8)),
    (torch.tensor(0.0), torch.tensor(3)),
    (torch.tensor(0.0), torch.tensor(0.0)),
    (torch.tensor(0.0), torch.tensor(0.0)),
    (torch.tensor(0.0), torch.tensor(4)),
    (torch.tensor(0.0), torch.tensor(0.0)),
    (torch.tensor(0.0), torch.tensor(2)),
    (torch.tensor(0.0), torch.tensor(0.0)),
]

corrupted_labels_dataset

CorruptedLabelsDataset(dataset=ZeroDataset, {'size_corrupted': 0.5, 'num_classes': 10, 'seed': 1})

In [None]:
corrupted_dataset = RandomLabelsDataset(zero_dataset, num_classes=10, seed=2)

assert list(corrupted_dataset) == [
    (torch.tensor(0.0), torch.tensor(8)),
    (torch.tensor(0.0), torch.tensor(2)),
    (torch.tensor(0.0), torch.tensor(1)),
    (torch.tensor(0.0), torch.tensor(2)),
    (torch.tensor(0.0), torch.tensor(4)),
    (torch.tensor(0.0), torch.tensor(8)),
    (torch.tensor(0.0), torch.tensor(4)),
    (torch.tensor(0.0), torch.tensor(0)),
    (torch.tensor(0.0), torch.tensor(3)),
    (torch.tensor(0.0), torch.tensor(6)),
]

corrupted_dataset

RandomLabelsDataset(ZeroDataset, {'num_classes': 10, 'seed': 2})

## Class Imbalances

In [None]:
# exports


def get_class_indices(dataset: data.Dataset, *, class_counts: list, generator: np.random.Generator):
    class_counts = list(class_counts)

    subset_indices = []

    remaining_samples = sum(class_counts)

    indices = generator.permutation(len(dataset))
    for index in indices:
        _, y = dataset[index]

        if class_counts[y] > 0:
            subset_indices.append(index)
            class_counts[y] -= 1
            remaining_samples -= 1

            if remaining_samples <= 0:
                break

    return subset_indices


class ImbalancedDataset(_NamedDataset):
    options: dict
    indices: list

    def __init__(self, dataset: data.Dataset, *, class_counts: list, seed: int):
        options = dict(class_counts=class_counts, seed=seed)
        super().__init__(dataset, f"ImbalancedDataset(dataset={dataset}, {options})")
        self.options = options

        generator = np.random.default_rng(seed)
        self.indices = get_class_indices(dataset, class_counts=class_counts, generator=generator)

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)


class ImbalancedClassSplitDataset(_NamedDataset):
    dataset: data.Dataset
    options: dict
    indices: list

    def __init__(self, dataset: data.Dataset, *, num_classes: int, majority_percentage: int, seed: int):
        assert (num_classes % 2) == 0

        super().__init__(dataset, None)

        N_class = len(dataset) // num_classes
        N_majority = N_class * majority_percentage // 100
        N_minority = N_class * (100 - majority_percentage) // 100

        generator = np.random.default_rng(seed)

        class_counts = [N_majority] * (num_classes // 2) + [N_minority] * (num_classes // 2)
        class_counts = generator.permuted(class_counts)

        self.options = dict(
            num_classes=num_classes, majority_percentage=majority_percentage, seed=seed, class_counts=class_counts
        )
        self.name = f"ImbalancedDataset(dataset={self.dataset}, {self.options})"

        self.indices = get_class_indices(dataset, class_counts=class_counts, generator=generator)

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)

### Example

In [None]:
three_dataset = NamedDataset(data.TensorDataset(torch.arange(9), torch.as_tensor(list(range(3)) * 3)), "123")

imbalanced_indices = get_class_indices(three_dataset, class_counts=[3, 0, 0], generator=np.random.default_rng())

three_dataset[imbalanced_indices][0]

tensor([6, 0, 3])

In [None]:
imbalanced_dataset = ImbalancedDataset(three_dataset, class_counts=[1, 2, 3], seed=2)
print(imbalanced_dataset[:][1])
imbalanced_dataset

tensor([2, 1, 0, 2, 2, 1])


ImbalancedDataset(dataset=123, {'class_counts': [1, 2, 3], 'seed': 2})

In [None]:
ImbalancedClassSplitDataset(MNIST, num_classes=10, majority_percentage=80, seed=1)

ImbalancedDataset(dataset=MNISTDataset, {'num_classes': 10, 'majority_percentage': 80, 'seed': 1, 'class_counts': array([1200, 4800, 1200, 4800, 4800, 4800, 1200, 1200, 1200, 4800])})

## Mixing in OOD data

In [None]:
# exports

# Convert label dataset to one hot
class OneHotDataset(_NamedDataset):
    options: dict
    targets: list

    def __init__(self, dataset: data.Dataset, *, num_classes: int, dtype=None, device=None):
        options = dict(num_classes=num_classes)

        super().__init__(dataset, f"OneHotDataset({dataset}, {options})")
        self.options = options

        N = len(dataset)
        targets = torch.zeros(len(dataset), num_classes, dtype=dtype, device=device)
        for i, (_, label) in enumerate(dataset):
            targets[i, label] = 1.0

        self.targets = targets

    def __getitem__(self, idx):
        data, _ = self.dataset[idx]
        return data, self.targets[idx]


class SubsetDataset(_NamedDataset):
    dataset: data.Dataset
    options: dict
    indices: list

    def __init__(self, dataset: data.Dataset, *, size: Optional[int] = None, factor: Optional[float] = None, seed: int):
        options = dict(size=size, factor=factor, seed=seed)
        super().__init__(dataset, f"SubsetDataset(dataset={dataset}, {options})")
        self.options = options

        generator = np.random.default_rng(seed)

        assert ((size is not None) or (factor is not None)) and not (size is None and factor is None)
        if size is not None:
            subset_size = size
        elif factor is not None:
            subset_size = int(len(dataset) * factor)
            if seed == 0:
                self.name = f"{dataset} * {factor}"

        self.indices = generator.choice(len(dataset), size=subset_size, replace=subset_size > len(dataset))

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)


class ConstantTargetDataset(_NamedDataset):
    target: object

    def __init__(self, dataset: data.Dataset, target: object):
        super().__init__(dataset, f"ConstantTargetDataset({dataset}, {target})")
        self.target = target

    def __getitem__(self, idx):
        data, _ = self.dataset[idx]
        return data, self.target


def UniformTargetDataset(dataset: data.Dataset, *, num_classes: int, device: str = None):
    target = torch.ones(num_classes, device=device) / num_classes
    result = ConstantTargetDataset(dataset, target)
    result.options = dict(num_classes=num_classes)
    result.name = f"UniformTargetDataset({dataset}, {result.options})"
    return result

To an OOD dataset, one can either use:
```
MNIST+OOD*0.5
```
and then use `get_base_dataset(dataset, index) == OOD` to check whether a picked sample is OOD.

Alternatively, we can use:
```
OneHotDataset(MNIST) + UniformTargetDataset(OOD * 0.5)
```

### Example

In [None]:
MNIST*0.1

MNISTDataset * 0.1

In [None]:
one_hot_MNIST = OneHotDataset(MNIST * 0.1, num_classes=10)
print(one_hot_MNIST[0][1])

one_hot_MNIST

tensor([0., 1., 0., 0., 0., 0., 0., 0., 0., 0.])


OneHotDataset(MNISTDataset * 0.1, {'num_classes': 10})

In [None]:
UniformTargetDataset(one_hot_MNIST, num_classes=10)

UniformTargetDataset(OneHotDataset(MNISTDataset * 0.1, {'num_classes': 10}), {'num_classes': 10})

## Noisy Samples

TODO: add AdditiveGaussianNoise.

The problem is that for large datasets, this can use up a lot of memory. Creating a new random generator for each sample could be too slow (as we are creating batches), so it might be worth