# Rejection OOD Experiment
> Can we get better by training on our assumptions?

In [None]:
# default_exp rejection_ood_experiment

In [None]:
# hide
import blackhc.project.script

Appended /home/blackhc/PycharmProjects/bald-ical/src to paths
Switched to directory /home/blackhc/PycharmProjects/bald-ical
%load_ext autoreload
%autoreload 2


Import modules and functions were are going to use.

In [None]:
# exports

import dataclasses
import traceback
from dataclasses import dataclass
from typing import Type, Union

import torch
import torch.utils.data
from blackhc.project import is_run_from_ipython
from blackhc.project.experiment import embedded_experiments
from torch.utils.data import Dataset

import batchbald_redux.acquisition_functions as acquisition_functions
from batchbald_redux.acquisition_functions import (
    CandidateBatchComputer,
    EvalCandidateBatchComputer,
)
from batchbald_redux.active_learning import ActiveLearningData, RandomFixedLengthSampler
from batchbald_redux.black_box_model_training import evaluate, train
from batchbald_redux.dataset_challenges import (
    AdditiveGaussianNoise,
    AliasDataset,
    NamedDataset,
    get_base_dataset_index,
    get_target,
    get_balanced_sample_indices_by_class,
)
from batchbald_redux.datasets import train_validation_split
from batchbald_redux.di import DependencyInjection
from batchbald_redux.fast_mnist import FastFashionMNIST, FastMNIST
from batchbald_redux.model_optimizer_factory import ModelOptimizerFactory
from batchbald_redux.models import MnistOptimizerFactory
from batchbald_redux.train_eval_model import (
    TrainEvalModel,
    TrainSelfDistillationEvalModel,
)
from batchbald_redux.trained_model import TrainedMCDropoutModel

In [None]:
# exports


@dataclass
class ExperimentData:
    active_learning: ActiveLearningData
    ood_dataset: NamedDataset
    validation_dataset: Dataset
    test_dataset: Dataset
    evaluation_dataset: Dataset
    initial_training_set_indices: [int]
    evaluation_set_indices: [int]


@dataclass
class RejectionOodExperiment:
    seed: int = 1337
    acquisition_size: int = 5
    max_training_set: int = 450
    num_pool_samples: int = 20
    num_validation_samples: int = 20
    num_training_samples: int = 1
    num_patience_epochs: int = 3
    max_training_epochs: int = 30
    training_batch_size: int = 64
    device: str = "cuda"
    validation_set_size: int = 1024
    evaluation_set_size: int = 10 * 10
    validation_split_random_state: int = 0
    initial_training_set_size: int = 20
    samples_per_epoch: int = 5056
    mnist_repetitions: float = 1
    add_dataset_noise: bool = False
    acquisition_function: Union[
        Type[CandidateBatchComputer], Type[EvalCandidateBatchComputer]
    ] = acquisition_functions.BALD
    train_eval_model: TrainEvalModel = TrainSelfDistillationEvalModel
    model_optimizer_factory: Type[ModelOptimizerFactory] = MnistOptimizerFactory
    acquisition_function_args: dict = None
    temperature: float = 0.0

    def load_experiment_data(self) -> ExperimentData:
        # num_classes = 10, input_size = 28
        full_train_dataset = NamedDataset(
            FastMNIST("data", train=True, download=True, device=self.device), "FastMNIST (train)"
        )
        ood_dataset = FastFashionMNIST("data", train=True, download=True, device=self.device)
        ood_dataset = NamedDataset(ood_dataset, f"OoD Dataset ({len(ood_dataset)} samples)")

        train_dataset, validation_dataset = train_validation_split(
            full_train_dataset=full_train_dataset,
            full_validation_dataset=full_train_dataset,
            train_labels=full_train_dataset.get_targets().cpu(),
            validation_set_size=self.validation_set_size,
            validation_split_random_state=self.validation_split_random_state,
        )

        train_dataset = AliasDataset(train_dataset, f"FastMNIST (train; {len(train_dataset)} samples)")
        validation_dataset = AliasDataset(
            validation_dataset, f"FastMNIST (validation; {len(validation_dataset)} samples)"
        )

        # If we reduce the train set, we need to do so before picking the initial train set.
        if self.mnist_repetitions < 1:
            train_dataset = train_dataset * self.mnist_repetitions

        num_classes = train_dataset.get_num_classes()
        initial_samples_per_class = self.initial_training_set_size / num_classes
        evaluation_set_samples_per_class = self.evaluation_set_size / num_classes
        samples_per_class = initial_samples_per_class + evaluation_set_samples_per_class
        balanced_samples_indices = get_balanced_sample_indices_by_class(
            train_dataset,
            num_classes=num_classes,
            samples_per_class=samples_per_class,
            seed=self.validation_split_random_state,
        )
        initial_training_set_indices = [
            idx for by_class in balanced_samples_indices for idx in by_class[:initial_samples_per_class]
        ]
        evaluation_set_indices = [
            idx for by_class in balanced_samples_indices for idx in by_class[initial_samples_per_class:]
        ]

        # If we over-sample the train set, we do so after picking the initial train set to avoid duplicates.
        if self.mnist_repetitions > 1:
            train_dataset = train_dataset * self.mnist_repetitions

        train_dataset = train_dataset + ood_dataset.constant_target(
            target=torch.tensor(-1, device=self.device), num_classes=train_dataset.get_num_classes()
        )

        if self.add_dataset_noise:
            train_dataset = AdditiveGaussianNoise(train_dataset, 0.1)

        test_dataset = FastMNIST("data", train=False, device=None)
        test_dataset = NamedDataset(test_dataset, f"FastMNIST (test, {len(test_dataset)} samples)")

        active_learning_data = ActiveLearningData(train_dataset)

        active_learning_data.acquire_base_indices(initial_training_set_indices)
        evaluation_dataset = active_learning_data.extract_dataset_from_base_indices(evaluation_set_indices)

        return ExperimentData(
            active_learning=active_learning_data,
            ood_dataset=ood_dataset,
            validation_dataset=validation_dataset,
            test_dataset=test_dataset,
            evaluation_dataset=evaluation_dataset,
            initial_training_set_indices=initial_training_set_indices,
            evaluation_set_indices=evaluation_set_indices,
        )

    # Simple Dependency Injection
    def create_acquisition_function(self):
        di = DependencyInjection(vars(self))
        return di.create_dataclass_type(self.acquisition_function)

    def create_train_eval_model(self, runtime_config) -> TrainEvalModel:
        config = {**vars(self), **runtime_config}
        di = DependencyInjection(config, [])
        return di.create_dataclass_type(self.train_eval_model)

    def run(self, store):
        torch.manual_seed(self.seed)

        # Active Learning setup
        data = self.load_experiment_data()
        store["dataset_info"] = dict(training=repr(data.active_learning.base_dataset), test=repr(data.test_dataset))
        store["initial_training_set_indices"] = data.initial_training_set_indices
        store["evaluation_set_indices"] = data.evaluation_set_indices

        train_loader = torch.utils.data.DataLoader(
            data.active_learning.training_dataset,
            batch_size=64,
            sampler=RandomFixedLengthSampler(data.active_learning.training_dataset, self.samples_per_epoch),
            drop_last=True,
        )
        pool_loader = torch.utils.data.DataLoader(
            data.active_learning.pool_dataset, batch_size=128, drop_last=False, shuffle=False
        )

        validation_loader = torch.utils.data.DataLoader(data.validation_dataset, batch_size=512, drop_last=False)
        test_loader = torch.utils.data.DataLoader(data.test_dataset, batch_size=512, drop_last=False)

        store["active_learning_steps"] = []
        active_learning_steps = store["active_learning_steps"]

        acquisition_function = self.create_acquisition_function()

        # Active Training Loop
        while True:
            training_set_size = len(data.active_learning.training_dataset)
            print(f"Training set size {training_set_size}:")

            # iteration_log = dict(training={}, pool_training={}, evaluation_metrics=None, acquisition=None)
            active_learning_steps.append({})
            iteration_log = active_learning_steps[-1]

            iteration_log["training"] = {}

            model_optimizer = self.model_optimizer_factory().create_model_optimizer()

            train(
                model=model_optimizer.model,
                optimizer=model_optimizer.optimizer,
                training_samples=self.num_training_samples,
                validation_samples=self.num_validation_samples,
                train_loader=train_loader,
                validation_loader=validation_loader,
                patience=self.num_patience_epochs,
                max_epochs=self.max_training_epochs,
                device=self.device,
                training_log=iteration_log["training"],
            )

            evaluation_metrics = evaluate(
                model=model_optimizer.model,
                num_samples=self.num_validation_samples,
                loader=test_loader,
                device=self.device,
            )
            iteration_log["evaluation_metrics"] = evaluation_metrics
            print(f"Perf after training {evaluation_metrics}")

            if training_set_size >= self.max_training_set:
                print("Done.")
                break

            trained_model = TrainedMCDropoutModel(num_samples=self.num_pool_samples, model=model_optimizer.model)

            if isinstance(acquisition_function, CandidateBatchComputer):
                candidate_batch = acquisition_function.compute_candidate_batch(trained_model, pool_loader, self.device)
            elif isinstance(acquisition_function, EvalCandidateBatchComputer):
                current_max_epochs = iteration_log["training"]["best_epoch"]

                if self.evaluation_set_size:
                    eval_dataset = data.active_learning.evaluation_dataset
                else:
                    eval_dataset = data.active_learning.pool_dataset

                train_eval_model = self.create_train_eval_model(
                    dict(
                        max_epochs=current_max_epochs,
                        training_dataset=data.active_learning.training_dataset,
                        eval_dataset=eval_dataset,
                        validation_loader=validation_loader,
                        trained_model=trained_model,
                    )
                )

                iteration_log["eval_training"] = {}
                trained_eval_model = train_eval_model(training_log=iteration_log["eval_training"], device=self.device)

                candidate_batch = acquisition_function.compute_candidate_batch(
                    trained_model, trained_eval_model, pool_loader, device=self.device
                )
            else:
                raise ValueError(f"Unknown acquisition function {acquisition_function}!")

            candidate_global_indices = [
                get_base_dataset_index(data.active_learning.pool_dataset, index).index
                for index in candidate_batch.indices
            ]
            candidate_labels = [
                get_target(data.active_learning.pool_dataset, index).item() for index in candidate_batch.indices
            ]

            iteration_log["acquisition"] = dict(
                indices=candidate_global_indices, labels=candidate_labels, scores=candidate_batch.scores
            )

            data.active_learning.acquire(
                [index for index, label in zip(candidate_batch.indices, candidate_labels) if label != -1]
            )

            print(candidate_batch)
            print(
                [
                    (index, get_base_dataset_index(data.active_learning.pool_dataset, index))
                    for index in candidate_batch.indices
                ]
            )

            ls = ", ".join(f"{label} ({score:.4})" for label, score in zip(candidate_labels, candidate_batch.scores))
            print(f"Acquiring (label, score)s: {ls}")

In [None]:
# experiment

RejectionOodExperiment(device="cpu").load_dataset(20)[0].base_dataset

((FastMNIST (train; 58976 samples))~x0.1) + ('OoD Dataset (60000 samples)' | constant_target{'target': tensor(-1), 'num_classes': 10})

In [None]:
# experiment

experiment = RejectionOodExperiment(
    seed=1,
    max_training_epochs=5,
    max_training_set=130,
    acquisition_function=acquisition_functions.BALD,
    acquisition_size=10,
    num_pool_samples=2,
    temperature=5,
    device="cuda",
)

results = {}
experiment.run(results)

Resolved: BALD with {'acquisition_size': 10}
Creating: BALD(acquisition_size=10)
Training set size 20:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.6181640625, 'crossentropy': 2.4419749975204468}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.61328125, 'crossentropy': 3.1026943922042847}
RestoringEarlyStopping: 1 / 3


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.6103515625, 'crossentropy': 3.1747686862945557}
RestoringEarlyStopping: 2 / 3


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.6103515625, 'crossentropy': 3.4691431522369385}
RestoringEarlyStopping: 3 / 3
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -2.4419749975204468)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.6366, 'crossentropy': 2.4271124660491945}


get_predictions_labels:   0%|          | 0/131754 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/65877 [00:00<?, ?it/s]

Entropy:   0%|          | 0/65877 [00:00<?, ?it/s]

CandidateBatch(scores=[0.692390663549304, 0.6920224502682686, 0.6917422339320183, 0.6912369038909674, 0.6911634802818298, 0.6910493969917297, 0.6908157765865326, 0.6905537843704224, 0.690106600522995, 0.689697802066803], indices=[2223, 395, 2510, 754, 4054, 4241, 4883, 1696, 3784, 5685])
[(2223, DatasetIndex(dataset='FastMNIST (train)', index=47269)), (395, DatasetIndex(dataset='FastMNIST (train)', index=6002)), (2510, DatasetIndex(dataset='FastMNIST (train)', index=7430)), (754, DatasetIndex(dataset='FastMNIST (train)', index=21829)), (4054, DatasetIndex(dataset='FastMNIST (train)', index=33301)), (4241, DatasetIndex(dataset='FastMNIST (train)', index=37039)), (4883, DatasetIndex(dataset='FastMNIST (train)', index=19923)), (1696, DatasetIndex(dataset='FastMNIST (train)', index=37185)), (3784, DatasetIndex(dataset='FastMNIST (train)', index=43361)), (5685, DatasetIndex(dataset='FastMNIST (train)', index=38026))]
Acquiring (label, score)s: 7 (0.6924), 7 (0.692), 7 (0.6917), 7 (0.6912), 

 20%|##        | 1/5 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.634765625, 'crossentropy': 2.5922300815582275}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.6552734375, 'crossentropy': 2.792037010192871}
RestoringEarlyStopping: 1 / 3


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.6376953125, 'crossentropy': 3.0876342058181763}
RestoringEarlyStopping: 2 / 3


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.673828125, 'crossentropy': 2.8696995973587036}
RestoringEarlyStopping: 3 / 3
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -2.5922300815582275)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.6519, 'crossentropy': 2.323693693161011}


get_predictions_labels:   0%|          | 0/131734 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/65867 [00:00<?, ?it/s]

Entropy:   0%|          | 0/65867 [00:00<?, ?it/s]

CandidateBatch(scores=[0.6686588823795319, 0.6632348895072937, 0.6590745151042938, 0.654802531003952, 0.6512910425662994, 0.6469181776046753, 0.6458415687084198, 0.6401853486895561, 0.6392222195863724, 0.6391425430774689], indices=[3886, 5253, 5780, 513, 3058, 3771, 717, 1031, 3754, 2355])
[(3886, DatasetIndex(dataset='FastMNIST (train)', index=21256)), (5253, DatasetIndex(dataset='FastMNIST (train)', index=461)), (5780, DatasetIndex(dataset='FastMNIST (train)', index=38538)), (513, DatasetIndex(dataset='FastMNIST (train)', index=7572)), (3058, DatasetIndex(dataset='FastMNIST (train)', index=46161)), (3771, DatasetIndex(dataset='FastMNIST (train)', index=24148)), (717, DatasetIndex(dataset='FastMNIST (train)', index=29334)), (1031, DatasetIndex(dataset='FastMNIST (train)', index=20835)), (3754, DatasetIndex(dataset='FastMNIST (train)', index=9416)), (2355, DatasetIndex(dataset='FastMNIST (train)', index=26621))]
Acquiring (label, score)s: 3 (0.6687), 3 (0.6632), 3 (0.6591), 2 (0.6548),

 20%|##        | 1/5 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7158203125, 'crossentropy': 1.6478411555290222}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7080078125, 'crossentropy': 2.2967125177383423}
RestoringEarlyStopping: 1 / 3


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.724609375, 'crossentropy': 2.3752167224884033}
RestoringEarlyStopping: 2 / 3


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7197265625, 'crossentropy': 2.294450879096985}
RestoringEarlyStopping: 3 / 3
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -1.6478411555290222)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.7411, 'crossentropy': 1.4925826753616334}


get_predictions_labels:   0%|          | 0/131714 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/65857 [00:00<?, ?it/s]

Entropy:   0%|          | 0/65857 [00:00<?, ?it/s]

CandidateBatch(scores=[0.6277315467596054, 0.6259744167327881, 0.6241333484649658, 0.6217308640480042, 0.6215141415596008, 0.6167296469211578, 0.6158503890037537, 0.614221841096878, 0.6115047633647919, 0.6082864105701447], indices=[68, 3556, 3881, 3605, 744, 399, 4098, 3151, 1494, 596])
[(68, DatasetIndex(dataset='FastMNIST (train)', index=47575)), (3556, DatasetIndex(dataset='FastMNIST (train)', index=40491)), (3881, DatasetIndex(dataset='FastMNIST (train)', index=41671)), (3605, DatasetIndex(dataset='FastMNIST (train)', index=50449)), (744, DatasetIndex(dataset='FastMNIST (train)', index=14514)), (399, DatasetIndex(dataset='FastMNIST (train)', index=49405)), (4098, DatasetIndex(dataset='FastMNIST (train)', index=23641)), (3151, DatasetIndex(dataset='FastMNIST (train)', index=42083)), (1494, DatasetIndex(dataset='FastMNIST (train)', index=35443)), (596, DatasetIndex(dataset='FastMNIST (train)', index=6484))]
Acquiring (label, score)s: 5 (0.6277), 5 (0.626), 5 (0.6241), 5 (0.6217), 5 (

 20%|##        | 1/5 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7255859375, 'crossentropy': 1.6036733388900757}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.712890625, 'crossentropy': 2.002104878425598}
RestoringEarlyStopping: 1 / 3


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.736328125, 'crossentropy': 2.018679678440094}
RestoringEarlyStopping: 2 / 3


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.708984375, 'crossentropy': 2.2136075496673584}
RestoringEarlyStopping: 3 / 3
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -1.6036733388900757)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.746, 'crossentropy': 1.4262506231307983}


get_predictions_labels:   0%|          | 0/131694 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/65847 [00:00<?, ?it/s]

Entropy:   0%|          | 0/65847 [00:00<?, ?it/s]

CandidateBatch(scores=[0.6444505453109741, 0.6414125487208366, 0.6350710317492485, 0.6295631229877472, 0.6293644905090332, 0.6242876946926117, 0.6208561956882477, 0.619663417339325, 0.619554728269577, 0.6187689006328583], indices=[5260, 976, 5591, 2966, 5227, 2586, 4641, 67, 1294, 2250])
[(5260, DatasetIndex(dataset='FastMNIST (train)', index=13719)), (976, DatasetIndex(dataset='FastMNIST (train)', index=27530)), (5591, DatasetIndex(dataset='FastMNIST (train)', index=7562)), (2966, DatasetIndex(dataset='FastMNIST (train)', index=44989)), (5227, DatasetIndex(dataset='FastMNIST (train)', index=2007)), (2586, DatasetIndex(dataset='FastMNIST (train)', index=4577)), (4641, DatasetIndex(dataset='FastMNIST (train)', index=21034)), (67, DatasetIndex(dataset='FastMNIST (train)', index=47575)), (1294, DatasetIndex(dataset='FastMNIST (train)', index=36645)), (2250, DatasetIndex(dataset='FastMNIST (train)', index=49537))]
Acquiring (label, score)s: 4 (0.6445), 0 (0.6414), 4 (0.6351), 3 (0.6296), 4

 20%|##        | 1/5 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.6953125, 'crossentropy': 1.7569584250450134}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7197265625, 'crossentropy': 1.8181380033493042}
RestoringEarlyStopping: 1 / 3


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.703125, 'crossentropy': 2.24788761138916}
RestoringEarlyStopping: 2 / 3


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7275390625, 'crossentropy': 2.056207537651062}
RestoringEarlyStopping: 3 / 3
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -1.7569584250450134)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.7064, 'crossentropy': 1.6216364873886109}


get_predictions_labels:   0%|          | 0/131674 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/65837 [00:00<?, ?it/s]

Entropy:   0%|          | 0/65837 [00:00<?, ?it/s]

CandidateBatch(scores=[0.6916591227054596, 0.6865983195602894, 0.685850091278553, 0.6858017891645432, 0.6853644624352455, 0.685111328959465, 0.684671938419342, 0.6834591552615166, 0.6828739866614342, 0.6826837062835693], indices=[3523, 3391, 2635, 2229, 2974, 22913, 5425, 3464, 1580, 1475])
[(3523, DatasetIndex(dataset='FastMNIST (train)', index=14188)), (3391, DatasetIndex(dataset='FastMNIST (train)', index=8315)), (2635, DatasetIndex(dataset='FastMNIST (train)', index=43847)), (2229, DatasetIndex(dataset='FastMNIST (train)', index=4831)), (2974, DatasetIndex(dataset='FastMNIST (train)', index=27679)), (22913, DatasetIndex(dataset='OoD Dataset (60000 samples)', index=17085)), (5425, DatasetIndex(dataset='FastMNIST (train)', index=15327)), (3464, DatasetIndex(dataset='FastMNIST (train)', index=42364)), (1580, DatasetIndex(dataset='FastMNIST (train)', index=55412)), (1475, DatasetIndex(dataset='FastMNIST (train)', index=14704))]
Acquiring (label, score)s: 2 (0.6917), 4 (0.6866), 3 (0.68

 20%|##        | 1/5 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.69140625, 'crossentropy': 1.508621096611023}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.73828125, 'crossentropy': 1.5373234748840332}
RestoringEarlyStopping: 1 / 3


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.744140625, 'crossentropy': 1.941004753112793}
RestoringEarlyStopping: 2 / 3


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7294921875, 'crossentropy': 2.077140748500824}
RestoringEarlyStopping: 3 / 3
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -1.508621096611023)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.735, 'crossentropy': 1.353954902267456}


get_predictions_labels:   0%|          | 0/131656 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/65828 [00:00<?, ?it/s]

Entropy:   0%|          | 0/65828 [00:00<?, ?it/s]

CandidateBatch(scores=[0.6577868759632111, 0.6562519818544388, 0.6560469269752502, 0.6526933908462524, 0.6509876698255539, 0.6503854393959045, 0.6463260054588318, 0.643702358007431, 0.6434788703918457, 0.6433059722185135], indices=[5652, 72, 5417, 5659, 1015, 1365, 937, 1983, 3251, 1014])
[(5652, DatasetIndex(dataset='FastMNIST (train)', index=41769)), (72, DatasetIndex(dataset='FastMNIST (train)', index=63)), (5417, DatasetIndex(dataset='FastMNIST (train)', index=15327)), (5659, DatasetIndex(dataset='FastMNIST (train)', index=52427)), (1015, DatasetIndex(dataset='FastMNIST (train)', index=27991)), (1365, DatasetIndex(dataset='FastMNIST (train)', index=49749)), (937, DatasetIndex(dataset='FastMNIST (train)', index=59257)), (1983, DatasetIndex(dataset='FastMNIST (train)', index=31533)), (3251, DatasetIndex(dataset='FastMNIST (train)', index=42777)), (1014, DatasetIndex(dataset='FastMNIST (train)', index=40426))]
Acquiring (label, score)s: 2 (0.6578), 8 (0.6563), 3 (0.656), 5 (0.6527), 3

 20%|##        | 1/5 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.705078125, 'crossentropy': 1.4032326936721802}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.712890625, 'crossentropy': 1.671923577785492}
RestoringEarlyStopping: 1 / 3


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.716796875, 'crossentropy': 2.1178619861602783}
RestoringEarlyStopping: 2 / 3


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7197265625, 'crossentropy': 1.7620881795883179}
RestoringEarlyStopping: 3 / 3
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -1.4032326936721802)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.7231, 'crossentropy': 1.2891359241485596}


get_predictions_labels:   0%|          | 0/131636 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/65818 [00:00<?, ?it/s]

Entropy:   0%|          | 0/65818 [00:00<?, ?it/s]

CandidateBatch(scores=[0.6631099283695221, 0.6597822308540344, 0.6486785411834717, 0.6428181529045105, 0.6393815577030182, 0.6268095970153809, 0.6262959539890289, 0.6231429874897003, 0.621695339679718, 0.6200662404298782], indices=[3879, 1431, 1354, 4505, 5345, 3598, 1276, 1352, 2588, 5268])
[(3879, DatasetIndex(dataset='FastMNIST (train)', index=21006)), (1431, DatasetIndex(dataset='FastMNIST (train)', index=35000)), (1354, DatasetIndex(dataset='FastMNIST (train)', index=49520)), (4505, DatasetIndex(dataset='FastMNIST (train)', index=9337)), (5345, DatasetIndex(dataset='FastMNIST (train)', index=2448)), (3598, DatasetIndex(dataset='FastMNIST (train)', index=16077)), (1276, DatasetIndex(dataset='FastMNIST (train)', index=43157)), (1352, DatasetIndex(dataset='FastMNIST (train)', index=36163)), (2588, DatasetIndex(dataset='FastMNIST (train)', index=33813)), (5268, DatasetIndex(dataset='FastMNIST (train)', index=23265))]
Acquiring (label, score)s: 7 (0.6631), 0 (0.6598), 7 (0.6487), 0 (0.

 20%|##        | 1/5 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.71484375, 'crossentropy': 1.3069573640823364}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7373046875, 'crossentropy': 1.4336466193199158}
RestoringEarlyStopping: 1 / 3


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.736328125, 'crossentropy': 1.5378057360649109}
RestoringEarlyStopping: 2 / 3


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7353515625, 'crossentropy': 1.6425959467887878}
RestoringEarlyStopping: 3 / 3
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -1.3069573640823364)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.7319, 'crossentropy': 1.280282580947876}


get_predictions_labels:   0%|          | 0/131616 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/65808 [00:00<?, ?it/s]

Entropy:   0%|          | 0/65808 [00:00<?, ?it/s]

CandidateBatch(scores=[0.6065114140510559, 0.551966667175293, 0.5503286719322205, 0.5480679273605347, 0.5340226292610168, 0.5274611115455627, 0.5260781049728394, 0.5244690775871277, 0.5213816165924072, 0.5168037414550781], indices=[3001, 4121, 1004, 4591, 497, 761, 65479, 1051, 3553, 42569])
[(3001, DatasetIndex(dataset='FastMNIST (train)', index=47187)), (4121, DatasetIndex(dataset='FastMNIST (train)', index=10501)), (1004, DatasetIndex(dataset='FastMNIST (train)', index=88)), (4591, DatasetIndex(dataset='FastMNIST (train)', index=2120)), (497, DatasetIndex(dataset='FastMNIST (train)', index=12249)), (761, DatasetIndex(dataset='FastMNIST (train)', index=9530)), (65479, DatasetIndex(dataset='OoD Dataset (60000 samples)', index=59679)), (1051, DatasetIndex(dataset='FastMNIST (train)', index=29158)), (3553, DatasetIndex(dataset='FastMNIST (train)', index=2824)), (42569, DatasetIndex(dataset='OoD Dataset (60000 samples)', index=36769))]
Acquiring (label, score)s: 8 (0.6065), 3 (0.552), 6 

 20%|##        | 1/5 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7353515625, 'crossentropy': 1.3632447719573975}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.759765625, 'crossentropy': 1.473086655139923}
RestoringEarlyStopping: 1 / 3


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7529296875, 'crossentropy': 1.9823876023292542}
RestoringEarlyStopping: 2 / 3


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.767578125, 'crossentropy': 1.7056154012680054}
RestoringEarlyStopping: 3 / 3
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -1.3632447719573975)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.7265, 'crossentropy': 1.3068740461349488}


get_predictions_labels:   0%|          | 0/131600 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/65800 [00:00<?, ?it/s]

Entropy:   0%|          | 0/65800 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
results

In [None]:
# experiment

experiment = Experiment(
    max_training_epochs=1, max_training_set=25, acquisition_function=AcquisitionFunction.randombaldical
)

results = {}
experiment.run(results)

results

In [None]:
# exports

configs = [
    RejectionOodExperiment(
        seed=seed,
        acquisition_function=acquisition_functions.TemperedEvalBALD,
        acquisition_size=acquisition_size,
        num_pool_samples=num_pool_samples,
        temperature=8,
    )
    for seed in range(5)
    for acquisition_size in [5, 10, 20, 50]
    for num_pool_samples in [100]
] + [
    RejectionOodExperiment(
        seed=seed,
        acquisition_function=acquisition_functions.Random,
        acquisition_size=5,
    )
    for seed in range(20)
]

if not is_run_from_ipython() and __name__ == "__main__":
    for job_id, store in embedded_experiments(__file__, len(configs)):
        config = configs[job_id]
        config.seed += job_id
        print(config)
        store["config"] = dataclasses.asdict(config)
        store["log"] = {}

        try:
            config.run(store=store)
        except Exception:
            store["exception"] = traceback.format_exc()
            raise

In [None]:
len(configs)