# SVHN/CIFAR-10 OOD Experiment
> Can we get better by training on our assumptions?

In [None]:
# default_exp svhn_cifar10_ood_experiment

In [None]:
# hide
import blackhc.project.script

Import modules and functions were are going to use.

In [None]:
# exports

import dataclasses
import traceback
from dataclasses import dataclass
from typing import Type, Union

import torch
import torch.utils.data
from blackhc.project import is_run_from_ipython
from blackhc.project.experiment import embedded_experiments

import batchbald_redux.acquisition_functions as acquisition_functions
from batchbald_redux.acquisition_functions import (
    CandidateBatchComputer,
    EvalCandidateBatchComputer,
)
from batchbald_redux.black_box_model_training import evaluate
from batchbald_redux.dataset_challenges import (
    get_base_dataset_index,
    get_target,
)
from batchbald_redux.di import DependencyInjection
from batchbald_redux.experiment_data import ExperimentData, ExperimentDataConfig, OoDDatasetConfig
from batchbald_redux.resnet_models import Cifar10ModelTrainer
from batchbald_redux.train_eval_model import (
    TrainEvalModel,
    TrainSelfDistillationEvalModel,
)
from batchbald_redux.trained_model import ModelTrainer

In [None]:
# exports


@dataclass
class UnifiedExperiment:
    seed: int

    id_dataset_name: str
    ood_dataset_name: str
    ood_exposure: bool
    initial_training_set_size: int = 0
    validation_set_size: int = 1024
    evaluation_set_size: int = 1024
    id_repetitions: float = 1
    ood_repetitions: float = 1
    add_dataset_noise: bool = False
    validation_split_random_state: int = 0

    acquisition_size: int = 5
    max_training_set: int = 200

    num_pool_samples: int = 20
    num_validation_samples: int = 20
    num_training_samples: int = 1

    device: str = "cuda"
    acquisition_function: Union[
        Type[CandidateBatchComputer], Type[EvalCandidateBatchComputer]
    ] = acquisition_functions.BALD
    train_eval_model: Type[TrainEvalModel] = TrainSelfDistillationEvalModel
    model_trainer_factory: Type[ModelTrainer] = Cifar10ModelTrainer

    temperature: float = 0.0

    def load_experiment_data(self) -> ExperimentData:
        di = DependencyInjection(vars(self), [])
        odc: OoDDatasetConfig = di.create_dataclass_type(OoDDatasetConfig)
        edc: ExperimentDataConfig = di.create_dataclass_type(ExperimentDataConfig, ood_dataset_config=odc)
        return edc.load()

    # Simple Dependency Injection
    def create_acquisition_function(self):
        di = DependencyInjection(vars(self))
        return di.create_dataclass_type(self.acquisition_function)

    def create_train_eval_model(self, runtime_config) -> TrainEvalModel:
        config = {**vars(self), **runtime_config}
        di = DependencyInjection(config, [])
        return di.create_dataclass_type(self.train_eval_model)

    def create_model_trainer(self) -> ModelTrainer:
        di = DependencyInjection(vars(self))
        return di.create_dataclass_type(self.model_trainer_factory)

    def run(self, store):
        torch.manual_seed(self.seed)

        # Active Learning setup
        data = self.load_experiment_data()
        store["dataset_info"] = dict(training=repr(data.active_learning.base_dataset), test=repr(data.test_dataset))
        store["initial_training_set_indices"] = data.initial_training_set_indices
        store["evaluation_set_indices"] = data.evaluation_set_indices

        model_trainer = self.create_model_trainer()

        train_loader = model_trainer.get_train_dataloader(data.active_learning.training_dataset)
        pool_loader = model_trainer.get_evaluation_dataloader(data.active_learning.pool_dataset)
        validation_loader = model_trainer.get_evaluation_dataloader(data.validation_dataset)
        test_loader = model_trainer.get_evaluation_dataloader(data.test_dataset)

        store["active_learning_steps"] = []
        active_learning_steps = store["active_learning_steps"]

        acquisition_function = self.create_acquisition_function()

        num_iterations = 0
        max_iterations = int(1.5 * (self.max_training_set - self.initial_training_set_size) / self.acquisition_size)

        # Active Training Loop
        while True:
            training_set_size = len(data.active_learning.training_dataset)
            print(f"Training set size {training_set_size}:")

            # iteration_log = dict(training={}, pool_training={}, evaluation_metrics=None, acquisition=None)
            active_learning_steps.append({})
            iteration_log = active_learning_steps[-1]

            iteration_log["training"] = {}

            if self.ood_exposure:
                loss = torch.nn.KLDivLoss(log_target=False, reduction="batchmean")
                validation_loss = torch.nn.NLLLoss()
            else:
                loss = validation_loss = torch.nn.NLLLoss()

            trained_model = model_trainer.get_trained(train_loader=train_loader, train_augmentations=data.train_augmentations,
                                                      validation_loader=validation_loader,
                                                      log=iteration_log["training"], loss=loss,
                                                      validation_loss=validation_loss)

            evaluation_metrics = evaluate(model=trained_model, num_samples=self.num_validation_samples,
                                          loader=test_loader, device=self.device)
            iteration_log["evaluation_metrics"] = evaluation_metrics
            print(f"Perf after training {evaluation_metrics}")

            if training_set_size >= self.max_training_set or num_iterations >= max_iterations:
                print("Done.")
                break

            if isinstance(acquisition_function, CandidateBatchComputer):
                candidate_batch = acquisition_function.compute_candidate_batch(trained_model, pool_loader, self.device)
            elif isinstance(acquisition_function, EvalCandidateBatchComputer):
                if self.evaluation_set_size:
                    eval_dataset = data.evaluation_dataset
                else:
                    eval_dataset = data.active_learning.pool_dataset

                train_eval_model = self.create_train_eval_model(
                    dict(
                        model_trainer=model_trainer,
                        training_dataset=data.active_learning.training_dataset,
                        train_augmentations=data.train_augmentations,
                        eval_dataset=eval_dataset,
                        validation_loader=validation_loader,
                        trained_model=trained_model,
                    )
                )

                iteration_log["eval_training"] = {}
                trained_eval_model = train_eval_model(device=self.device, training_log=iteration_log["eval_training"])

                candidate_batch = acquisition_function.compute_candidate_batch(
                    trained_model, trained_eval_model, pool_loader, device=self.device
                )
            else:
                raise ValueError(f"Unknown acquisition function {acquisition_function}!")

            candidate_global_dataset_indices = []
            candidate_labels = []
            for index in candidate_batch.indices:
                base_di = get_base_dataset_index(data.active_learning.pool_dataset, index)
                dataset_type = "ood" if base_di.dataset == data.ood_dataset else "id"
                candidate_global_dataset_indices.append((dataset_type, base_di.index))
                label = get_target(data.active_learning.pool_dataset, index)
                candidate_labels.append(label)

            iteration_log["acquisition"] = dict(
                indices=candidate_global_dataset_indices, labels=candidate_labels, scores=candidate_batch.scores
            )

            print(candidate_batch)
            print(candidate_global_dataset_indices)

            if self.ood_exposure:
                data.active_learning.acquire(candidate_batch.indices)
            else:
                data.active_learning.acquire(
                    [index for index, label in zip(candidate_batch.indices, candidate_labels) if label != -1]
                )

            ls = ", ".join(f"{label} ({score:.4})" for label, score in zip(candidate_labels, candidate_batch.scores))
            print(f"Acquiring (label, score)s: {ls}")

            num_iterations += 1

In [None]:
# experiment

experiment = UnifiedExperiment(
    ood_exposure=True,
    id_dataset_name="CIFAR-10",
    ood_dataset_name="SVHN",
    seed=1,
    max_training_epochs=1,
    max_training_set=20 + 10,
    acquisition_function=acquisition_functions.EvalBALD,
    evaluation_set_size=100,
    acquisition_size=10,
    num_pool_samples=2,
    device="cuda",
)

results = {}
experiment.run(results)
results

Creating: OoDDatasetConfig(
	ood_dataset_name=SVHN,
	ood_repetitions=1,
	ood_exposure=True
)
Creating: ExperimentDataConfig(
	id_dataset_name=CIFAR-10,
	id_repetitions=1,
	initial_training_set_size=0,
	validation_set_size=1024,
	validation_split_random_state=0,
	evaluation_set_size=100,
	add_dataset_noise=False,
	device=cuda,
	ood_dataset_config=OoDDatasetConfig(ood_dataset_name='SVHN', ood_repetitions=1, ood_exposure=True)
)
Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: data/SVHN/train_32x32.mat
Using downloaded and verified file: data/SVHN/test_32x32.mat
Creating: Cifar10ModelTrainer(
	device=cuda,
	num_training_samples=1,
	num_validation_samples=20,
	max_training_epochs=1,
	min_samples_per_epoch=5056
)
Creating: EvalBALD(
	acquisition_size=10,
	num_pool_samples=2
)
Training set size 0:


get_predictions_labels:   0%|          | 0/200000 [00:00<?, ?it/s]

Exception in thread QueueFeederThread:
Traceback (most recent call last):
  File "/home/blackhc/anaconda3/envs/active_learning/lib/python3.8/multiprocessing/queues.py", line 235, in _feed
    close()
  File "/home/blackhc/anaconda3/envs/active_learning/lib/python3.8/multiprocessing/connection.py", line 177, in close
    self._close()
  File "/home/blackhc/anaconda3/envs/active_learning/lib/python3.8/multiprocessing/connection.py", line 361, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/blackhc/anaconda3/envs/active_learning/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/home/blackhc/anaconda3/envs/active_learning/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/home/blackhc/anaconda3/envs/active_learning/lib/python3.8/multiprocessing/queues.py", line 266,

Perf after training {'accuracy': 0.1, 'crossentropy': tensor(2.3507)}
Creating: TrainSelfDistillationEvalModel(
	num_pool_samples=2,
	training_dataset=<torch.utils.data.dataset.Subset object at 0x7ff8e6e7d490>,
	eval_dataset=Evaluation Set (100 samples),
	validation_loader=<torch.utils.data.dataloader.DataLoader object at 0x7ff8e6e7d3a0>,
	training_batch_size=128,
	trained_model=TrainedBayesianModel(model=BayesianResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): Identity()
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), 

get_predictions_labels:   0%|          | 0/200 [00:00<?, ?it/s]

TypeError: get_trained() missing 1 required keyword-only argument: 'log'

In [None]:
# experiment

experiment = OodExperiment(
    uniform_ood=True,
    id_dataset_name="CIFAR-10",
    ood_dataset_name="SVHN",
    seed=1,
    max_training_epochs=1,
    max_training_set=20 + 10,
    acquisition_function=acquisition_functions.EvalBALD,
    evaluation_set_size=100,
    acquisition_size=10,
    num_pool_samples=2,
    device="cuda",
)

results = {}
experiment.run(results)
results

Creating: ExperimentDataConfig(
	uniform_ood=True,
	id_dataset_name=CIFAR-10,
	ood_dataset_name=SVHN,
	initial_training_set_size=0,
	validation_set_size=1024,
	evaluation_set_size=100,
	id_repetitions=1,
	ood_repetitions=1,
	add_dataset_noise=False,
	validation_split_random_state=0,
	device=cuda
)
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: data/SVHN/train_32x32.mat
Using downloaded and verified file: data/SVHN/train_32x32.mat
Using downloaded and verified file: data/SVHN/test_32x32.mat
Creating: EvalBALD(
	acquisition_size=10
)
Training set size 0:


100%|##########| 1/1 [00:00<?, ?it/s]

Current run is terminating due to exception: integer division or modulo by zero.
Engine run is terminating due to exception: integer division or modulo by zero.


ZeroDivisionError: integer division or modulo by zero

In [None]:
# experiment

print(
    OodExperiment(seed=0, device="cpu", id_dataset_name="SVHN", ood_dataset_name="CIFAR-10", uniform_ood=False)
    .load_experiment_data()
    .active_learning.base_dataset
)
print(
    OodExperiment(seed=1, device="cpu", id_dataset_name="CIFAR-10", ood_dataset_name="SVHN", uniform_ood=True)
    .load_experiment_data()
    .active_learning.base_dataset
)

Creating: ExperimentDataConfig(
	uniform_ood=False,
	id_dataset_name=SVHN,
	ood_dataset_name=CIFAR-10,
	initial_training_set_size=0,
	validation_set_size=1024,
	evaluation_set_size=1024,
	id_repetitions=1,
	ood_repetitions=1,
	add_dataset_noise=False,
	validation_split_random_state=0,
	device=cpu
)
Using downloaded and verified file: data/SVHN/train_32x32.mat
Using downloaded and verified file: data/SVHN/train_32x32.mat
Using downloaded and verified file: data/SVHN/test_32x32.mat
Using downloaded and verified file: data/SVHN/train_32x32.mat
Using downloaded and verified file: data/SVHN/train_32x32.mat
Using downloaded and verified file: data/SVHN/test_32x32.mat
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
(Augmentation Node) + ('CIFAR-10 (Train, seed=0, 50000 samples)' | constant_target{'target': tensor(-1), 'num_classes': 10})
Creating: ExperimentDataConfig(
	uniform_ood=True,
	id_dataset_name=CIFAR-10,
	ood_dataset_

In [None]:
results

{'dataset_info': {'training': "(Augmentation Node | one_hot_targets{'num_classes': 10}) + ('SVHN (Train, seed=0, 73257 samples)' | uniform_targets{'num_classes': 10})",
  'test': "'CIFAR-10 (Test)'"},
 'initial_training_set_indices': [12980,
  44617,
  6984,
  21168,
  33976,
  35571,
  33058,
  43729,
  26944,
  24745,
  66,
  14046,
  46542,
  39478,
  6000,
  5915,
  39360,
  20774,
  27084,
  44464],
 'evaluation_set_indices': [3812,
  42704,
  6729,
  38942,
  48125,
  16968,
  5652,
  4045,
  10740,
  19606,
  37164,
  33354,
  47307,
  17878,
  26665,
  40819,
  14805,
  201,
  47956,
  44739,
  15578,
  36667,
  5551,
  23088,
  32496,
  5705,
  23255,
  25559,
  11975,
  44032,
  47518,
  36303,
  18452,
  34447,
  24821,
  36157,
  48089,
  25120,
  44689,
  6509,
  11001,
  6995,
  10899,
  36881,
  7002,
  19049,
  13388,
  40737,
  9210,
  22684,
  45656,
  5604,
  9134,
  35979,
  19757,
  43627,
  35248,
  23566,
  727,
  34909,
  25443,
  45862,
  30730,
  9611,
  43077

In [None]:
# experiment
experiment = OodExperiment(
    uniform_ood=False,
    id_dataset_type=FastMNIST,
    ood_dataset=FastFashionMNIST,
    seed=1,
    max_training_epochs=1,
    max_training_set=20 + 10,
    acquisition_function=acquisition_functions.BALD,
    evaluation_set_size=100,
    acquisition_size=10,
    num_pool_samples=2,
    temperature=5,
    device="cuda",
)

results = {}
experiment.run(results)
results

Creating: ExperimentDataConfig(
	uniform_ood=False,
	id_dataset_type=<class 'batchbald_redux.fast_mnist.FastMNIST'>,
	ood_dataset_type=<class 'batchbald_redux.fast_mnist.FastFashionMNIST'>,
	initial_training_set_size=20,
	validation_set_size=1024,
	evaluation_set_size=100,
	id_repetitions=1,
	ood_repetitions=1,
	add_dataset_noise=False,
	validation_split_random_state=0,
	device=cuda
)
Creating: BALD(
	acquisition_size=10
)
Training set size 20:


100%|##########| 1/1 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.6328125, 'crossentropy': 2.2057557106018066}
RestoringEarlyStopping: Restoring best parameters. (Score: 0.6328125)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.6557, 'crossentropy': 2.067915026473999}


get_predictions_labels:   0%|          | 0/237712 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/118856 [00:00<?, ?it/s]

Entropy:   0%|          | 0/118856 [00:00<?, ?it/s]

CandidateBatch(scores=[0.6930213868618011, 0.6930111590772867, 0.6930059990845621, 0.6929988877382129, 0.6929858811199665, 0.69298055768013, 0.6929786503314972, 0.6929723918437958, 0.6929657161235809, 0.6929287277162075], indices=[37754, 20802, 5448, 184, 51966, 13296, 26425, 46199, 21745, 4910])
[('id', 46785), ('id', 36269), ('id', 47902), ('id', 26859), ('id', 58402), ('id', 34634), ('id', 46751), ('id', 46535), ('id', 21910), ('id', 5884)]
Acquiring (label, score)s: 0 (0.693), 0 (0.693), 0 (0.693), 0 (0.693), 0 (0.693), 0 (0.693), 0 (0.693), 0 (0.693), 0 (0.693), 0 (0.6929)
Training set size 30:


100%|##########| 1/1 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.611328125, 'crossentropy': 2.4593361616134644}
RestoringEarlyStopping: Restoring best parameters. (Score: 0.611328125)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.6549, 'crossentropy': 2.1096548225402834}
Done.


{'dataset_info': {'training': "(FastMNIST (train; 58976 samples)) + ('FastFashionMNIST Train (60000 samples)' | constant_target{'target': tensor(-1, device='cuda:0'), 'num_classes': 10})",
  'test': "'FastMNIST Test (10000 samples)'"},
 'initial_training_set_indices': [53434,
  8533,
  14640,
  39579,
  30392,
  58125,
  37915,
  3091,
  57520,
  43803,
  44119,
  52296,
  58226,
  40334,
  46037,
  22015,
  22304,
  43812,
  12640,
  53689],
 'evaluation_set_indices': [29974,
  55573,
  35472,
  44048,
  48031,
  5616,
  10110,
  47420,
  56990,
  34198,
  3792,
  5715,
  15969,
  32775,
  19757,
  34588,
  28991,
  47417,
  26501,
  12108,
  5573,
  48032,
  40646,
  43252,
  2404,
  36797,
  29079,
  40018,
  37047,
  41512,
  45567,
  801,
  10664,
  52801,
  42890,
  32972,
  45974,
  20801,
  23496,
  5803,
  10508,
  46870,
  49549,
  306,
  38725,
  13074,
  19689,
  27135,
  16068,
  18137,
  2728,
  43321,
  29950,
  380,
  27254,
  50466,
  31965,
  24052,
  44454,
  20076,


In [None]:
# experiment

experiment = OodExperiment(
    uniform_ood=False,
    id_dataset_type=FastMNIST,
    ood_dataset=FastFashionMNIST,
    seed=1,
    max_training_epochs=1,
    max_training_set=20 + 10,
    acquisition_function=acquisition_functions.EvalBALD,
    evaluation_set_size=100,
    acquisition_size=10,
    num_pool_samples=2,
    temperature=5,
    device="cuda",
)

results = {}
experiment.run(results)
results

Creating: ExperimentDataConfig(
	uniform_ood=False,
	id_dataset_type=<class 'batchbald_redux.fast_mnist.FastMNIST'>,
	ood_dataset_type=<class 'batchbald_redux.fast_mnist.FastFashionMNIST'>,
	initial_training_set_size=20,
	validation_set_size=1024,
	evaluation_set_size=100,
	id_repetitions=1,
	ood_repetitions=1,
	add_dataset_noise=False,
	validation_split_random_state=0,
	device=cuda
)
Creating: EvalBALD(
	acquisition_size=10
)
Training set size 20:


100%|##########| 1/1 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.6328125, 'crossentropy': 2.1983988285064697}
RestoringEarlyStopping: Restoring best parameters. (Score: 0.6328125)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.6559, 'crossentropy': 2.0655145431518553}
Creating: TrainSelfDistillationEvalModel(
	num_pool_samples=2,
	num_training_samples=1,
	num_validation_samples=20,
	num_patience_epochs=3,
	max_epochs=3,
	training_dataset=<torch.utils.data.dataset.Subset object at 0x7fdc2a2f6520>,
	eval_dataset=Evaluation Set (100 samples),
	validation_loader=<torch.utils.data.dataloader.DataLoader object at 0x7fdc2a2f6f40>,
	training_batch_size=64,
	model_optimizer_factory=<class 'batchbald_redux.models.MnistOptimizerFactory'>,
	trained_model=TrainedMCDropoutModel(num_samples=2, model=BayesianMNISTCNN(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv1_drop): ConsistentMCDropout2d(p=0.5)
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): ConsistentMCDropout2d(p=0.5)
  (fc1): Linear(in_features=1024, out_features=128, bias=True)
  (fc1_drop): ConsistentMCDropout(p=0.5)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
))

get_predictions_labels:   0%|          | 0/240 [00:00<?, ?it/s]

 33%|###3      | 1/3 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.5966796875, 'crossentropy': 1.5679776072502136}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.5908203125, 'crossentropy': 1.7015655040740967}
RestoringEarlyStopping: 1 / 3


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.5869140625, 'crossentropy': 1.7359371781349182}
RestoringEarlyStopping: 2 / 3
RestoringEarlyStopping: Restoring best parameters. (Score: 0.5966796875)
RestoringEarlyStopping: Restoring optimizer.


get_predictions_labels:   0%|          | 0/237712 [00:00<?, ?it/s]

get_predictions_labels:   0%|          | 0/237712 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/118856 [00:00<?, ?it/s]

Entropy:   0%|          | 0/118856 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/118856 [00:00<?, ?it/s]

Entropy:   0%|          | 0/118856 [00:00<?, ?it/s]

CandidateBatch(scores=[0.6806451752781868, 0.6790880858898163, 0.6763254124671221, 0.6758237481117249, 0.675022229552269, 0.6749783530831337, 0.6741648018360138, 0.6734838634729385, 0.6730879247188568, 0.6723472699522972], indices=[54855, 56792, 22122, 40241, 10674, 15235, 14226, 26593, 14211, 1276])
[('id', 55525), ('id', 45346), ('id', 8446), ('id', 28278), ('id', 54369), ('id', 51180), ('id', 53366), ('id', 16103), ('id', 13247), ('id', 55629)]
Acquiring (label, score)s: 2 (0.6806), 2 (0.6791), 2 (0.6763), 2 (0.6758), 9 (0.675), 7 (0.675), 2 (0.6742), 9 (0.6735), 9 (0.6731), 2 (0.6723)
Training set size 30:


100%|##########| 1/1 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.6474609375, 'crossentropy': 2.016913414001465}
RestoringEarlyStopping: Restoring best parameters. (Score: 0.6474609375)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.6889, 'crossentropy': 1.9067923236846924}
Done.


{'dataset_info': {'training': "(FastMNIST (train; 58976 samples)) + ('FastFashionMNIST Train (60000 samples)' | constant_target{'target': tensor(-1, device='cuda:0'), 'num_classes': 10})",
  'test': "'FastMNIST Test (10000 samples)'"},
 'initial_training_set_indices': [53434,
  8533,
  14640,
  39579,
  30392,
  58125,
  37915,
  3091,
  57520,
  43803,
  44119,
  52296,
  58226,
  40334,
  46037,
  22015,
  22304,
  43812,
  12640,
  53689],
 'evaluation_set_indices': [29974,
  55573,
  35472,
  44048,
  48031,
  5616,
  10110,
  47420,
  56990,
  34198,
  3792,
  5715,
  15969,
  32775,
  19757,
  34588,
  28991,
  47417,
  26501,
  12108,
  5573,
  48032,
  40646,
  43252,
  2404,
  36797,
  29079,
  40018,
  37047,
  41512,
  45567,
  801,
  10664,
  52801,
  42890,
  32972,
  45974,
  20801,
  23496,
  5803,
  10508,
  46870,
  49549,
  306,
  38725,
  13074,
  19689,
  27135,
  16068,
  18137,
  2728,
  43321,
  29950,
  380,
  27254,
  50466,
  31965,
  24052,
  44454,
  20076,


In [None]:
# exports

configs = [
    UnifiedExperiment(
        seed=seed + 1234,
        ood_exposure=ood_exposure,
        acquisition_function=acquisition_function,
        acquisition_size=5,
        num_pool_samples=num_pool_samples,
        evaluation_set_size=evaluation_set_size,
        id_dataset_name=id_dataset_name,
        ood_dataset_name=ood_dataset_name,
    )
    for seed in range(3)
    for acquisition_function in [acquisition_functions.BatchEvalBALD, acquisition_functions.BatchBALD]
    for evaluation_set_size in [1024]
    for num_pool_samples in [100]
    for ood_exposure in [True, False]
    for id_dataset_name, ood_dataset_name in [("CIFAR-10", "SVHN"), ("SVHN", "CIFAR-10")]
]

if not is_run_from_ipython() and __name__ == "__main__":
    for job_id, store in embedded_experiments(__file__, len(configs)):
        config = configs[job_id]
        config.seed += job_id
        print(config)
        store["config"] = dataclasses.asdict(config)
        store["log"] = {}

        try:
            config.run(store=store)
        except Exception:
            store["exception"] = traceback.format_exc()
            raise

In [None]:
len(configs)

24

In [None]:
# slow
import prettyprinter

prettyprinter.install_extras(include={"dataclasses"})

prettyprinter.pprint(configs)

[
    OodExperiment(
        seed=1234,
        uniform_ood=True,
        id_dataset_name='CIFAR-10',
        ood_dataset_name='SVHN',
        num_pool_samples=100,
        # class
        acquisition_function=batchbald_redux.acquisition_functions.BatchEvalBALD
    ),
    OodExperiment(
        seed=1234,
        uniform_ood=True,
        id_dataset_name='SVHN',
        ood_dataset_name='CIFAR-10',
        num_pool_samples=100,
        # class
        acquisition_function=batchbald_redux.acquisition_functions.BatchEvalBALD
    ),
    OodExperiment(
        seed=1234,
        uniform_ood=False,
        id_dataset_name='CIFAR-10',
        ood_dataset_name='SVHN',
        num_pool_samples=100,
        # class
        acquisition_function=batchbald_redux.acquisition_functions.BatchEvalBALD
    ),
    OodExperiment(
        seed=1234,
        uniform_ood=False,
        id_dataset_name='SVHN',
        ood_dataset_name='CIFAR-10',
        num_pool_samples=100,
        # class
        acquisit