# Dataset Challenges
> “Whoever fights monsters should see to it that in the process he does not become a monster. And if you gaze long enough into an abyss, the abyss will gaze back into you.”
>
> ― Friedrich Nietzsche

In [None]:
# default_exp dataset_challenges

In [None]:
# hide
import blackhc.project.script
from nbdev.showdoc import *

Appended /home/blackhc/PycharmProjects/bald-ical/src to paths
Switched to directory /home/blackhc/PycharmProjects/bald-ical
%load_ext autoreload
%autoreload 2


To model real-world use cases better, we need:
* redundant/duplicated data;
* noisy labels (emulating noisy oracles);
* class imbalance;
* out-of-distribution data/outliers included in the unlabelled data;
* noisy or ambiguous samples.

RepeatedMNIST takes care of the first and last challenge in a very specific way. 
This chapter takes care of the other ones.

## Noisy Labels

In [None]:
# exports

import bisect
from typing import Optional, Union, List

import numpy as np
import torch
import torch.utils.data as data
import torchvision.datasets

from batchbald_redux.fast_mnist import FastMNIST

In [None]:
# exports


def _wrap_alias(dataset: data.Dataset):
    if isinstance(dataset, NamedDataset):
        return repr(dataset)
    return f"({dataset.alias})"


class _AliasDataset(data.Dataset):
    """
    A dataset with an easier to understand alias.

    And convenience operators.
    """

    dataset: data.Dataset
    alias: str

    def __init__(self, dataset: data.Dataset, alias: str):
        self.dataset = dataset
        self.alias = alias

    def __getitem__(self, idx):
        return self.dataset[idx]

    def __len__(self):
        return len(self.dataset)

    def __repr__(self):
        return self.alias

    def __add__(self, other):
        return _AliasDataset(data.ConcatDataset([self, other]), f"{_wrap_alias(self)} + {_wrap_alias(other)}")

    def __mul__(self, factor):
        if int(factor) == factor:
            return RepeatedDataset(self, num_repeats=factor)

        return SubsetDataset(self, factor=factor, seed=0)

    def __rmul__(self, factor):
        if int(factor) == factor:
            return RepeatedDataset(self, num_repeats=factor)

        return SubsetDataset(self, factor=factor, seed=0)


class NamedDataset(_AliasDataset):
    def __init__(self, dataset: data.Dataset, name: str):
        super().__init__(dataset, repr(name))

In [None]:
from batchbald_redux.fast_mnist import FastMNIST

MNIST = NamedDataset(FastMNIST(root="data/", download=True, device="cpu"), name="MNISTDataset")

MNIST

'MNISTDataset'

In [None]:
# exports


class OverridenTargetDataset(_AliasDataset):
    reverse_indices: dict
    new_targets: list

    def __init__(self, dataset: data.Dataset, *, indices: list, new_targets: list):
        self.reverse_indices = {idx: rank for rank, idx in enumerate(indices)}
        self.new_targets = new_targets

        super().__init__(dataset, f"{dataset} | override_targets{dict(indices=indices, new_targets=new_targets)}")

    def __getitem__(self, idx):
        data, label = self.dataset[idx]

        if idx not in self.reverse_indices:
            return data, label

        ridx = self.reverse_indices[idx]
        new_y = self.new_targets[ridx]
        return data, new_y


class CorruptedLabelsDataset(_AliasDataset):
    options: dict
    implementation: OverridenTargetDataset

    def __init__(
        self, dataset: data.Dataset, *, size_corrupted: Union[float, int], num_classes: int, seed: int, device=None
    ):
        options = dict(size_corrupted=size_corrupted, num_classes=num_classes, seed=seed)

        super().__init__(dataset, f"{dataset} | corrupt_labels{options}")
        self.options = options

        generator = np.random.default_rng(seed)

        N = len(dataset)

        if size_corrupted > 1:
            num_corrupted = size_corrupted
        else:
            num_corrupted = int(N * size_corrupted)

        indices = generator.choice(N, size=num_corrupted, replace=False)
        new_targets = generator.choice(num_classes, size=num_corrupted, replace=True)

        self.implementation = OverridenTargetDataset(
            dataset, indices=indices, new_targets=torch.as_tensor(new_targets, device=device)
        )

    def __getitem__(self, idx):
        return self.implementation[idx]

    def __len__(self):
        return len(self.implementation)


class RandomLabelsDataset(_AliasDataset):
    options: dict
    new_labels: list

    def __init__(self, dataset: data.Dataset, *, num_classes: int, seed: int, device=None):
        options = dict(num_classes=num_classes, seed=seed)

        super().__init__(dataset, f"{dataset} | randomize_labels{options}")
        self.options = options

        generator = np.random.default_rng(seed)
        N = len(dataset)

        self.new_labels = torch.as_tensor(generator.choice(num_classes, size=N, replace=True), device=device)

    def __getitem__(self, idx):
        data, _ = self.dataset[idx]
        return data, self.new_labels[idx]

### Example

In [None]:
zero_dataset = NamedDataset(data.TensorDataset(torch.zeros(10), torch.zeros(10)), "ZeroDataset")

corrupted_labels_dataset = CorruptedLabelsDataset(zero_dataset, size_corrupted=0.5, num_classes=10, seed=1)

assert list(corrupted_labels_dataset) == [
    (torch.tensor(0.0), torch.tensor(8)),
    (torch.tensor(0.0), torch.tensor(0.0)),
    (torch.tensor(0.0), torch.tensor(8)),
    (torch.tensor(0.0), torch.tensor(3)),
    (torch.tensor(0.0), torch.tensor(0.0)),
    (torch.tensor(0.0), torch.tensor(0.0)),
    (torch.tensor(0.0), torch.tensor(4)),
    (torch.tensor(0.0), torch.tensor(0.0)),
    (torch.tensor(0.0), torch.tensor(2)),
    (torch.tensor(0.0), torch.tensor(0.0)),
]

corrupted_labels_dataset

'ZeroDataset' | corrupt_labels{'size_corrupted': 0.5, 'num_classes': 10, 'seed': 1}

In [None]:
corrupted_dataset = RandomLabelsDataset(zero_dataset, num_classes=10, seed=2)

assert list(corrupted_dataset) == [
    (torch.tensor(0.0), torch.tensor(8)),
    (torch.tensor(0.0), torch.tensor(2)),
    (torch.tensor(0.0), torch.tensor(1)),
    (torch.tensor(0.0), torch.tensor(2)),
    (torch.tensor(0.0), torch.tensor(4)),
    (torch.tensor(0.0), torch.tensor(8)),
    (torch.tensor(0.0), torch.tensor(4)),
    (torch.tensor(0.0), torch.tensor(0)),
    (torch.tensor(0.0), torch.tensor(3)),
    (torch.tensor(0.0), torch.tensor(6)),
]

corrupted_dataset

'ZeroDataset' | randomize_labels{'num_classes': 10, 'seed': 2}

## Class Imbalances

In [None]:
# exports


def get_class_indices(dataset: data.Dataset, *, class_counts: list, generator: np.random.Generator):
    class_counts = list(class_counts)

    subset_indices = []

    remaining_samples = sum(class_counts)

    indices = generator.permutation(len(dataset))
    for index in indices:
        _, y = dataset[index]

        if class_counts[y] > 0:
            subset_indices.append(index)
            class_counts[y] -= 1
            remaining_samples -= 1

            if remaining_samples <= 0:
                break

    return subset_indices


def get_balanced_sample_indices(dataset: data.Dataset, *, num_classes, samples_per_class, seed: int) -> List[int]:
    class_counts = [samples_per_class] * num_classes
    generator = np.random.default_rng(seed)

    return get_class_indices(dataset, class_counts=class_counts, generator=generator)

In [None]:
# exports


class ImbalancedDataset(_AliasDataset):
    options: dict
    indices: list

    def __init__(self, dataset: data.Dataset, *, class_counts: list, seed: int):
        options = dict(class_counts=class_counts, seed=seed)
        super().__init__(dataset, f"ImbalancedDataset(dataset={dataset}, {options})")
        self.options = options

        generator = np.random.default_rng(seed)
        self.indices = get_class_indices(dataset, class_counts=class_counts, generator=generator)

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)


class ImbalancedClassSplitDataset(_AliasDataset):
    options: dict
    indices: list

    def __init__(
        self, dataset: data.Dataset, *, num_classes: int, majority_percentage: int, minority_percentage: int, seed: int
    ):
        assert (num_classes % 2) == 0

        super().__init__(dataset, None)

        num_samples_per_class = len(dataset) // num_classes
        num_samples_majority = num_samples_per_class * majority_percentage // 100
        num_samples_minority = num_samples_per_class * minority_percentage // 100

        generator = np.random.default_rng(seed)

        class_counts = [num_samples_majority] * (num_classes // 2) + [num_samples_minority] * (num_classes // 2)
        class_counts = generator.permuted(class_counts)

        self.options = dict(
            num_classes=num_classes, majority_percentage=majority_percentage, seed=seed, class_counts=class_counts
        )
        self.alias = f"ImbalancedDataset(dataset={self.dataset}, {self.options})"

        self.indices = get_class_indices(dataset, class_counts=class_counts, generator=generator)

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)

### Example

In [None]:
three_dataset = NamedDataset(data.TensorDataset(torch.arange(9), torch.as_tensor(list(range(3)) * 3)), "123")

imbalanced_indices = get_class_indices(three_dataset, class_counts=[3, 0, 0], generator=np.random.default_rng())

three_dataset[imbalanced_indices][0]

tensor([0, 6, 3])

In [None]:
imbalanced_dataset = ImbalancedDataset(three_dataset, class_counts=[1, 2, 3], seed=2)
print(imbalanced_dataset[:][1])
imbalanced_dataset

tensor([2, 1, 0, 2, 2, 1])


ImbalancedDataset(dataset='123', {'class_counts': [1, 2, 3], 'seed': 2})

In [None]:
ImbalancedClassSplitDataset(MNIST, num_classes=10, majority_percentage=80, minority_percentage=20, seed=1)

ImbalancedDataset(dataset='MNISTDataset', {'num_classes': 10, 'majority_percentage': 80, 'seed': 1, 'class_counts': array([1200, 4800, 1200, 4800, 4800, 4800, 1200, 1200, 1200, 4800])})

## Mixing in OOD data

In [None]:
# exports


# Convert label dataset to one hot
class OneHotDataset(_AliasDataset):
    options: dict
    targets: list

    def __init__(self, dataset: data.Dataset, *, num_classes: int, dtype=None, device=None):
        options = dict(num_classes=num_classes)

        super().__init__(dataset, f"{dataset} | one_hot_targets{options}")
        self.options = options

        N = len(dataset)
        targets = torch.zeros(len(dataset), num_classes, dtype=dtype, device=device)
        # TODO: use get_targets() here, which will require a refactoring
        for i, (_, label) in enumerate(dataset):
            targets[i, label] = 1.0

        self.targets = targets

    def __getitem__(self, idx):
        data, _ = self.dataset[idx]
        return data, self.targets[idx]


class RepeatedDataset(_AliasDataset):
    def __init__(self, dataset: data.Dataset, *, num_repeats: int):
        self.num_repeats = num_repeats

        super().__init__(dataset, f"{dataset}x{num_repeats}")

    def __getitem__(self, idx):
        if idx > len(self):
            return self.dataset[idx]

        return self.dataset[idx % len(self.dataset)]

    def __len__(self):
        return len(self.dataset) * self.num_repeats


class SubsetDataset(_AliasDataset):
    options: dict
    indices: list

    def __init__(self, dataset: data.Dataset, *, size: Optional[int] = None, factor: Optional[float] = None, seed: int):
        options = dict(size=size, factor=factor, seed=seed)
        self.options = options

        generator = np.random.default_rng(seed)

        assert ((size is not None) or (factor is not None)) and not (size is None and factor is None)
        if size is not None:
            subset_size = size
            if seed == 0:
                alias = f"{dataset}[:{size}]"
            else:
                alias = f"{dataset}[:{size};seed={seed}]"
        elif factor is not None:
            subset_size = int(len(dataset) * factor)
            if seed == 0:
                alias = f"{dataset}~x{factor}"
            else:
                alias = f"{dataset}~x{factor} (seed={seed})"

        self.indices = generator.choice(len(dataset), size=subset_size, replace=subset_size > len(dataset))

        super().__init__(dataset, alias)

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)


class ConstantTargetDataset(_AliasDataset):
    target: object

    def __init__(self, dataset: data.Dataset, target: object):
        super().__init__(dataset, f"{dataset} | constant_target{target}")
        self.target = target

    def __getitem__(self, idx):
        data, _ = self.dataset[idx]
        return data, self.target


def UniformTargetDataset(dataset: data.Dataset, *, num_classes: int, dtype=None, device=None):
    target = torch.ones(num_classes, dtype=dtype, device=device) / num_classes
    result = ConstantTargetDataset(dataset, target)
    result.options = dict(num_classes=num_classes)
    result.alias = f"{dataset} | uniform_targets{result.options}"
    return result

To an OOD dataset, one can either use:
```
MNIST+OOD*0.5
```
and then use `get_base_dataset(dataset, index) == OOD` to check whether a picked sample is OOD (see below).

Alternatively, we can use:
```
OneHotDataset(MNIST) + UniformTargetDataset(OOD * 0.5)
```

### Example

In [None]:
MNIST * 0.1

'MNISTDataset'~x0.1

In [None]:
MNIST * 3

'MNISTDataset'x3

In [None]:
one_hot_MNIST = OneHotDataset(MNIST * 0.1, num_classes=10)
print(one_hot_MNIST[0][1])

one_hot_MNIST

tensor([0., 1., 0., 0., 0., 0., 0., 0., 0., 0.])


'MNISTDataset'~x0.1 | one_hot_targets{'num_classes': 10}

In [None]:
UniformTargetDataset(MNIST, num_classes=10)

'MNISTDataset' | uniform_targets{'num_classes': 10}

## Noisy Samples

The problem is that for large datasets, precomputing the noise to be added can use up a lot of memory (doubling the dataset size). Creating a new random generator for each sample is too slow, so it might be worth creating an entirely new dataset.

However, we support exporting and importing datasets, so simply storing and loading the dataset is an option.

> Tip: Do not use this on very large datasets (e.g. ImageNet)... :)

In [None]:
# exports


class AdditiveGaussianNoise(_AliasDataset):
    noise: torch.Tensor
    options: dict

    def __init__(self, dataset: data.Dataset, sigma: float):
        sample = dataset[0][0]
        self.noise = torch.randn(len(dataset), *sample.shape, device=sample.device)
        self.options = dict(sigma=sigma)

        super().__init__(dataset, f"{dataset} + 𝓝(0;σ={sigma})")

    def __getitem__(self, idx):
        sample, target = self.dataset[idx]
        return sample + self.noise[idx], target

In [None]:
noisy_zero = AdditiveGaussianNoise(zero_dataset, sigma=1)

print(noisy_zero)

list(noisy_zero)

'ZeroDataset' + 𝓝(0;σ=1)


[(tensor(0.7870), tensor(0.)),
 (tensor(1.0932), tensor(0.)),
 (tensor(-1.5405), tensor(0.)),
 (tensor(1.2409), tensor(0.)),
 (tensor(0.2093), tensor(0.)),
 (tensor(-0.1016), tensor(0.)),
 (tensor(-0.3672), tensor(0.)),
 (tensor(0.4202), tensor(0.)),
 (tensor(-1.1063), tensor(0.)),
 (tensor(-0.7468), tensor(0.))]

## Exporting Datasets

Finally, to make it easier to use datasets across Python versions, we allow dataset exports.

In [None]:
# exports


def dataset_to_tensors(dataset):
    samples = []
    targets = []

    for sample, target in dataset:
        samples.append(sample.to(device="cpu", non_blocking=True))
        targets.append(target.to(device="cpu", non_blocking=True))

    samples = torch.stack(samples)
    targets = torch.stack(targets)

    return samples, targets


def get_dataset_state_dict(dataset):
    dataset_alias = repr(dataset)

    samples, targets = dataset_to_tensors(dataset)

    state_dict = dict(alias=dataset_alias, samples=samples, targets=targets)

    return state_dict


class ImportedDataset(_AliasDataset):
    def __init__(self, state_dict, device=None):
        tensor_dataset = data.TensorDataset(state_dict["samples"], state_dict["targets"])
        super().__init__(tensor_dataset, state_dict["alias"])


def save_dataset(dataset: data.Dataset, f, **kwargs):
    torch.save(get_dataset_state_dict(dataset), f, **kwargs)


def load_dataset(f, map_location=None, **kwargs):
    state_dict = torch.load(f, map_location=map_location, **kwargs)
    dataset = ImportedDataset(state_dict)
    return dataset

In [None]:
linear_dataset = NamedDataset(data.TensorDataset(torch.arange(0, 10), torch.arange(90, 100)), "LinearDataset")

samples, targets = dataset_to_tensors(linear_dataset)

assert all(samples == torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
assert all(targets == torch.tensor([90, 91, 92, 93, 94, 95, 96, 97, 98, 99]))

In [None]:
get_dataset_state_dict(linear_dataset)

{'alias': "'LinearDataset'",
 'samples': tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
 'targets': tensor([90, 91, 92, 93, 94, 95, 96, 97, 98, 99])}

In [None]:
save_dataset(linear_dataset, "linear.dataset")

loaded_linear_dataset = load_dataset("linear.dataset")

loaded_samples, loaded_targets = dataset_to_tensors(loaded_linear_dataset)

assert all(loaded_samples == samples)
assert all(loaded_targets == targets)

## Auxiliary Methods

We sometimes want to:

* obtain the base dataset for a certain index (for OOD detection)
* get only a target for a certain index; or
* get all targets for a dataset.

In [None]:
# exports


def get_base_dataset(dataset, index):
    if isinstance(dataset, NamedDataset):
        return dataset
    elif isinstance(dataset, data.ConcatDataset):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return get_base_dataset(dataset.datasets[dataset_idx], sample_idx)
    elif isinstance(dataset, ImbalancedDataset):
        return get_base_dataset(dataset.dataset, dataset.indices[index])
    elif isinstance(dataset, ImbalancedClassSplitDataset):
        return get_base_dataset(dataset.dataset, dataset.indices[index])
    elif isinstance(dataset, SubsetDataset):
        return get_base_dataset(dataset.dataset, dataset.indices[index])
    elif isinstance(dataset, data.Subset):
        return get_base_dataset(dataset.dataset, dataset.indices[index])
    elif isinstance(dataset, RepeatedDataset):
        return get_base_dataset(dataset.dataset, index % len(dataset.dataset))
    elif isinstance(dataset, _AliasDataset):
        return get_base_dataset(dataset.dataset, index)
    return dataset


def get_base_index(dataset, index):
    if isinstance(dataset, data.ConcatDataset):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return get_base_index(dataset.datasets[dataset_idx], sample_idx)
    elif isinstance(dataset, ImbalancedDataset):
        return get_base_index(dataset.dataset, dataset.indices[index])
    elif isinstance(dataset, ImbalancedClassSplitDataset):
        return get_base_index(dataset.dataset, dataset.indices[index])
    elif isinstance(dataset, SubsetDataset):
        return get_base_index(dataset.dataset, dataset.indices[index])
    elif isinstance(dataset, data.Subset):
        return get_base_index(dataset.dataset, dataset.indices[index])
    elif isinstance(dataset, RepeatedDataset):
        return get_base_index(dataset.dataset, index % len(dataset.dataset))
    elif isinstance(dataset, _AliasDataset):
        return get_base_index(dataset.dataset, index)
    elif isinstance(dataset, data.TensorDataset):
        return index
    elif isinstance(dataset, torchvision.datasets.MNIST):
        return index
    elif isinstance(dataset, torchvision.datasets.CIFAR10):
        return index

    raise NotImplementedError(f"Unrecognized dataset {dataset}!")


def get_target(dataset, index):
    if isinstance(dataset, data.ConcatDataset):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return get_target(dataset.datasets[dataset_idx], sample_idx)
    elif isinstance(dataset, CorruptedLabelsDataset):
        return get_target(dataset.implementation, index)
    elif isinstance(dataset, OverridenTargetDataset):
        if index not in dataset.reverse_indices:
            return get_target(dataset.dataset, index)

        ridx = dataset.reverse_indices[index]
        new_y = dataset.new_targets[ridx]
        return new_y
    elif isinstance(dataset, RandomLabelsDataset):
        return dataset.new_labels[index]
    elif isinstance(dataset, ImbalancedDataset):
        return get_target(dataset.dataset, dataset.indices[index])
    elif isinstance(dataset, ImbalancedClassSplitDataset):
        return get_target(dataset.dataset, dataset.indices[index])
    elif isinstance(dataset, SubsetDataset):
        return get_target(dataset.dataset, dataset.indices[index])
    elif isinstance(dataset, data.Subset):
        return get_target(dataset.dataset, dataset.indices[index])
    elif isinstance(dataset, RepeatedDataset):
        return get_target(dataset.dataset, index % len(dataset.dataset))
    elif isinstance(dataset, OneHotDataset):
        return dataset.targets[index]
    elif isinstance(dataset, ConstantTargetDataset):
        return dataset.target
    elif isinstance(dataset, _AliasDataset):
        return get_target(dataset.dataset, index)
    elif isinstance(dataset, data.TensorDataset):
        return dataset.tensors[1][index]
    elif isinstance(dataset, torchvision.datasets.MNIST):
        return dataset.targets[index]
    elif isinstance(dataset, torchvision.datasets.CIFAR10):
        return dataset.targets[index]

    raise NotImplementedError(f"Unrecognized dataset {dataset}!")


def get_targets(dataset):
    if isinstance(dataset, data.ConcatDataset):
        return torch.concat([get_targets(sub_dataset) for sub_dataset in dataset.datasets], device="cpu")
    elif isinstance(dataset, CorruptedLabelsDataset):
        return get_targets(dataset.implementation)
    elif isinstance(dataset, OverridenTargetDataset):
        targets = torch.clone(get_targets(dataset.dataset))
        targets[list(dataset.reverse_indices.keys())] = torch.as_tensor(dataset.new_targets)
        return targets
    elif isinstance(dataset, RandomLabelsDataset):
        return dataset.new_labels
    elif isinstance(dataset, ImbalancedDataset):
        return get_targets(dataset.dataset)[torch.as_tensor(dataset.indices)]
    elif isinstance(dataset, ImbalancedClassSplitDataset):
        return get_targets(dataset.dataset)[torch.as_tensor(dataset.indices)]
    elif isinstance(dataset, SubsetDataset):
        return get_targets(dataset.dataset)[torch.as_tensor(dataset.indices)]
    elif isinstance(dataset, Subset):
        return get_targets(dataset.dataset)[torch.as_tensor(dataset.indices)]
    elif isinstance(dataset, RepeatedDataset):
        return get_targets(dataset.dataset).repeat(dataset.num_repeats)
    elif isinstance(dataset, OneHotDataset):
        return dataset.targets
    elif isinstance(dataset, ConstantTargetDataset):
        return dataset.target.expand(len(dataset), *dataset.target.shape)
    elif isinstance(dataset, _AliasDataset):
        return get_targets(dataset.dataset)
    elif isinstance(dataset, data.TensorDataset):
        return torch.as_tensor(dataset.tensors[1])
    elif isinstance(dataset, torchvision.datasets.MNIST):
        return torch.as_tensor(dataset.targets)
    elif isinstance(dataset, torchvision.datasets.CIFAR10):
        return torch.as_tensor(dataset.targets)

    raise NotImplementedError(f"Unrecognized dataset {dataset} with type {type(dataset)}!")

## Repeated MNIST

In [None]:
# exports


def create_repeated_MNIST_dataset(*, device=None, num_repetitions: int = 3, add_noise: bool = True):
    # num_classes = 10, input_size = 28

    train_dataset = NamedDataset(FastMNIST("data", train=True, download=True, device=device), "FastMNIST (Train)")

    rmnist_train_dataset = train_dataset
    if num_repetitions > 1:
        rmnist_train_dataset = train_dataset * num_repetitions

    if add_noise:
        rmnist_train_dataset = AdditiveGaussianNoise(rmnist_train_dataset, 0.1)

    test_dataset = NamedDataset(FastMNIST("data", train=False, device=device), "FastMNIST (Test)")

    return rmnist_train_dataset, test_dataset


def create_MNIST_dataset(device=None):
    return create_repeated_MNIST_dataset(num_repetitions=1, add_noise=False, device=device)

In [None]:
rmnist_example = create_repeated_MNIST_dataset(device="cpu", num_repetitions=2, add_noise=True)
rmnist_example

('FastMNIST (Train)'x2 + 𝓝(0;σ=0.1), 'FastMNIST (Test)')

In [None]:
len(get_targets(rmnist_example[0]))

120000

In [None]:
rmnist_example[0] + rmnist_example[1]

('FastMNIST (Train)'x2 + 𝓝(0;σ=0.1)) + 'FastMNIST (Test)'