# Experiment Data (Setup)
> No empirical experiments without data.

In [None]:
# default_exp experiment_data

In [None]:
# hide
import blackhc.project.script
from nbdev.showdoc import *

Appended /home/blackhc/PycharmProjects/bald-ical/src to paths
Switched to directory /home/blackhc/PycharmProjects/bald-ical
%load_ext autoreload
%autoreload 2


Import modules and functions were are going to use.

In [None]:
# exports

from dataclasses import dataclass
from typing import List, Optional, Set

import numpy as np
import torch
import torch.utils.data
from torch import nn
from torch.utils.data import Dataset

from batchbald_redux.active_learning import ActiveLearningData
from batchbald_redux.dataset_operations import (
    AdditiveGaussianNoise,
    AliasDataset,
    NamedDataset,
    get_balanced_sample_indices,
    get_balanced_sample_indices_by_class,
    get_class_indices,
    get_class_indices_by_class,
    get_targets,
)
from batchbald_redux.datasets.factories import get_dataset

In [None]:
# exports


@dataclass
class ExperimentData:
    active_learning: ActiveLearningData
    validation_dataset: Dataset
    evaluation_dataset: Dataset
    test_dataset: Dataset

    train_augmentations: nn.Module

    initial_training_set_indices: List[int]
    evaluation_set_indices: List[int]

    ood_dataset: Optional[NamedDataset]

    # TODO: replace this with dataset info on the targets
    ood_exposure: bool

    device: str


class ExperimentDataConfig:
    def load(self, device) -> ExperimentData:
        raise NotImplementedError()

In [None]:
# exports


@dataclass
class OoDDatasetConfig:
    ood_dataset_name: str
    ood_repetitions: float
    ood_exposure: bool


@dataclass
class StandardExperimentDataConfig(ExperimentDataConfig):
    id_dataset_name: str
    id_repetitions: float

    initial_training_set_size: int

    validation_set_size: int
    validation_split_random_state: int

    evaluation_set_size: int

    add_dataset_noise: bool

    ood_dataset_config: Optional[OoDDatasetConfig]

    def load(self, device) -> ExperimentData:
        return load_standard_experiment_data(
            id_dataset_name=self.id_dataset_name,
            id_repetitions=self.id_repetitions,
            initial_training_set_size=self.initial_training_set_size,
            validation_set_size=self.validation_set_size,
            validation_split_random_state=self.validation_split_random_state,
            evaluation_set_size=self.evaluation_set_size,
            add_dataset_noise=self.add_dataset_noise,
            ood_dataset_config=self.ood_dataset_config,
            device=device,
        )


def load_standard_experiment_data(
    *,
    id_dataset_name: str,
    id_repetitions: float,
    initial_training_set_size: int,
    validation_set_size: int,
    validation_split_random_state: int,
    evaluation_set_size: int,
    add_dataset_noise: bool,
    ood_dataset_config: Optional[OoDDatasetConfig],
    device: str,
) -> ExperimentData:
    split_dataset = get_dataset(
        id_dataset_name,
        root="data",
        validation_set_size=validation_set_size,
        validation_split_random_state=validation_split_random_state,
        normalize_like_cifar10=True,
        device_hint=device,
    )

    train_dataset = split_dataset.train

    # TODO: add hook here to further process the train dataset?

    # If we reduce the train set, we need to do so before picking the initial train set.
    if id_repetitions < 1:
        train_dataset = train_dataset * id_repetitions

    targets = train_dataset.get_targets()
    num_classes = train_dataset.get_num_classes()
    initial_samples_per_class = initial_training_set_size // num_classes
    evaluation_set_samples_per_class = evaluation_set_size // num_classes
    samples_per_class = initial_samples_per_class + evaluation_set_samples_per_class

    balanced_samples_indices = get_balanced_sample_indices_by_class(
        targets=targets,
        num_classes=num_classes,
        samples_per_class=samples_per_class,
        seed=validation_split_random_state,
    )

    initial_training_set_indices = [
        idx for by_class in balanced_samples_indices.values() for idx in by_class[:initial_samples_per_class]
    ]
    evaluation_set_indices = [
        idx for by_class in balanced_samples_indices.values() for idx in by_class[initial_samples_per_class:]
    ]

    # If we over-sample the train set, we do so after picking the initial train set to avoid duplicates
    # (duplicates within the initial train set).
    if id_repetitions > 1:
        train_dataset = train_dataset * id_repetitions

    if ood_dataset_config:
        ood_exposure = ood_dataset_config.ood_exposure
        odd_split_dataset = get_dataset(
            ood_dataset_config.ood_dataset_name, root="data", normalize_like_cifar10=True, device_hint=device
        )
        assert split_dataset.device == odd_split_dataset.device, (
            f"ID dataset resides on {split_dataset.device}, while OOD dataset is on {odd_split_dataset.device};"
            'try to put both on "cpu"!'
        )
        original_ood_dataset = odd_split_dataset.train
        if ood_exposure:
            train_dataset = train_dataset.one_hot(device=split_dataset.device)
            ood_dataset = original_ood_dataset.uniform_target(
                device=split_dataset.device, num_classes=train_dataset.get_num_classes()
            )
        else:
            ood_dataset = original_ood_dataset.constant_target(
                target=torch.tensor(-1, device=split_dataset.device), num_classes=train_dataset.get_num_classes()
            )

        if ood_dataset_config.ood_repetitions != 1:
            ood_dataset = ood_dataset * ood_dataset_config.ood_repetitions

        train_dataset = train_dataset + ood_dataset
    else:
        original_ood_dataset = None
        ood_exposure = False

    if add_dataset_noise:
        train_dataset = AdditiveGaussianNoise(train_dataset, 0.1)
    else:
        if id_repetitions > 1 or (ood_dataset_config is not None and ood_dataset_config.ood_repetitions) > 1:
            raise RuntimeError("`add_dataset_noise`==False, even though repeated id or ood data!")

    active_learning_data = ActiveLearningData(train_dataset)

    active_learning_data.acquire_base_indices(initial_training_set_indices)

    evaluation_dataset = AliasDataset(
        active_learning_data.extract_dataset_from_base_indices(evaluation_set_indices),
        f"Evaluation Set ({len(evaluation_set_indices)} samples)",
    )

    return ExperimentData(
        active_learning=active_learning_data,
        validation_dataset=split_dataset.validation,
        test_dataset=split_dataset.test,
        evaluation_dataset=evaluation_dataset,
        train_augmentations=split_dataset.train_augmentations,
        initial_training_set_indices=initial_training_set_indices,
        evaluation_set_indices=evaluation_set_indices,
        ood_dataset=original_ood_dataset,
        ood_exposure=ood_exposure,
        device=split_dataset.device,
    )

## `load_standard_experiment_data` tests

In [None]:
# slow

load_standard_experiment_data(
    id_dataset_name="MNIST",
    initial_training_set_size=20,
    validation_set_size=32,
    evaluation_set_size=16,
    id_repetitions=1.0,
    add_dataset_noise=False,
    validation_split_random_state=0,
    ood_dataset_config=None,
    device="cuda",
)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)


ExperimentData(active_learning=ActiveLearningData(base_dataset='MNIST (Train, seed=0, 59968 samples)', num_training_samples=20, num_pool_samples=59938), validation_dataset='MNIST (Validation, seed=0, 32 samples)', evaluation_dataset=Evaluation Set (10 samples), test_dataset='MNIST (Test)', train_augmentations=Sequential(), initial_training_set_indices=[46413, 55726, 25576, 55469, 39617, 35783, 36962, 56698, 4436, 24251, 27760, 7593, 15110, 21413, 31797, 42500, 34791, 46864, 47424, 57533], evaluation_set_indices=[43895, 47051, 56807, 13452, 39664, 38002, 53721, 37072, 18635, 52360], ood_dataset=None, ood_exposure=False, device='cuda')

In [None]:
# slow

load_standard_experiment_data(
    id_dataset_name="CIFAR-10",
    initial_training_set_size=20,
    validation_set_size=32,
    evaluation_set_size=16,
    id_repetitions=1.0,
    add_dataset_noise=False,
    validation_split_random_state=0,
    ood_dataset_config=None,
    device="cuda",
)

Files already downloaded and verified


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)


Files already downloaded and verified


ExperimentData(active_learning=ActiveLearningData(base_dataset='CIFAR-10 (Train, seed=0, 49968 samples)', num_training_samples=20, num_pool_samples=49938), validation_dataset='CIFAR-10 (Validation, seed=0, 32 samples)', evaluation_dataset=Evaluation Set (10 samples), test_dataset='CIFAR-10 (Test)', train_augmentations=Sequential(
  (0): RandomCrop(crop_size=(32, 32), padding=4, fill=0, pad_if_needed=False, padding_mode=constant, resample=BILINEAR, p=1.0, p_batch=1.0, same_on_batch=False, return_transform=False)
  (1): RandomHorizontalFlip(p=0.5, p_batch=1.0, same_on_batch=False, return_transform=None)
), initial_training_set_indices=[5618, 30732, 1910, 25225, 6409, 17895, 49063, 49577, 41071, 10377, 27423, 811, 27285, 22836, 26253, 5916, 49126, 40676, 31804, 13474], evaluation_set_indices=[36153, 11586, 36207, 16977, 1000, 10548, 11403, 2005, 41796, 25579], ood_dataset=None, ood_exposure=False, device='cpu')

In [None]:
# slow

load_standard_experiment_data(
    id_dataset_name="CIFAR-10",
    initial_training_set_size=20,
    validation_set_size=32,
    evaluation_set_size=16,
    id_repetitions=1.0,
    add_dataset_noise=False,
    validation_split_random_state=0,
    ood_dataset_config=OoDDatasetConfig(ood_dataset_name="MNIST", ood_repetitions=1.0, ood_exposure=False),
    device="cuda",
)

# THIS OUGHT TO CRASH! BUT TELL YOU WHAT TO DO INSTEAD :)

Files already downloaded and verified


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)


Files already downloaded and verified


AssertionError: ID dataset resides on cpu, while OOD dataset is on cuda;try to put both on "cpu"!

In [None]:
# slow

load_standard_experiment_data(
    id_dataset_name="CIFAR-10",
    initial_training_set_size=20,
    validation_set_size=32,
    evaluation_set_size=16,
    id_repetitions=1.0,
    add_dataset_noise=False,
    validation_split_random_state=0,
    ood_dataset_config=OoDDatasetConfig(ood_dataset_name="MNIST", ood_repetitions=1.0, ood_exposure=False),
    device="cpu",
)

Files already downloaded and verified
Files already downloaded and verified


ExperimentData(active_learning=<batchbald_redux.active_learning.ActiveLearningData object at 0x7ff6f58d5c40>, validation_dataset='CIFAR-10 (Validation, seed=0, 32 samples)', evaluation_dataset=Evaluation Set (10 samples), test_dataset='CIFAR-10 (Test)', train_augmentations=Sequential(
  (0): RandomCrop(crop_size=(32, 32), padding=4, fill=0, pad_if_needed=False, padding_mode=constant, resample=BILINEAR, p=1.0, p_batch=1.0, same_on_batch=False, return_transform=False)
  (1): RandomHorizontalFlip(p=0.5, p_batch=1.0, same_on_batch=False, return_transform=None)
), initial_training_set_indices=[5618, 30732, 1910, 25225, 6409, 17895, 49063, 49577, 41071, 10377, 27423, 811, 27285, 22836, 26253, 5916, 49126, 40676, 31804, 13474], evaluation_set_indices=[36153, 11586, 36207, 16977, 1000, 10548, 11403, 2005, 41796, 25579], ood_dataset='MNIST (Train, seed=0, 60000 samples)', device='cpu')

In [None]:
# slow

load_standard_experiment_data(
    id_dataset_name="CIFAR-10",
    initial_training_set_size=20,
    validation_set_size=32,
    evaluation_set_size=16,
    id_repetitions=1.0,
    add_dataset_noise=False,
    validation_split_random_state=0,
    ood_dataset_config=OoDDatasetConfig(ood_dataset_name="CIFAR-100", ood_repetitions=1.0, ood_exposure=False),
    device="cuda",
)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


ExperimentData(active_learning=<batchbald_redux.active_learning.ActiveLearningData object at 0x7f434cd5ca90>, validation_dataset='CIFAR-10 (Validation, seed=0, 32 samples)', evaluation_dataset=Evaluation Set (10 samples), test_dataset='CIFAR-10 (Test)', train_augmentations=Sequential(
  (0): RandomCrop(crop_size=(32, 32), padding=4, fill=0, pad_if_needed=False, padding_mode=constant, resample=BILINEAR, p=1.0, p_batch=1.0, same_on_batch=False, return_transform=False)
  (1): RandomHorizontalFlip(p=0.5, p_batch=1.0, same_on_batch=False, return_transform=None)
), initial_training_set_indices=[5618, 30732, 1910, 25225, 6409, 17895, 49063, 49577, 41071, 10377, 27423, 811, 27285, 22836, 26253, 5916, 49126, 40676, 31804, 13474], evaluation_set_indices=[36153, 11586, 36207, 16977, 1000, 10548, 11403, 2005, 41796, 25579], ood_dataset='CIFAR-100 (Train, seed=0, 50000 samples)')

In [None]:
# slow

load_standard_experiment_data(
    id_dataset_name="CIFAR-10",
    initial_training_set_size=20,
    validation_set_size=32,
    evaluation_set_size=16,
    id_repetitions=1.0,
    add_dataset_noise=False,
    validation_split_random_state=0,
    ood_dataset_config=OoDDatasetConfig(ood_dataset_name="CIFAR-100", ood_repetitions=1.0, ood_exposure=True),
    device="cuda",
)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


ExperimentData(active_learning=<batchbald_redux.active_learning.ActiveLearningData object at 0x7f434db5f2e0>, validation_dataset='CIFAR-10 (Validation, seed=0, 32 samples)', evaluation_dataset=Evaluation Set (10 samples), test_dataset='CIFAR-10 (Test)', train_augmentations=Sequential(
  (0): RandomCrop(crop_size=(32, 32), padding=4, fill=0, pad_if_needed=False, padding_mode=constant, resample=BILINEAR, p=1.0, p_batch=1.0, same_on_batch=False, return_transform=False)
  (1): RandomHorizontalFlip(p=0.5, p_batch=1.0, same_on_batch=False, return_transform=None)
), initial_training_set_indices=[5618, 30732, 1910, 25225, 6409, 17895, 49063, 49577, 41071, 10377, 27423, 811, 27285, 22836, 26253, 5916, 49126, 40676, 31804, 13474], evaluation_set_indices=[36153, 11586, 36207, 16977, 1000, 10548, 11403, 2005, 41796, 25579], ood_dataset='CIFAR-100 (Train, seed=0, 50000 samples)')

# `ImbalancedTestDistributionExperimentDataConfig`

In [None]:
# exports


@dataclass
class ImbalancedTestDistributionExperimentDataConfig(ExperimentDataConfig):
    """Make the test set and evaluation set imbalanced"""

    dataset_name: str
    repetitions: float

    initial_training_set_size: int

    validation_set_size: int
    validation_split_random_state: int

    evaluation_set_size: int

    add_dataset_noise: bool

    minority_classes: Set[int]
    minority_class_percentage: float

    def load(self, device) -> ExperimentData:
        return load_imbalanced_experiment_data(
            dataset_name=self.dataset_name,
            repetitions=self.repetitions,
            initial_training_set_size=self.initial_training_set_size,
            validation_set_size=self.validation_set_size,
            validation_split_random_state=self.validation_split_random_state,
            evaluation_set_size=self.evaluation_set_size,
            add_dataset_noise=self.add_dataset_noise,
            minority_classes=self.minority_classes,
            minority_class_percentage=self.minority_class_percentage,
            device=device,
        )


def load_imbalanced_experiment_data(
    *,
    dataset_name: str,
    repetitions: float,
    initial_training_set_size: int,
    validation_set_size: int,
    validation_split_random_state: int,
    evaluation_set_size: int,
    add_dataset_noise: bool,
    minority_classes: Set[int],
    minority_class_percentage: float,
    device: str,
) -> ExperimentData:
    split_dataset = get_dataset(
        dataset_name,
        root="data",
        validation_set_size=validation_set_size,
        validation_split_random_state=validation_split_random_state,
        normalize_like_cifar10=True,
        device_hint=device,
    )

    train_dataset = split_dataset.train

    # If we reduce the train set, we need to do so before picking the initial train set.
    if repetitions < 1:
        train_dataset = train_dataset * repetitions

    num_classes = train_dataset.get_num_classes()
    # Keep the initial training set balanced at least.
    initial_samples_per_class = initial_training_set_size // num_classes
    weighted_num_classes = num_classes - len(minority_classes) * (1 - minority_class_percentage / 100)

    evaluation_set_samples_per_class = int(evaluation_set_size / weighted_num_classes)
    evaluation_set_class_counts = [
        int(evaluation_set_samples_per_class * minority_class_percentage / 100)
        if i in minority_classes
        else evaluation_set_samples_per_class
        for i in range(num_classes)
    ]

    print("Evaluation Set Class Counts:", evaluation_set_class_counts)

    generator = np.random.default_rng(validation_split_random_state)
    class_indices_by_class = get_class_indices_by_class(
        train_dataset.get_targets(),
        class_counts=[
            initial_samples_per_class + evaluation_set_class_count
            for evaluation_set_class_count in evaluation_set_class_counts
        ],
        generator=generator,
    )

    initial_training_set_indices = [
        idx for by_class in class_indices_by_class.values() for idx in by_class[:initial_samples_per_class]
    ]
    evaluation_set_indices = [
        idx for by_class in class_indices_by_class.values() for idx in by_class[initial_samples_per_class:]
    ]

    # If we over-sample the train set, we do so after picking the initial train set to avoid duplicates.
    if repetitions > 1:
        train_dataset = train_dataset * repetitions

    if add_dataset_noise:
        train_dataset = AdditiveGaussianNoise(train_dataset, 0.1)
    else:
        if repetitions > 1:
            raise RuntimeError("`add_dataset_noise`==False, even though repeated id!")

    active_learning_data = ActiveLearningData(train_dataset)

    active_learning_data.acquire_base_indices(initial_training_set_indices)

    evaluation_dataset = AliasDataset(
        active_learning_data.extract_dataset_from_base_indices(evaluation_set_indices),
        f"Evaluation Set ({len(evaluation_set_indices)} samples)",
    )

    test_dataset = split_dataset.test.imbalance_subsample(
        minority_classes=minority_classes,
        minority_percentage=minority_class_percentage,
        seed=validation_split_random_state,
    )
    validation_dataset = split_dataset.validation.imbalance_subsample(
        minority_classes=minority_classes,
        minority_percentage=minority_class_percentage,
        seed=validation_split_random_state,
    )

    return ExperimentData(
        active_learning=active_learning_data,
        validation_dataset=validation_dataset,
        test_dataset=test_dataset,
        evaluation_dataset=evaluation_dataset,
        train_augmentations=split_dataset.train_augmentations,
        initial_training_set_indices=initial_training_set_indices,
        evaluation_set_indices=evaluation_set_indices,
        ood_dataset=None,
        ood_exposure=False,
        device=split_dataset.device,
    )

## Tests

In [None]:
# slow

load_imbalanced_experiment_data(
    dataset_name="MNIST",
    repetitions=1.0,
    initial_training_set_size=20,
    validation_set_size=5000,
    validation_split_random_state=0,
    evaluation_set_size=1000,
    minority_classes={1, 2, 5, 8},
    minority_class_percentage=20,
    add_dataset_noise=False,
    device="cuda",
)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)


Evaluation Set Class Counts: [147, 29, 29, 147, 147, 29, 147, 147, 29, 147]


ExperimentData(active_learning=ActiveLearningData(base_dataset='MNIST (Train, seed=0, 55000 samples)', num_training_samples=20, num_pool_samples=53982), validation_dataset=ImbalancedClassSplitDataset(dataset='MNIST (Validation, seed=0, 5000 samples)', {'minority_classes': {8, 1, 2, 5}, 'minority_percentage': 20, 'seed': 0, 'class_counts': [500, 100, 100, 500, 500, 100, 500, 500, 100, 500]}), evaluation_dataset=Evaluation Set (998 samples), test_dataset=ImbalancedClassSplitDataset(dataset='MNIST (Test)', {'minority_classes': {8, 1, 2, 5}, 'minority_percentage': 20, 'seed': 0, 'class_counts': [1000, 200, 200, 1000, 1000, 200, 1000, 1000, 200, 1000]}), train_augmentations=Sequential(), initial_training_set_indices=[33294, 1023, 978, 2538, 25764, 30627, 9954, 8642, 1608, 23054, 39798, 24785, 29444, 7316, 13075, 30426, 14915, 37374, 42515, 3029], evaluation_set_indices=[26100, 3612, 35806, 25748, 2738, 3717, 5686, 30976, 11659, 49732, 27975, 38073, 6695, 24587, 11606, 4100, 44524, 53294, 50

In [None]:
# slow

load_imbalanced_experiment_data(
    dataset_name="MNIST",
    repetitions=1.0,
    initial_training_set_size=20,
    validation_set_size=5000,
    validation_split_random_state=0,
    evaluation_set_size=1000,
    minority_classes={1, 2, 5, 8},
    minority_class_percentage=0,
    add_dataset_noise=True,
    device="cuda",
)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)


Evaluation Set Class Counts: [166, 0, 0, 166, 166, 0, 166, 166, 0, 166]


ExperimentData(active_learning=ActiveLearningData(base_dataset='MNIST (Train, seed=0, 55000 samples)' + 𝓝(0;σ=0.1), num_training_samples=20, num_pool_samples=53984), validation_dataset=ImbalancedClassSplitDataset(dataset='MNIST (Validation, seed=0, 5000 samples)', {'minority_classes': {8, 1, 2, 5}, 'minority_percentage': 0, 'seed': 0, 'class_counts': [500, 0, 0, 500, 500, 0, 500, 500, 0, 500]}), evaluation_dataset=Evaluation Set (996 samples), test_dataset=ImbalancedClassSplitDataset(dataset='MNIST (Test)', {'minority_classes': {8, 1, 2, 5}, 'minority_percentage': 0, 'seed': 0, 'class_counts': [1000, 0, 0, 1000, 1000, 0, 1000, 1000, 0, 1000]}), train_augmentations=Sequential(), initial_training_set_indices=[33294, 1023, 978, 2538, 25764, 30627, 9954, 8642, 1608, 23054, 39798, 24785, 29444, 7316, 13075, 30426, 14915, 37374, 42515, 3029], evaluation_set_indices=[26100, 3612, 35806, 25748, 2738, 3717, 5686, 30976, 11659, 49732, 27975, 38073, 6695, 24587, 11606, 4100, 44524, 53294, 50990, 

# `OODClassesDistributionExperimentDataConfig`

In [None]:
# exports


@dataclass
class OODClassesDistributionExperimentDataConfig(ExperimentDataConfig):
    """Make the test set and evaluation set imbalanced"""

    dataset_name: str
    repetitions: float

    initial_training_set_size: int

    validation_set_size: int
    validation_split_random_state: int

    evaluation_set_size: int

    add_dataset_noise: bool

    ood_classes: Set[int]
    ood_repetitions: float
    ood_exposure: bool

    def load(self, device) -> ExperimentData:
        return load_ood_classes_experiment_data(
            dataset_name=self.dataset_name,
            repetitions=self.repetitions,
            initial_training_set_size=self.initial_training_set_size,
            validation_set_size=self.validation_set_size,
            validation_split_random_state=self.validation_split_random_state,
            evaluation_set_size=self.evaluation_set_size,
            add_dataset_noise=self.add_dataset_noise,
            ood_classes=self.ood_classes,
            ood_exposure=self.ood_exposure,
            ood_repetitions=self.ood_repetitions,
            device=device,
        )


def load_ood_classes_experiment_data(
    *,
    dataset_name: str,
    repetitions: float,
    initial_training_set_size: int,
    validation_set_size: int,
    validation_split_random_state: int,
    evaluation_set_size: int,
    add_dataset_noise: bool,
    ood_classes: Set[int],
    ood_exposure: bool,
    ood_repetitions: float,
    device: str,
) -> ExperimentData:
    split_dataset = get_dataset(
        dataset_name,
        root="data",
        validation_set_size=validation_set_size,
        validation_split_random_state=validation_split_random_state,
        normalize_like_cifar10=True,
        device_hint=device,
    )

    # Split off OOD dataset
    train_targets = get_targets(split_dataset.train)
    id_indices = [index for index, target in enumerate(train_targets) if int(target) not in ood_classes]
    ood_indices = [index for index, target in enumerate(train_targets) if int(target) in ood_classes]

    original_ood_dataset = NamedDataset(
        split_dataset.train.subset(ood_indices), f"{split_dataset.train}[target in {ood_classes}]"
    )
    id_dataset = AliasDataset(
        split_dataset.train.subset(id_indices), f"{split_dataset.train}[target not in {ood_classes}]"
    )

    # If we reduce the train set, we need to do so before picking the initial train set.
    if repetitions < 1:
        id_dataset = id_dataset * repetitions

    num_classes = split_dataset.train.get_num_classes()
    num_id_classes = num_classes - len(ood_classes)
    assert num_id_classes > 0

    # Keep the initial training set balanced at least.
    initial_samples_per_class = initial_training_set_size // num_id_classes
    evaluation_set_samples_per_class = evaluation_set_size // num_id_classes

    class_counts = [
        0 if i in ood_classes else initial_samples_per_class + evaluation_set_samples_per_class
        for i in range(num_classes)
    ]

    print("Initial Samples + Evaluation Set Class Counts:", class_counts)

    generator = np.random.default_rng(validation_split_random_state)
    class_indices_by_class = get_class_indices_by_class(get_targets(id_dataset), class_counts=class_counts, generator=generator)

    initial_training_set_indices = [
        idx for by_class in class_indices_by_class.values() for idx in by_class[:initial_samples_per_class]
    ]
    evaluation_set_indices = [
        idx for by_class in class_indices_by_class.values() for idx in by_class[initial_samples_per_class:]
    ]

    # If we over-sample the train set, we do so after picking the initial train set to avoid duplicates.
    if repetitions > 1:
        id_dataset = id_dataset * repetitions

    if ood_exposure:
        id_dataset = id_dataset.one_hot(device=split_dataset.device)
        ood_dataset = original_ood_dataset.uniform_target(device=split_dataset.device, num_classes=num_classes)
    else:
        ood_dataset = original_ood_dataset.constant_target(
            target=torch.tensor(-1, device=split_dataset.device), num_classes=num_classes
        )

    if ood_repetitions != 1:
        ood_dataset = ood_dataset * ood_repetitions

    train_dataset = id_dataset + ood_dataset

    if add_dataset_noise:
        train_dataset = AdditiveGaussianNoise(train_dataset, 0.1)
    else:
        if repetitions > 1:
            raise RuntimeError("`add_dataset_noise`==False, even though repeated id!")

    active_learning_data = ActiveLearningData(train_dataset)

    active_learning_data.acquire_base_indices(initial_training_set_indices)

    evaluation_dataset = AliasDataset(
        active_learning_data.extract_dataset_from_base_indices(evaluation_set_indices),
        f"Evaluation Set ({len(evaluation_set_indices)} samples)",
    )

    test_dataset = split_dataset.test.imbalance_subsample(
        minority_classes=ood_classes, minority_percentage=0, seed=validation_split_random_state
    )
    validation_dataset = split_dataset.validation.imbalance_subsample(
        minority_classes=ood_classes, minority_percentage=0, seed=validation_split_random_state
    )

    return ExperimentData(
        active_learning=active_learning_data,
        validation_dataset=validation_dataset,
        test_dataset=test_dataset,
        evaluation_dataset=evaluation_dataset,
        train_augmentations=split_dataset.train_augmentations,
        initial_training_set_indices=initial_training_set_indices,
        evaluation_set_indices=evaluation_set_indices,
        ood_dataset=original_ood_dataset,
        ood_exposure=ood_exposure,
        device=split_dataset.device,
    )

## Tests

In [None]:
# slow

load_ood_classes_experiment_data(
    dataset_name="MNIST",
    repetitions=1.0,
    initial_training_set_size=20,
    validation_set_size=5000,
    validation_split_random_state=0,
    evaluation_set_size=1000,
    ood_classes={1, 2, 5, 8},
    ood_exposure=True,
    ood_repetitions=1.0,
    add_dataset_noise=True,
    device="cuda",
)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)


33025
21975
Initial Samples + Evaluation Set Class Counts: [169, 0, 0, 169, 169, 0, 169, 169, 0, 169]


ExperimentData(active_learning=ActiveLearningData(base_dataset='MNIST (Train, seed=0, 55000 samples)'[target not in {8, 1, 2, 5}] | one_hot_targets{'num_classes': 10} + 'MNIST (Train, seed=0, 55000 samples)'[target in {8, 1, 2, 5}] | uniform_targets{'num_classes': 10} + 𝓝(0;σ=0.1), num_training_samples=18, num_pool_samples=53986), validation_dataset=ImbalancedClassSplitDataset(dataset='MNIST (Validation, seed=0, 5000 samples)', {'minority_classes': {8, 1, 2, 5}, 'minority_percentage': 0, 'seed': 0, 'class_counts': [500, 0, 0, 500, 500, 0, 500, 500, 0, 500]}), evaluation_dataset=Evaluation Set (996 samples), test_dataset=ImbalancedClassSplitDataset(dataset='MNIST (Test)', {'minority_classes': {8, 1, 2, 5}, 'minority_percentage': 0, 'seed': 0, 'class_counts': [1000, 0, 0, 1000, 1000, 0, 1000, 1000, 0, 1000]}), train_augmentations=Sequential(), initial_training_set_indices=[8278, 24082, 9356, 28456, 30930, 26581, 7107, 15612, 20067, 18098, 18619, 26280, 1849, 16652, 162, 19205, 8484, 2500

In [None]:
# slow

load_ood_classes_experiment_data(
    dataset_name="MNIST",
    repetitions=1.0,
    initial_training_set_size=20,
    validation_set_size=5000,
    validation_split_random_state=0,
    evaluation_set_size=1000,
    ood_classes={1, 2, 5, 8},
    ood_exposure=False,
    ood_repetitions=1.0,
    add_dataset_noise=True,
    device="cuda",
)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)


33025
21975
Initial Samples + Evaluation Set Class Counts: [169, 0, 0, 169, 169, 0, 169, 169, 0, 169]


ExperimentData(active_learning=ActiveLearningData(base_dataset='MNIST (Train, seed=0, 55000 samples)'[target not in {8, 1, 2, 5}] + 'MNIST (Train, seed=0, 55000 samples)'[target in {8, 1, 2, 5}] | constant_target{'target': tensor(-1, device='cuda:0'), 'num_classes': 10} + 𝓝(0;σ=0.1), num_training_samples=18, num_pool_samples=53986), validation_dataset=ImbalancedClassSplitDataset(dataset='MNIST (Validation, seed=0, 5000 samples)', {'minority_classes': {8, 1, 2, 5}, 'minority_percentage': 0, 'seed': 0, 'class_counts': [500, 0, 0, 500, 500, 0, 500, 500, 0, 500]}), evaluation_dataset=Evaluation Set (996 samples), test_dataset=ImbalancedClassSplitDataset(dataset='MNIST (Test)', {'minority_classes': {8, 1, 2, 5}, 'minority_percentage': 0, 'seed': 0, 'class_counts': [1000, 0, 0, 1000, 1000, 0, 1000, 1000, 0, 1000]}), train_augmentations=Sequential(), initial_training_set_indices=[8278, 24082, 9356, 28456, 30930, 26581, 7107, 15612, 20067, 18098, 18619, 26280, 1849, 16652, 162, 19205, 8484, 25

# `CinicCifarShiftExperimentDataConfig`

In [None]:
# exports


@dataclass
class CinicCifarShiftExperimentDataConfig(ExperimentDataConfig):
    """CINIC-10 as train set, CIFAR-10 as test/eval set."""

    train_imagenet_only: bool

    initial_training_set_size: int

    validation_set_size: int
    validation_split_random_state: int

    evaluation_set_size: int

    def load(self, device) -> ExperimentData:
        return load_cinic_cifar_shift_experiment_data(
            train_imagenet_only=self.train_imagenet_only,
            initial_training_set_size=self.initial_training_set_size,
            validation_set_size=self.validation_set_size,
            validation_split_random_state=self.validation_split_random_state,
            evaluation_set_size=self.evaluation_set_size,
            device=device,
        )


def load_cinic_cifar_shift_experiment_data(
    *,
    train_imagenet_only: bool,
    initial_training_set_size: int,
    validation_set_size: int,
    validation_split_random_state: int,
    evaluation_set_size: int,
    device: str,
) -> ExperimentData:
    split_imagenet_cinic10_dataset = get_dataset(
        name="IMAGENET-CINIC-10",
        root="data",
        validation_set_size=0,
        validation_split_random_state=validation_split_random_state,
        normalize_like_cifar10=True,
        device_hint=device,
    )

    split_cifar10_dataset = get_dataset(
        name="CIFAR-10",
        root="data",
        validation_set_size=validation_set_size,
        validation_split_random_state=validation_split_random_state,
        normalize_like_cifar10=True,
        device_hint=device,
    )

    assert split_imagenet_cinic10_dataset.device == split_cifar10_dataset.device

    train_dataset = split_imagenet_cinic10_dataset.train
    validation_dataset = split_cifar10_dataset.validation
    test_dataset = split_cifar10_dataset.test

    num_classes = split_cifar10_dataset.train.get_num_classes()
    evaluation_set_samples_per_class = evaluation_set_size // num_classes
    balanced_evaluation_indices = get_balanced_sample_indices(
        targets=split_cifar10_dataset.train.get_targets(),
        num_classes=num_classes,
        samples_per_class=evaluation_set_samples_per_class,
        seed=validation_split_random_state,
    )

    evaluation_dataset, cifar10_train_dataset = split_cifar10_dataset.train.split(balanced_evaluation_indices)

    # If we add CIFAR-10 back, exclude the evaluation samples.
    if not train_imagenet_only:
        train_dataset += cifar10_train_dataset

    initial_samples_per_class = initial_training_set_size // num_classes
    balanced_initial_indices = get_balanced_sample_indices(
        targets=train_dataset.get_targets(),
        num_classes=num_classes,
        samples_per_class=initial_samples_per_class,
        seed=validation_split_random_state,
    )

    active_learning_data = ActiveLearningData(train_dataset)
    active_learning_data.acquire_base_indices(balanced_initial_indices)

    return ExperimentData(
        active_learning=active_learning_data,
        validation_dataset=validation_dataset,
        test_dataset=test_dataset,
        evaluation_dataset=evaluation_dataset,
        train_augmentations=split_imagenet_cinic10_dataset.train_augmentations,
        initial_training_set_indices=balanced_initial_indices,
        evaluation_set_indices=balanced_evaluation_indices,
        ood_dataset=None,
        ood_exposure=False,
        device=split_imagenet_cinic10_dataset.device,
    )

## Test

In [None]:
# slow

load_cinic_cifar_shift_experiment_data(
    train_imagenet_only=False,
    initial_training_set_size=20,
    validation_set_size=5000,
    validation_split_random_state=0,
    evaluation_set_size=1000,
    device="cuda",
)

Files already downloaded and verified
Files already downloaded and verified


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)


Files already downloaded and verified


ExperimentData(active_learning=ActiveLearningData(base_dataset=('CINIC-10 (Train, imagenet_only=True, seed=0, 140000 samples)') + 'CIFAR-10 (Train, seed=0, 45000 samples)'[~[28964, 11411, 14886, 9609, 39292, 398, 3431, 21404, 44846, 11777, 6361, 36817, 5044, 37231, 14346, 24286, 4294, 28590, 16297, 12733, 19940, 27283, 27046, 17495, 4417, 40795, 10717, 3957, 14535, 20341, 27604, 43757, 26320, 40449, 10574, 12396, 14656, 21304, 44149, 12180, 27762, 22949, 32997, 11309, 29865, 36001, 20338, 24032, 34368, 9137, 23376, 13769, 44858, 15640, 40594, 407, 40764, 7166, 17277, 15347, 7175, 10233, 14617, 35065, 39662, 32385, 28273, 15891, 26145, 27266, 38700, 14319, 31039, 4596, 21831, 6428, 27461, 6582, 518, 20455, 6795, 21079, 30299, 33470, 38939, 27229, 22701, 33968, 19425, 6796, 5874, 32641, 32181, 5994, 43189, 38244, 32894, 18469, 34402, 20303, 20577, 32160, 36055, 40702, 27739, 32548, 34043, 37950, 485, 6804, 21241, 28762, 5895, 23667, 24825, 6571, 35106, 32930, 30491, 2144, 44711, 31990, 1

In [None]:
# slow

load_cinic_cifar_shift_experiment_data(
    train_imagenet_only=True,
    initial_training_set_size=20,
    validation_set_size=5000,
    validation_split_random_state=0,
    evaluation_set_size=1000,
    device="cuda",
)

Files already downloaded and verified
Files already downloaded and verified


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)


Files already downloaded and verified


ExperimentData(active_learning=ActiveLearningData(base_dataset='CINIC-10 (Train, imagenet_only=True, seed=0, 140000 samples)', num_training_samples=20, num_pool_samples=139980), validation_dataset='CIFAR-10 (Validation, seed=0, 5000 samples)', evaluation_dataset='CIFAR-10 (Train, seed=0, 45000 samples)'[[28964, 11411, 14886, 9609, 39292, 398, 3431, 21404, 44846, 11777, 6361, 36817, 5044, 37231, 14346, 24286, 4294, 28590, 16297, 12733, 19940, 27283, 27046, 17495, 4417, 40795, 10717, 3957, 14535, 20341, 27604, 43757, 26320, 40449, 10574, 12396, 14656, 21304, 44149, 12180, 27762, 22949, 32997, 11309, 29865, 36001, 20338, 24032, 34368, 9137, 23376, 13769, 44858, 15640, 40594, 407, 40764, 7166, 17277, 15347, 7175, 10233, 14617, 35065, 39662, 32385, 28273, 15891, 26145, 27266, 38700, 14319, 31039, 4596, 21831, 6428, 27461, 6582, 518, 20455, 6795, 21079, 30299, 33470, 38939, 27229, 22701, 33968, 19425, 6796, 5874, 32641, 32181, 5994, 43189, 38244, 32894, 18469, 34402, 20303, 20577, 32160, 360