# Experiment
> Can we get better by training on our assumptions?

In [None]:
# default_exp experiment

In [None]:
# hide
import blackhc.project.script

Appended /home/blackhc/PycharmProjects/bald-ical/src to paths
Switched to directory /home/blackhc/PycharmProjects/bald-ical
%load_ext autoreload
%autoreload 2


Import modules and functions were are going to use.

In [None]:
# exports

import dataclasses
import traceback
from dataclasses import dataclass
from typing import Type, Union

import torch
import torch.utils.data
from blackhc.project import is_run_from_ipython
from blackhc.project.experiment import embedded_experiments
from torch import nn
from torch.utils.data import Dataset

import batchbald_redux.acquisition_functions as acquisition_functions
from batchbald_redux.acquisition_functions import (
    CandidateBatchComputer,
    EvalCandidateBatchComputer,
)
from batchbald_redux.active_learning import ActiveLearningData, RandomFixedLengthSampler
from batchbald_redux.black_box_model_training import evaluate_old, train, evaluate
from batchbald_redux.dataset_challenges import (
    create_repeated_MNIST_dataset,
    get_base_dataset_index,
    get_target,
)
from batchbald_redux.di import DependencyInjection
from batchbald_redux.model_optimizer_factory import ModelOptimizerFactory
from batchbald_redux.models import MnistOptimizerFactory, MnistModelTrainer

In [None]:
# exports

# From the BatchBALD Repo
from batchbald_redux.train_eval_model import (
    TrainEvalModel,
    TrainSelfDistillationEvalModel,
)
from batchbald_redux.trained_model import TrainedBayesianModel, ModelTrainer

mnist_initial_samples = [
    38043,
    40091,
    17418,
    2094,
    39879,
    3133,
    5011,
    40683,
    54379,
    24287,
    9849,
    59305,
    39508,
    39356,
    8758,
    52579,
    13655,
    7636,
    21562,
    41329,
]

In [None]:
# exports


@dataclass
class Experiment:
    seed: int = 1337
    acquisition_size: int = 5
    max_training_set: int = 300
    num_pool_samples: int = 20
    num_validation_samples: int = 20
    num_training_samples: int = 1
    num_patience_epochs: int = 5*4
    max_training_epochs: int = 30*4
    training_batch_size: int = 64
    device: str = "cuda"
    validation_set_size: int = 2048
    initial_set_size: int = 20
    min_samples_per_epoch: int = 1024
    repeated_mnist_repetitions: int = 1
    add_dataset_noise: bool = False
    acquisition_function: Union[
        Type[CandidateBatchComputer], Type[EvalCandidateBatchComputer]
    ] = acquisition_functions.BALD
    train_eval_model_factory: Type[TrainEvalModel] = TrainSelfDistillationEvalModel
    model_trainer_factory: Type[ModelTrainer] = MnistModelTrainer
    acquisition_function_args: dict = None
    temperature: float = 0.0

    def load_dataset(self, initial_training_set_indices) -> (ActiveLearningData, Dataset, Dataset):
        train_dataset, test_dataset = create_repeated_MNIST_dataset(
            num_repetitions=self.repeated_mnist_repetitions, add_noise=self.add_dataset_noise
        )
        active_learning_data = ActiveLearningData(train_dataset)

        active_learning_data.acquire(initial_training_set_indices)

        validation_dataset = active_learning_data.extract_dataset_from_pool(self.validation_set_size)

        return active_learning_data, validation_dataset, test_dataset

    # Simple Dependency Injection
    def create_acquisition_function(self):
        di = DependencyInjection(vars(self))
        return di.create_dataclass_type(self.acquisition_function)

    def create_train_eval_model(self, runtime_config) -> TrainEvalModel:
        config = {**vars(self), **runtime_config}
        di = DependencyInjection(config, [])
        return di.create_dataclass_type(self.train_eval_model_factory)

    def create_model_trainer(self) -> ModelTrainer:
        di = DependencyInjection(vars(self))
        return di.create_dataclass_type(self.model_trainer_factory)

    def run(self, store):
        torch.manual_seed(self.seed)

        initial_training_set_indices = mnist_initial_samples
        store["initial_training_set_indices"] = initial_training_set_indices

        # Active Learning setup
        active_learning_data, validation_dataset, test_dataset = self.load_dataset(initial_training_set_indices)
        store["dataset_info"] = dict(training=repr(active_learning_data.base_dataset), test=repr(test_dataset))

        # initial_training_set_indices = active_learning_data.get_random_pool_indices(self.initial_set_size)
        # initial_training_set_indices = get_balanced_sample_indices(
        #     active_learning_data.pool_dataset, 10, self.initial_set_size // 10
        # )

        train_loader = torch.utils.data.DataLoader(
            active_learning_data.training_dataset,
            batch_size=64,
            sampler=RandomFixedLengthSampler(active_learning_data.training_dataset, self.min_samples_per_epoch),
            drop_last=True,
        )
        pool_loader = torch.utils.data.DataLoader(
            active_learning_data.pool_dataset, batch_size=128, drop_last=False, shuffle=False
        )

        validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=128, drop_last=False)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, drop_last=False)

        store["active_learning_steps"] = []
        active_learning_steps = store["active_learning_steps"]

        acquisition_function = self.create_acquisition_function()

        model_trainer = MnistModelTrainer(
            num_training_samples=self.num_training_samples,
            num_validation_samples=self.num_validation_samples,
            num_patience_epochs=self.num_patience_epochs,
            max_training_epochs=self.max_training_epochs,
            device=self.device
        )

        # Active Training Loop
        while True:
            training_set_size = len(active_learning_data.training_dataset)
            print(f"Training set size {training_set_size}:")

            # iteration_log = dict(training={}, pool_training={}, evaluation_metrics=None, acquisition=None)
            active_learning_steps.append({})
            iteration_log = active_learning_steps[-1]

            iteration_log["training"] = {}
            trained_model = model_trainer.get_trained(train_loader=train_loader, train_augmentations=None,
                                                      validation_loader=validation_loader, log=iteration_log["training"])

            evaluation_metrics = evaluate(model=trained_model, num_samples=self.num_validation_samples, loader=test_loader, device=self.device)

            iteration_log["evaluation_metrics"] = evaluation_metrics
            print(f"Perf after training {evaluation_metrics}")

            if training_set_size >= self.max_training_set:
                print("Done.")
                break

            if isinstance(acquisition_function, CandidateBatchComputer):
                candidate_batch = acquisition_function.compute_candidate_batch(trained_model, pool_loader, self.device)
            elif isinstance(acquisition_function, EvalCandidateBatchComputer):
                train_eval_model = self.create_train_eval_model(
                    dict(
                        model_trainer=model_trainer,
                        training_dataset=active_learning_data.training_dataset,
                        eval_dataset=active_learning_data.pool_dataset,
                        validation_loader=validation_loader,
                        trained_model=trained_model,
                    )
                )

                iteration_log["eval_training"] = {}
                trained_eval_model = train_eval_model(training_log=iteration_log["eval_training"], device=self.device)

                candidate_batch = acquisition_function.compute_candidate_batch(
                    trained_model, trained_eval_model, pool_loader, device=self.device
                )
            else:
                raise ValueError(f"Unknown acquisition function {acquisition_function}!")

            candidate_global_indices = [
                get_base_dataset_index(active_learning_data.pool_dataset, index).index
                for index in candidate_batch.indices
            ]
            candidate_labels = [
                get_target(active_learning_data.base_dataset, index).item() for index in candidate_global_indices
            ]

            iteration_log["acquisition"] = dict(
                indices=candidate_global_indices, labels=candidate_labels, scores=candidate_batch.scores
            )

            active_learning_data.acquire(candidate_batch.indices)

            ls = ", ".join(f"{label} ({score:.4})" for label, score in zip(candidate_labels, candidate_batch.scores))
            print(f"Acquiring (label, score)s: {ls}")

In [None]:
# experiment

experiment = Experiment(
    seed=1120,
    max_training_epochs=30,
    max_training_set=130,
    acquisition_function=acquisition_functions.SoftmaxBALD,
    acquisition_size=10,
    num_pool_samples=20,
    temperature=8,
    device="cuda",
)

results = {}
experiment.run(results)

Creating: SoftmaxBALD(
	acquisition_size=10,
	temperature=8
)
Training set size 20:


  3%|3         | 1/30 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.59033203125, 'crossentropy': 1.861450120806694}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.609375, 'crossentropy': 1.8868751600384712}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.61572265625, 'crossentropy': 2.2216100096702576}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6376953125, 'crossentropy': 2.2530888095498085}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.62939453125, 'crossentropy': 2.50810956209898}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6357421875, 'crossentropy': 2.4853495061397552}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6328125, 'crossentropy': 2.6377444863319397}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.642578125, 'crossentropy': 2.675654709339142}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6494140625, 'crossentropy': 2.673930361866951}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.63671875, 'crossentropy': 3.057991608977318}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6455078125, 'crossentropy': 2.974304348230362}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.642578125, 'crossentropy': 3.0666737258434296}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.63427734375, 'crossentropy': 3.1648862659931183}
RestoringEarlyStopping: 4 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.64013671875, 'crossentropy': 3.036598652601242}
RestoringEarlyStopping: 5 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.64501953125, 'crossentropy': 3.1013021171092987}
RestoringEarlyStopping: 6 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.646484375, 'crossentropy': 3.1335627138614655}
RestoringEarlyStopping: 7 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.638671875, 'crossentropy': 3.256747245788574}
RestoringEarlyStopping: 8 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6328125, 'crossentropy': 3.138277366757393}
RestoringEarlyStopping: 9 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.64111328125, 'crossentropy': 3.1264200508594513}
RestoringEarlyStopping: 10 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.63427734375, 'crossentropy': 3.142566606402397}
RestoringEarlyStopping: 11 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.63525390625, 'crossentropy': 3.3125554025173187}
RestoringEarlyStopping: 12 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.64404296875, 'crossentropy': 3.1048810184001923}
RestoringEarlyStopping: 13 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6396484375, 'crossentropy': 3.209494262933731}
RestoringEarlyStopping: 14 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.63818359375, 'crossentropy': 3.2090319395065308}
RestoringEarlyStopping: 15 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6455078125, 'crossentropy': 3.268053323030472}
RestoringEarlyStopping: 16 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6396484375, 'crossentropy': 3.357768729329109}
RestoringEarlyStopping: 17 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6396484375, 'crossentropy': 3.211002081632614}
RestoringEarlyStopping: 18 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.64306640625, 'crossentropy': 3.375944823026657}
RestoringEarlyStopping: 19 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.634765625, 'crossentropy': 3.4645328372716904}
RestoringEarlyStopping: 20 / 20
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: 0.6494140625)
RestoringEarlyStopping: Restoring optimizer.


[1/79]   1%|1          [00:00<?]

Perf after training {'accuracy': 0.6619, 'crossentropy': 2.678555955505371}


get_predictions_labels:   0%|          | 0/1158640 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/57932 [00:00<?, ?it/s]

Entropy:   0%|          | 0/57932 [00:00<?, ?it/s]

Acquiring (label, score)s: 1 (1.932), 4 (1.772), 7 (1.01), 7 (1.0), 0 (1.021), 8 (1.771), 3 (1.964), 3 (1.62), 8 (1.673), 1 (1.012)
Training set size 30:


  3%|3         | 1/30 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.4375, 'crossentropy': 1.983402356505394}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.60009765625, 'crossentropy': 1.844376228749752}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6337890625, 'crossentropy': 1.9153877794742584}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.61279296875, 'crossentropy': 2.430958680808544}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.64453125, 'crossentropy': 2.3921838775277138}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.63818359375, 'crossentropy': 2.634309396147728}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6416015625, 'crossentropy': 2.627225875854492}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.634765625, 'crossentropy': 2.755842551589012}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.62158203125, 'crossentropy': 2.9181164503097534}
RestoringEarlyStopping: 4 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.630859375, 'crossentropy': 2.908192679286003}
RestoringEarlyStopping: 5 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.634765625, 'crossentropy': 2.9231721311807632}
RestoringEarlyStopping: 6 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.62841796875, 'crossentropy': 3.060195416212082}
RestoringEarlyStopping: 7 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6455078125, 'crossentropy': 2.98862624168396}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.62060546875, 'crossentropy': 3.0973827093839645}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.62451171875, 'crossentropy': 3.1818181574344635}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.63037109375, 'crossentropy': 3.179712325334549}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6552734375, 'crossentropy': 3.141026258468628}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.630859375, 'crossentropy': 3.09803469479084}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6416015625, 'crossentropy': 2.984661117196083}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.62109375, 'crossentropy': 3.277120530605316}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.638671875, 'crossentropy': 2.9718064218759537}
RestoringEarlyStopping: 4 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.62255859375, 'crossentropy': 3.2840312719345093}
RestoringEarlyStopping: 5 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.646484375, 'crossentropy': 3.0958415865898132}
RestoringEarlyStopping: 6 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6318359375, 'crossentropy': 3.3209871351718903}
RestoringEarlyStopping: 7 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6318359375, 'crossentropy': 3.3115693032741547}
RestoringEarlyStopping: 8 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6396484375, 'crossentropy': 3.2767363786697388}
RestoringEarlyStopping: 9 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.63037109375, 'crossentropy': 3.4181878119707108}
RestoringEarlyStopping: 10 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6298828125, 'crossentropy': 3.594739645719528}
RestoringEarlyStopping: 11 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6376953125, 'crossentropy': 3.528226226568222}
RestoringEarlyStopping: 12 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.65283203125, 'crossentropy': 3.2282912731170654}
RestoringEarlyStopping: 13 / 20
RestoringEarlyStopping: Restoring best parameters. (Score: 0.6552734375)
RestoringEarlyStopping: Restoring optimizer.


[1/79]   1%|1          [00:00<?]

Perf after training {'accuracy': 0.642, 'crossentropy': 3.1842077194213867}


get_predictions_labels:   0%|          | 0/1158440 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/57922 [00:00<?, ?it/s]

Entropy:   0%|          | 0/57922 [00:00<?, ?it/s]

Acquiring (label, score)s: 5 (1.886), 7 (1.037), 3 (1.396), 5 (1.328), 4 (1.046), 0 (1.171), 3 (1.563), 1 (1.003), 4 (1.022), 0 (1.012)
Training set size 40:


  3%|3         | 1/30 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.4619140625, 'crossentropy': 1.9796495214104652}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.583984375, 'crossentropy': 1.8143711909651756}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.62109375, 'crossentropy': 2.1308560520410538}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.62744140625, 'crossentropy': 2.4476575404405594}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.646484375, 'crossentropy': 2.4642718583345413}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.63818359375, 'crossentropy': 2.7880875319242477}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.62646484375, 'crossentropy': 2.965988278388977}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.64697265625, 'crossentropy': 3.045136660337448}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.64306640625, 'crossentropy': 3.0356389731168747}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.62841796875, 'crossentropy': 3.219661369919777}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Engine run is terminating due to exception: .


Epoch metrics: {'accuracy': 0.64208984375, 'crossentropy': 3.0388206243515015}
RestoringEarlyStopping: 3 / 20


KeyboardInterrupt: 

In [None]:
results

{'initial_training_set_indices': [38043,
  40091,
  17418,
  2094,
  39879,
  3133,
  5011,
  40683,
  54379,
  24287,
  9849,
  59305,
  39508,
  39356,
  8758,
  52579,
  13655,
  7636,
  21562,
  41329],
 'dataset_info': {'training': "'FastMNIST (Train)'",
  'test': "'FastMNIST (Test)'"},
 'active_learning_steps': [{'training': {'epochs': []}}]}

In [None]:
# experiment

experiment = Experiment(
    max_training_epochs=1, max_training_set=25, acquisition_function=AcquisitionFunction.randombaldical
)

results = {}
experiment.run(results)

results

Training set size 20:


100%|##########| 1/1 [00:00<?, ?it/s]

[1/384]   0%|           [00:00<?]

[1/64]   2%|1          [00:00<?]

RestoringEarlyStopping: Restoring best parameters. (Score: -6.529030114412308)
RestoringEarlyStopping: Restoring optimizer.


[1/157]   1%|           [00:00<?]

Perf after training {'accuracy': 0.5367, 'crossentropy': 6.438035237884521}


get_predictions:   0%|          | 0/463616 [00:00<?, ?it/s]

100%|##########| 1/1 [00:00<?, ?it/s]

[1/1811]   0%|           [00:00<?]

[1/64]   2%|1          [00:00<?]

RestoringEarlyStopping: Restoring best parameters. (Score: -5.1637596152722836)
RestoringEarlyStopping: Restoring optimizer.


get_predictions:   0%|          | 0/2317680 [00:00<?, ?it/s]

get_predictions:   0%|          | 0/2317680 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Acquiring (label, score)s: 8 (0.8711), 8 (0.8687), 3 (0.876), 3 (0.8465), 3 (0.8811)
Training set size 25:


100%|##########| 1/1 [00:00<?, ?it/s]

[1/384]   0%|           [00:00<?]

[1/64]   2%|1          [00:00<?]

RestoringEarlyStopping: Restoring best parameters. (Score: -4.6851686127483845)
RestoringEarlyStopping: Restoring optimizer.


[1/157]   1%|           [00:00<?]

Perf after training {'accuracy': 0.6256, 'crossentropy': 4.484497045135498}
Done.


{'initial_training_set_indices': [38043,
  40091,
  17418,
  2094,
  39879,
  3133,
  5011,
  40683,
  54379,
  24287,
  9849,
  59305,
  39508,
  39356,
  8758,
  52579,
  13655,
  7636,
  21562,
  41329],
 'active_learning_steps': [{'training': {'epochs': [{'accuracy': 0.538818359375,
      'crossentropy': 6.529030114412308}],
    'best_epoch': 1},
   'evalution_metrics': {'accuracy': 0.5367,
    'crossentropy': 6.438035237884521},
   'pool_training': {'epochs': [{'accuracy': 0.531005859375,
      'crossentropy': 5.1637596152722836}],
    'best_epoch': 1},
   'acquisition': {'indices': [63338, 10856, 63452, 81864, 109287],
    'labels': [8, 8, 3, 3, 3],
    'scores': [0.8710822958846325,
     0.8687216999221631,
     0.8759664372823723,
     0.8464646732511746,
     0.8810812784952251]}},
  {'training': {'epochs': [{'accuracy': 0.62255859375,
      'crossentropy': 4.6851686127483845}],
    'best_epoch': 1},
   'evalution_metrics': {'accuracy': 0.6256,
    'crossentropy': 4.4844970451

In [None]:
if False:
    configs = (
        [
            Experiment(
                seed=seed + 120,
                acquisition_function=acquisition_function,
                acquisition_size=acquisition_size,
                num_pool_samples=100,
            )
            for seed in range(5)
            for acquisition_size in [5, 10]
            for acquisition_function in [AcquisitionFunction.batchbald, AcquisitionFunction.batchbaldical]
        ]
        + [
            Experiment(
                seed=seed + 120,
                acquisition_function=acquisition_function,
                acquisition_size=acquisition_size,
                num_pool_samples=max(20, acquisition_size),
            )
            for seed in range(5)
            for acquisition_size in [5, 10, 20, 50]
            for acquisition_function in [
                AcquisitionFunction.bald,
                AcquisitionFunction.thompsonbald,
                AcquisitionFunction.randombald,
            ]
        ]
        + [
            Experiment(
                seed=seed + 120,
                acquisition_function=AcquisitionFunction.random,
                acquisition_size=5,
                num_pool_samples=20,
            )
            for seed in range(40)
            for acquisition_size in [5]
        ]
    )
if False:
    configs = [
        Experiment(
            seed=seed + 240,
            acquisition_function=acquisition_function,
            acquisition_size=acquisition_size,
            num_pool_samples=max(20, acquisition_size),
        )
        for seed in range(5)
        for acquisition_size in [5, 10, 20, 50]
        for acquisition_function in [
            AcquisitionFunction.baldical,
            AcquisitionFunction.randombaldical,
        ]
    ]
if False:
    configs = [
        Experiment(
            seed=seed + 340,
            acquisition_function=acquisition_function,
            acquisition_size=acquisition_size,
            num_pool_samples=20,
            temperature=temperature,
        )
        for seed in range(5)
        for acquisition_size in [5, 10, 20, 50]
        for acquisition_function in [
            AcquisitionFunction.temperedbald,
        ]
        for temperature in [13, 15, 18]
    ]
if False:
    configs = [
        Experiment(
            seed=seed + 340,
            acquisition_function=acquisition_function,
            acquisition_size=acquisition_size,
            num_pool_samples=20,
            temperature=temperature,
        )
        for seed in range(5)
        for acquisition_size in [5, 10, 20, 50]
        for acquisition_function in [
            AcquisitionFunction.temperedbald,
        ]
        for temperature in [8, 10]
    ] + [
        Experiment(
            seed=seed + 340,
            acquisition_function=acquisition_function,
            acquisition_size=acquisition_size,
            num_pool_samples=20,
            temperature=temperature,
        )
        for seed in range(5)
        for acquisition_size in [5, 10, 20, 50]
        for acquisition_function in [
            AcquisitionFunction.temperedbaldical,
        ]
        for temperature in [8, 10, 13]
    ]
if False:
    configs = [
        Experiment(
            seed=seed + 500,
            acquisition_function=acquisition_function,
            acquisition_size=acquisition_size,
            num_pool_samples=20,
            temperature=temperature,
        )
        for seed in range(5)
        for acquisition_size in [5, 10, 20, 50]
        for acquisition_function in [
            AcquisitionFunction.temperedical,
        ]
        for temperature in [5, 8, 11]
    ] + [
        Experiment(
            seed=seed + 600,
            acquisition_function=acquisition_function,
            acquisition_size=acquisition_size,
            num_pool_samples=20,
        )
        for seed in range(5)
        for acquisition_size in [5, 10, 20, 50]
        for acquisition_function in [
            AcquisitionFunction.ical,
        ]
    ]
if False:
    configs = (
        [
            Experiment(
                seed=seed + 1000,
                acquisition_function=acquisition_function,
                acquisition_size=acquisition_size,
                num_pool_samples=num_pool_samples,
                temperature=temperature,
            )
            for seed in range(5)
            for acquisition_size in [5, 10, 20, 50]
            for num_pool_samples in [20, 100]
            for acquisition_function in [
                AcquisitionFunction.temperedical,
                AcquisitionFunction.temperedbald,
                AcquisitionFunction.temperedbaldical,
            ]
            for temperature in [2, 5, 8]
        ]
        + [
            Experiment(
                seed=seed + 2000,
                acquisition_function=acquisition_function,
                acquisition_size=acquisition_size,
                num_pool_samples=num_pool_samples,
            )
            for seed in range(5)
            for acquisition_size in [5, 10, 20, 50]
            for num_pool_samples in [20, 100]
            for acquisition_function in [
                AcquisitionFunction.ical,
                AcquisitionFunction.bald,
                AcquisitionFunction.baldical,
                AcquisitionFunction.randombald,
            ]
        ]
        + [
            Experiment(
                seed=seed + 2000,
                acquisition_function=acquisition_function,
                acquisition_size=acquisition_size,
                num_pool_samples=max(num_pool_samples, acquisition_size),
            )
            for seed in range(5)
            for acquisition_size in [5, 10, 20, 50]
            for num_pool_samples in [20, 100]
            for acquisition_function in [
                AcquisitionFunction.thompsonbald,
            ]
        ]
        + [
            Experiment(
                seed=seed + 3000,
                acquisition_function=acquisition_function,
                acquisition_size=acquisition_size,
                num_pool_samples=100,
            )
            for seed in range(5)
            for acquisition_size in [5]
            for acquisition_function in [
                AcquisitionFunction.batchbaldical,
                AcquisitionFunction.batchbald,
            ]
        ]
    )
if False:
    configs = [
    Experiment(
        seed=seed,
        acquisition_function=acquisition_functions.BALD,
        acquisition_size=acquisition_size,
        num_pool_samples=num_pool_samples,
    )
    for seed in range(5)
    for acquisition_size in [5, 10, 20, 50]
    for num_pool_samples in [10, 20, 50, 100]
] + [
    Experiment(
        seed=seed,
        acquisition_function=acquisition_functions.Random,
        acquisition_size=5,
        num_pool_samples=20,
    )
    for seed in range(20)
]


In [None]:
# exports

configs = [
    Experiment(
        seed=seed + 315,
        acquisition_function=acquisition_function,
        acquisition_size=acquisition_size,
        num_pool_samples=num_pool_samples,
        repeated_mnist_repetitions=repeated_mnist_repetitions,
        add_dataset_noise=repeated_mnist_repetitions > 1,
        temperature=temperature,
        max_training_set=150,
    )
    for seed in range(5)
    for acquisition_function in [
                acquisition_functions.SoftmaxBALD,
            ]
    for temperature in [1/32, 1/64, 1/128, 1/256]
    for acquisition_size in [5]
    for num_pool_samples in [100]
    for repeated_mnist_repetitions in [2]
]

if not is_run_from_ipython() and __name__ == "__main__":
    for job_id, store in embedded_experiments(__file__, len(configs)):
        config = configs[job_id]
        config.seed += job_id
        print(config)
        store["config"] = dataclasses.asdict(config)
        store["log"] = {}

        try:
            config.run(store=store)
        except Exception:
            store["exception"] = traceback.format_exc()
            raise

In [None]:
len(configs)

90

In [None]:
# slow
import prettyprinter

prettyprinter.install_extras(include={"dataclasses"})

prettyprinter.pprint(configs)

[
    Experiment(
        seed=115,
        max_training_set=150,
        num_pool_samples=100,
        repeated_mnist_repetitions=2,
        add_dataset_noise=True,
        # class
        acquisition_function=batchbald_redux.acquisition_functions.SoftmaxBALD,
        temperature=0.0625
    ),
    Experiment(
        seed=115,
        max_training_set=150,
        num_pool_samples=100,
        repeated_mnist_repetitions=2,
        add_dataset_noise=True,
        # class
        acquisition_function=batchbald_redux.acquisition_functions.SoftmaxBALD,
        temperature=0.125
    ),
    Experiment(
        seed=115,
        max_training_set=150,
        num_pool_samples=100,
        repeated_mnist_repetitions=2,
        add_dataset_noise=True,
        # class
        acquisition_function=batchbald_redux.acquisition_functions.SoftmaxBALD,
        temperature=0.25
    ),
    Experiment(
        seed=115,
        max_training_set=150,
        num_pool_samples=100,
        repeated_mnist_re