# Uniform Target OOD Experiment
> Can we get better by training on our assumptions?

In [None]:
# default_exp uniform_target_ood_experiment

In [None]:
# hide
import blackhc.project.script

Appended /home/blackhc/PycharmProjects/bald-ical/src to paths
Switched to directory /home/blackhc/PycharmProjects/bald-ical
%load_ext autoreload
%autoreload 2


Import modules and functions were are going to use.

In [None]:
# exports

import dataclasses
import traceback
from dataclasses import dataclass
from typing import Type, Union

import torch
import torch.utils.data
from blackhc.project import is_run_from_ipython
from blackhc.project.experiment import embedded_experiments
from torch.utils.data import Dataset

import batchbald_redux.acquisition_functions as acquisition_functions
from batchbald_redux.acquisition_functions import (
    CandidateBatchComputer,
    EvalCandidateBatchComputer,
)
from batchbald_redux.active_learning import ActiveLearningData, RandomFixedLengthSampler
from batchbald_redux.black_box_model_training import evaluate, train
from batchbald_redux.dataset_challenges import (
    AdditiveGaussianNoise,
    AliasDataset,
    NamedDataset,
    get_balanced_sample_indices,
    get_base_dataset_index,
    get_target,
)
from batchbald_redux.datasets import train_validation_split
from batchbald_redux.di import DependencyInjection
from batchbald_redux.fast_mnist import FastFashionMNIST, FastMNIST
from batchbald_redux.model_optimizer_factory import ModelOptimizerFactory
from batchbald_redux.models import MnistOptimizerFactory
from batchbald_redux.train_eval_model import (
    TrainEvalModel,
    TrainSelfDistillationPoolModel,
)
from batchbald_redux.trained_model import TrainedMCDropoutModel

In [None]:
# exports


@dataclass
class ExperimentDatasets:
    active_learning: ActiveLearningData
    ood_dataset: NamedDataset
    validation_dataset: torch.utils.data.Dataset
    test_dataset: torch.utils.data.Dataset
    initial_training_set_indices: [int]


@dataclass
class UniformTargetOodExperiment:
    seed: int = 1337
    acquisition_size: int = 5
    max_training_set: int = 450
    num_pool_samples: int = 20
    num_eval_samples: int = 20
    num_training_samples: int = 1
    num_patience_epochs: int = 3
    max_training_epochs: int = 30
    training_batch_size: int = 64
    device: str = "cuda"
    validation_set_size: int = 1024
    validation_split_random_state: int = 0
    initial_training_set_size: int = 20
    samples_per_epoch: int = 5056
    mnist_repetitions: float = 1
    ood_fmnist_repetitions: float = 1
    add_dataset_noise: bool = False
    acquisition_function: Union[
        Type[CandidateBatchComputer], Type[EvalCandidateBatchComputer]
    ] = acquisition_functions.BALD
    train_eval_model: TrainEvalModel = TrainSelfDistillationPoolModel
    model_optimizer_factory: Type[ModelOptimizerFactory] = MnistOptimizerFactory
    acquisition_function_args: dict = None
    temperature: float = 0.0

    def load_dataset(self, initial_training_set_size) -> (ActiveLearningData, Dataset, Dataset):
        # num_classes = 10, input_size = 28
        train_dataset = NamedDataset(
            FastMNIST("data", train=True, download=True, device=self.device), "FastMNIST (train)"
        )

        ood_dataset = FastFashionMNIST("data", train=True, download=True, device=self.device)
        ood_dataset = NamedDataset(ood_dataset, f"OoD Dataset ({len(ood_dataset)} samples)")
        if self.ood_fmnist_repetitions > 1:
            ood_dataset = ood_dataset * self.ood_fmnist_repetitions

        train_dataset, validation_dataset = train_validation_split(
            full_train_dataset=train_dataset,
            full_validation_dataset=train_dataset,
            train_labels=train_dataset.get_targets().cpu(),
            validation_set_size=self.validation_set_size,
            validation_split_random_state=self.validation_split_random_state,
        )

        train_dataset = AliasDataset(train_dataset, f"FastMNIST (train; {len(train_dataset)} samples)")
        validation_dataset = AliasDataset(
            validation_dataset, f"FastMNIST (validation; {len(validation_dataset)} samples)"
        )
        
        # If we reduce the train set, we need to do so before picking the initial train set.
        if self.mnist_repetitions < 1:
            train_dataset = train_dataset * self.mnist_repetitions

        num_classes = train_dataset.get_num_classes()
        samples_per_class = initial_training_set_size / num_classes
        initial_training_set_indices = get_balanced_sample_indices(
            train_dataset,
            num_classes=num_classes,
            samples_per_class=samples_per_class,
            seed=self.validation_split_random_state,
        )
        
        # If we over-sample the train set, we do so after picking the initial train set to avoid duplicates.
        if self.mnist_repetitions > 1:
            train_dataset = train_dataset * self.mnist_repetitions

        train_dataset = train_dataset.one_hot(device=self.device) + ood_dataset.uniform_target(device=self.device)

        if self.add_dataset_noise:
            train_dataset = AdditiveGaussianNoise(train_dataset, 0.1)

        test_dataset = FastMNIST("data", train=False, device=None)
        test_dataset = NamedDataset(test_dataset, f"FastMNIST (test, {len(test_dataset)} samples)")

        active_learning_data = ActiveLearningData(train_dataset)

        active_learning_data.acquire(initial_training_set_indices)

        return ExperimentDatasets(
            active_learning=active_learning_data,
            validation_dataset=validation_dataset,
            test_dataset=test_dataset,
            initial_training_set_indices=initial_training_set_indices,
            ood_dataset=ood_dataset,
        )

    # Simple Dependency Injection
    def create_acquisition_function(self):
        di = DependencyInjection(vars(self))
        return di.create_dataclass_type(self.acquisition_function)

    def create_train_eval_model(self, runtime_config) -> TrainEvalModel:
        config = {**vars(self), **runtime_config}
        di = DependencyInjection(config, [])
        return di.create_dataclass_type(self.train_eval_model)

    def run(self, store):
        torch.manual_seed(self.seed)

        # Active Learning setup
        data = self.load_dataset(self.initial_training_set_size)
        store["dataset_info"] = dict(training=repr(data.active_learning.base_dataset), test=repr(data.test_dataset))
        store["initial_training_set_indices"] = data.initial_training_set_indices

        # initial_training_set_indices = data.active_learning.get_random_pool_indices(self.initial_set_size)
        # initial_training_set_indices = get_balanced_sample_indices(
        #     data.active_learning.pool_dataset, 10, self.initial_set_size // 10
        # )

        train_loader = torch.utils.data.DataLoader(
            data.active_learning.training_dataset,
            batch_size=64,
            sampler=RandomFixedLengthSampler(data.active_learning.training_dataset, self.samples_per_epoch),
            drop_last=True,
        )
        pool_loader = torch.utils.data.DataLoader(
            data.active_learning.pool_dataset, batch_size=128, drop_last=False, shuffle=False
        )

        validation_loader = torch.utils.data.DataLoader(data.validation_dataset, batch_size=512, drop_last=False)
        test_loader = torch.utils.data.DataLoader(data.test_dataset, batch_size=512, drop_last=False)

        store["active_learning_steps"] = []
        active_learning_steps = store["active_learning_steps"]

        acquisition_function = self.create_acquisition_function()

        # Active Training Loop
        while True:
            training_set_size = len(data.active_learning.training_dataset)
            print(f"Training set size {training_set_size}:")

            # iteration_log = dict(training={}, pool_training={}, evaluation_metrics=None, acquisition=None)
            active_learning_steps.append({})
            iteration_log = active_learning_steps[-1]

            iteration_log["training"] = {}

            model_optimizer = self.model_optimizer_factory().create_model_optimizer()

            loss = torch.nn.KLDivLoss(log_target=False, reduction="batchmean")

            train(
                model=model_optimizer.model,
                optimizer=model_optimizer.optimizer,
                training_samples=self.num_training_samples,
                validation_samples=self.num_eval_samples,
                train_loader=train_loader,
                validation_loader=validation_loader,
                patience=self.num_patience_epochs,
                max_epochs=self.max_training_epochs,
                device=self.device,
                training_log=iteration_log["training"],
                loss=loss,
                validation_loss=torch.nn.NLLLoss(),
            )

            evaluation_metrics = evaluate(
                model=model_optimizer.model, num_samples=self.num_eval_samples, loader=test_loader, device=self.device
            )
            iteration_log["evaluation_metrics"] = evaluation_metrics
            print(f"Perf after training {evaluation_metrics}")

            if training_set_size >= self.max_training_set:
                print("Done.")
                break

            trained_model = TrainedMCDropoutModel(num_pool_samples=self.num_pool_samples, model=model_optimizer.model)

            if isinstance(acquisition_function, CandidateBatchComputer):
                candidate_batch = acquisition_function.compute_candidate_batch(trained_model, pool_loader, self.device)
            elif isinstance(acquisition_function, EvalCandidateBatchComputer):
                current_max_epochs = iteration_log["training"]["best_epoch"]

                train_eval_model = self.create_train_eval_model(
                    dict(
                        max_epochs=current_max_epochs,
                        training_dataset=data.active_learning.training_dataset,
                        pool_dataset=data.active_learning.pool_dataset,
                        validation_loader=validation_loader,
                        trained_model=trained_model,
                    )
                )

                iteration_log["eval_training"] = {}
                trained_eval_model = train_eval_model(training_log=iteration_log["eval_training"], device=self.device)

                candidate_batch = acquisition_function.compute_candidate_batch(
                    trained_model, trained_eval_model, pool_loader, device=self.device
                )
            else:
                raise ValueError(f"Unknown acquisition function {acquisition_function}!")

            candidate_global_dataset_indices = []
            candidate_labels = []
            for index in candidate_batch.indices:
                base_di = get_base_dataset_index(data.active_learning.pool_dataset, index)
                dataset_type = "ood" if base_di.dataset == data.ood_dataset else "id"
                candidate_global_dataset_indices.append((dataset_type, base_di.index))
                label = get_target(data.active_learning.pool_dataset, index)
                candidate_labels.append(label)

            iteration_log["acquisition"] = dict(
                indices=candidate_global_dataset_indices, labels=candidate_labels, scores=candidate_batch.scores
            )

            data.active_learning.acquire(candidate_batch.indices)

            print(candidate_batch)
            print(candidate_global_dataset_indices)

            ls = ", ".join(f"{label} ({score:.4})" for label, score in zip(candidate_labels, candidate_batch.scores))
            print(f"Acquiring (label, score)s: {ls}")

In [None]:
# experiment

UniformTargetOodExperiment(mnist_repetitions=0.1,device="cpu").load_dataset(20).active_learning.base_dataset

((FastMNIST (train; 58976 samples))~x0.1 | one_hot_targets{'num_classes': 10}) + ('OoD Dataset (60000 samples)' | uniform_targets{'num_classes': 10})

In [None]:
# experiment

experiment = UniformTargetOodExperiment(
    seed=1120,
    max_training_epochs=30,
    num_patience_epochs=5,
    max_training_set=130,
    acquisition_function=acquisition_functions.TemperedBALD,
    acquisition_size=10,
    num_pool_samples=20,
    temperature=8,
    mnist_repetitions=1,
    device="cuda",
)

results = {}
experiment.run(results)

Resolved: TemperedBALD with {'acquisition_size': 10, 'temperature': 8}
Creating: TemperedBALD(acquisition_size=10,temperature=8)
Training set size 20:


  3%|3         | 1/30 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.6552734375, 'crossentropy': 2.281791090965271}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.6396484375, 'crossentropy': 2.764202117919922}
RestoringEarlyStopping: 1 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.65234375, 'crossentropy': 2.7625155448913574}
RestoringEarlyStopping: 2 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.640625, 'crossentropy': 3.117473602294922}
RestoringEarlyStopping: 3 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.654296875, 'crossentropy': 2.7298072576522827}
RestoringEarlyStopping: 4 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.6376953125, 'crossentropy': 2.7832967042922974}
RestoringEarlyStopping: 5 / 5
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -2.281791090965271)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.686, 'crossentropy': 1.9357322483062744}


get_predictions_labels:   0%|          | 0/2379120 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/118956 [00:00<?, ?it/s]

Entropy:   0%|          | 0/118956 [00:00<?, ?it/s]

CandidateBatch(scores=[0.7525752484798431, 0.9461209774017334, 0.7917467355728149, 0.7990282773971558, 0.9037496745586395, 0.9071029424667358, 0.42743825912475586, 0.844712495803833, 0.8469107151031494, 0.64350426197052], indices=[58755, 49156, 58931, 39220, 44700, 36190, 101609, 48457, 9998, 8542])
[('id', 40198), ('id', 19315), ('id', 6432), ('id', 4909), ('id', 20280), ('id', 47039), ('ood', 42653), ('id', 41721), ('id', 44678), ('id', 42857)]
Acquiring (label, score)s: tensor([0., 0., 1., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0') (0.7526), tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1.], device='cuda:0') (0.9461), tensor([0., 0., 0., 0., 0., 0., 0., 1., 0., 0.], device='cuda:0') (0.7917), tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0.], device='cuda:0') (0.799), tensor([0., 0., 0., 0., 0., 0., 0., 1., 0., 0.], device='cuda:0') (0.9037), tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.], device='cuda:0') (0.9071), tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1

  3%|3         | 1/30 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.626953125, 'crossentropy': 1.9055257439613342}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.634765625, 'crossentropy': 2.397943377494812}
RestoringEarlyStopping: 1 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.65234375, 'crossentropy': 2.5055283308029175}
RestoringEarlyStopping: 2 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.64453125, 'crossentropy': 2.761470317840576}
RestoringEarlyStopping: 3 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.6455078125, 'crossentropy': 2.6420114040374756}
RestoringEarlyStopping: 4 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.6318359375, 'crossentropy': 2.841150403022766}
RestoringEarlyStopping: 5 / 5
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -1.9055257439613342)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.6721, 'crossentropy': 1.6736589221954346}


get_predictions_labels:   0%|          | 0/2378920 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/118946 [00:00<?, ?it/s]

Entropy:   0%|          | 0/118946 [00:00<?, ?it/s]

CandidateBatch(scores=[0.6050038933753967, 0.9222344160079956, 0.8689695596694946, 0.7768614292144775, 0.8677030205726624, 0.5678052306175232, 0.690179705619812, 0.6932134628295898, 0.5644712448120117, 0.46727848052978516], indices=[38397, 23638, 17756, 32259, 19761, 8146, 14276, 37799, 57707, 10893])
[('id', 31094), ('id', 24726), ('id', 36589), ('id', 46895), ('id', 15855), ('id', 4251), ('id', 18446), ('id', 23332), ('id', 13003), ('id', 37256)]
Acquiring (label, score)s: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1.], device='cuda:0') (0.605), tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], device='cuda:0') (0.9222), tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0') (0.869), tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.], device='cuda:0') (0.7769), tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.], device='cuda:0') (0.8677), tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], device='cuda:0') (0.5678), tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.], device='cuda:0') 

  3%|3         | 1/30 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.67578125, 'crossentropy': 1.7049498558044434}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.6923828125, 'crossentropy': 2.062513589859009}
RestoringEarlyStopping: 1 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.69921875, 'crossentropy': 2.0917270183563232}
RestoringEarlyStopping: 2 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.70703125, 'crossentropy': 2.3297260999679565}
RestoringEarlyStopping: 3 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.69921875, 'crossentropy': 2.459813952445984}
RestoringEarlyStopping: 4 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7109375, 'crossentropy': 2.6445037126541138}
RestoringEarlyStopping: 5 / 5
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -1.7049498558044434)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.7102, 'crossentropy': 1.4282880615234375}


get_predictions_labels:   0%|          | 0/2378720 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/118936 [00:00<?, ?it/s]

Entropy:   0%|          | 0/118936 [00:00<?, ?it/s]

CandidateBatch(scores=[0.8289439082145691, 0.8253107666969299, 0.7205442786216736, 0.8806598782539368, 0.6559193134307861, 0.7009842991828918, 0.6863618493080139, 0.7819263935089111, 0.6963304281234741, 0.5840955376625061], indices=[56998, 57653, 19320, 56832, 1840, 56163, 37562, 48009, 41350, 44307])
[('id', 17404), ('id', 25960), ('id', 57034), ('id', 37858), ('id', 47057), ('id', 55698), ('id', 43363), ('id', 43747), ('id', 50343), ('id', 37465)]
Acquiring (label, score)s: tensor([0., 0., 0., 1., 0., 0., 0., 0., 0., 0.], device='cuda:0') (0.8289), tensor([0., 0., 0., 0., 1., 0., 0., 0., 0., 0.], device='cuda:0') (0.8253), tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0') (0.7205), tensor([0., 0., 0., 0., 1., 0., 0., 0., 0., 0.], device='cuda:0') (0.8807), tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0.], device='cuda:0') (0.6559), tensor([0., 0., 0., 1., 0., 0., 0., 0., 0., 0.], device='cuda:0') (0.701), tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0.], device='cuda:0'

  3%|3         | 1/30 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.72265625, 'crossentropy': 1.5206688046455383}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.77734375, 'crossentropy': 1.4437467455863953}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7763671875, 'crossentropy': 1.6880419254302979}
RestoringEarlyStopping: 1 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.787109375, 'crossentropy': 1.7344599962234497}
RestoringEarlyStopping: 2 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7705078125, 'crossentropy': 1.8612591624259949}
RestoringEarlyStopping: 3 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.78515625, 'crossentropy': 1.7511099576950073}
RestoringEarlyStopping: 4 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.76953125, 'crossentropy': 2.021477460861206}
RestoringEarlyStopping: 5 / 5
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -1.4437467455863953)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.7906, 'crossentropy': 1.3133127180099486}


get_predictions_labels:   0%|          | 0/2378520 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/118926 [00:00<?, ?it/s]

Entropy:   0%|          | 0/118926 [00:00<?, ?it/s]

CandidateBatch(scores=[0.5788464546203613, 0.8498013615608215, 0.8048466444015503, 0.5830477476119995, 0.8645184636116028, 0.4955335855484009, 0.620869517326355, 0.8390311002731323, 0.447624146938324, 0.7256765365600586], indices=[73851, 13781, 1281, 24388, 26778, 75920, 31513, 7269, 67994, 5689])
[('ood', 14924), ('id', 41013), ('id', 55629), ('id', 13634), ('id', 22281), ('ood', 16993), ('id', 32425), ('id', 24853), ('ood', 9067), ('id', 24773)]
Acquiring (label, score)s: tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
        0.1000], device='cuda:0') (0.5788), tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0.], device='cuda:0') (0.8498), tensor([0., 0., 1., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0') (0.8048), tensor([0., 0., 0., 0., 1., 0., 0., 0., 0., 0.], device='cuda:0') (0.583), tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0.], device='cuda:0') (0.8645), tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
        0.1000

  3%|3         | 1/30 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Engine run is terminating due to exception: .


KeyboardInterrupt: 

In [None]:
# experiment

results

{'dataset_info': {'training': "(FastMNIST (train; 58976 samples)) + ('OoD Dataset (60000 samples)' | constant_target{'target': tensor(-1, device='cuda:0'), 'num_classes': 10})",
  'test': "'FastMNIST (test, 10000 samples)'"},
 'initial_training_set_indices': [30392,
  53434,
  12640,
  8533,
  22304,
  37915,
  58226,
  44119,
  3091,
  14640,
  58125,
  39579,
  43812,
  53689,
  52296,
  46037,
  22015,
  40334,
  57520,
  43803],
 'active_learning_steps': [{'training': {'epochs': [], 'best_epoch': None},
   'evaluation_metrics': {'accuracy': 0.1129,
    'crossentropy': 2.35244740562439},
   'acquisition': {'indices': [('ood', 27822),
     ('id', 32367),
     ('id', 55086),
     ('id', 53929),
     ('id', 48696),
     ('ood', 37815),
     ('ood', 47304),
     ('ood', 28667),
     ('ood', 40866),
     ('ood', 35119)],
    'labels': [-1, 1, 7, 6, 2, -1, -1, -1, -1, -1],
    'scores': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}},
  {'training': {'epochs': [], 'best_epoch': None}

In [None]:
# experiment

experiment = Experiment(
    max_training_epochs=1, max_training_set=25, acquisition_function=AcquisitionFunction.randombaldical
)

results = {}
experiment.run(results)

results

Training set size 20:


100%|##########| 1/1 [00:00<?, ?it/s]

[1/384]   0%|           [00:00<?]

[1/64]   2%|1          [00:00<?]

RestoringEarlyStopping: Restoring best parameters. (Score: -6.529030114412308)
RestoringEarlyStopping: Restoring optimizer.


[1/157]   1%|           [00:00<?]

Perf after training {'accuracy': 0.5367, 'crossentropy': 6.438035237884521}


get_predictions:   0%|          | 0/463616 [00:00<?, ?it/s]

100%|##########| 1/1 [00:00<?, ?it/s]

[1/1811]   0%|           [00:00<?]

[1/64]   2%|1          [00:00<?]

RestoringEarlyStopping: Restoring best parameters. (Score: -5.1637596152722836)
RestoringEarlyStopping: Restoring optimizer.


get_predictions:   0%|          | 0/2317680 [00:00<?, ?it/s]

get_predictions:   0%|          | 0/2317680 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Acquiring (label, score)s: 8 (0.8711), 8 (0.8687), 3 (0.876), 3 (0.8465), 3 (0.8811)
Training set size 25:


100%|##########| 1/1 [00:00<?, ?it/s]

[1/384]   0%|           [00:00<?]

[1/64]   2%|1          [00:00<?]

RestoringEarlyStopping: Restoring best parameters. (Score: -4.6851686127483845)
RestoringEarlyStopping: Restoring optimizer.


[1/157]   1%|           [00:00<?]

Perf after training {'accuracy': 0.6256, 'crossentropy': 4.484497045135498}
Done.


{'initial_training_set_indices': [38043,
  40091,
  17418,
  2094,
  39879,
  3133,
  5011,
  40683,
  54379,
  24287,
  9849,
  59305,
  39508,
  39356,
  8758,
  52579,
  13655,
  7636,
  21562,
  41329],
 'active_learning_steps': [{'training': {'epochs': [{'accuracy': 0.538818359375,
      'crossentropy': 6.529030114412308}],
    'best_epoch': 1},
   'evalution_metrics': {'accuracy': 0.5367,
    'crossentropy': 6.438035237884521},
   'pool_training': {'epochs': [{'accuracy': 0.531005859375,
      'crossentropy': 5.1637596152722836}],
    'best_epoch': 1},
   'acquisition': {'indices': [63338, 10856, 63452, 81864, 109287],
    'labels': [8, 8, 3, 3, 3],
    'scores': [0.8710822958846325,
     0.8687216999221631,
     0.8759664372823723,
     0.8464646732511746,
     0.8810812784952251]}},
  {'training': {'epochs': [{'accuracy': 0.62255859375,
      'crossentropy': 4.6851686127483845}],
    'best_epoch': 1},
   'evalution_metrics': {'accuracy': 0.6256,
    'crossentropy': 4.4844970451

In [None]:
# exports

configs = [
    UniformTargetOodExperiment(
        seed=seed,
        acquisition_function=acquisition_functions.TemperedBALD,
        acquisition_size=acquisition_size,
        num_pool_samples=num_pool_samples,
        temperature=8,
    )
    for seed in range(5)
    for acquisition_size in [5, 10, 20, 50]
    for num_pool_samples in [100]
] + [
    UniformTargetOodExperiment(
        seed=seed,
        acquisition_function=acquisition_functions.Random,
        acquisition_size=5,
    )
    for seed in range(20)
]

if not is_run_from_ipython() and __name__ == "__main__":
    for job_id, store in embedded_experiments(__file__, len(configs)):
        config = configs[job_id]
        config.seed += job_id
        print(config)
        store["config"] = dataclasses.asdict(config)
        store["log"] = {}

        try:
            config.run(store=store)
        except Exception:
            store["exception"] = traceback.format_exc()
            raise

In [None]:
len(configs)

40

In [None]:
configs

[RejectionOodExperiment(seed=0, acquisition_size=5, max_training_set=450, num_pool_samples=50, num_eval_samples=20, num_training_samples=1, num_patience_epochs=3, max_training_epochs=30, training_batch_size=64, device='cuda', validation_set_size=1024, validation_split_random_state=0, initial_training_set_size=20, samples_per_epoch=5056, mnist_repetitions=1, ood_fmnist_repetitions=1, add_dataset_noise=False, acquisition_function=<class 'batchbald_redux.acquisition_functions.TemperedBALD'>, train_eval_model=<class 'batchbald_redux.train_eval_model.TrainSelfDistillationPoolModel'>, model_optimizer_factory=<class 'batchbald_redux.models.MnistOptimizerFactory'>, acquisition_function_args=None, temperature=8),
 RejectionOodExperiment(seed=0, acquisition_size=10, max_training_set=450, num_pool_samples=50, num_eval_samples=20, num_training_samples=1, num_patience_epochs=3, max_training_epochs=30, training_batch_size=64, device='cuda', validation_set_size=1024, validation_split_random_state=0, 