# Experiment
> Can we get better by training on our assumptions?

In [None]:
# default_exp experiment

In [None]:
# hide
import blackhc.project.script

Import modules and functions were are going to use.

In [None]:
# exports

import dataclasses
import traceback
from dataclasses import dataclass
from typing import Type, Union

import torch
import torch.utils.data
from blackhc.project import is_run_from_ipython
from blackhc.project.experiment import embedded_experiments
from torch.utils.data import Dataset

import batchbald_redux.acquisition_functions as acquisition_functions
from batchbald_redux.acquisition_functions import (
    CandidateBatchComputer,
    EvalCandidateBatchComputer,
)
from batchbald_redux.active_learning import ActiveLearningData, RandomFixedLengthSampler
from batchbald_redux.black_box_model_training import evaluate, train, train_double_snapshots
from batchbald_redux.dataset_challenges import (
    create_repeated_MNIST_dataset,
    get_base_dataset_index,
    get_target,
)
from batchbald_redux.di import DependencyInjection
from batchbald_redux.model_optimizer_factory import ModelOptimizerFactory
from batchbald_redux.models import MnistOptimizerFactory

In [None]:
# exports

# From the BatchBALD Repo
from batchbald_redux.train_eval_model import (
    TrainEvalModel,
    TrainSelfDistillationEvalModel,
)
from batchbald_redux.trained_model import TrainedMCDropoutModel

mnist_initial_samples = [
    38043,
    40091,
    17418,
    2094,
    39879,
    3133,
    5011,
    40683,
    54379,
    24287,
    9849,
    59305,
    39508,
    39356,
    8758,
    52579,
    13655,
    7636,
    21562,
    41329,
]

In [None]:
# exports


@dataclass
class Experiment:
    seed: int = 1337
    acquisition_size: int = 5
    max_training_set: int = 300
    num_pool_samples: int = 20
    num_validation_samples: int = 20
    num_training_samples: int = 1
    num_patience_epochs: int = 5*4
    max_training_epochs: int = 30*4
    training_batch_size: int = 64
    device: str = "cuda"
    validation_set_size: int = 2048
    initial_set_size: int = 20
    min_samples_per_epoch: int = 1024
    repeated_mnist_repetitions: int = 1
    add_dataset_noise: bool = False
    acquisition_function: Union[
        Type[CandidateBatchComputer], Type[EvalCandidateBatchComputer]
    ] = acquisition_functions.BALD
    train_eval_model: TrainEvalModel = TrainSelfDistillationEvalModel
    model_optimizer_factory: Type[ModelOptimizerFactory] = MnistOptimizerFactory
    acquisition_function_args: dict = None
    temperature: float = 0.0

    def load_dataset(self, initial_training_set_indices) -> (ActiveLearningData, Dataset, Dataset):
        train_dataset, test_dataset = create_repeated_MNIST_dataset(
            num_repetitions=self.repeated_mnist_repetitions, add_noise=self.add_dataset_noise
        )
        active_learning_data = ActiveLearningData(train_dataset)

        active_learning_data.acquire(initial_training_set_indices)

        validation_dataset = active_learning_data.extract_dataset_from_pool(self.validation_set_size)

        return active_learning_data, validation_dataset, test_dataset

    # Simple Dependency Injection
    def create_acquisition_function(self):
        di = DependencyInjection(vars(self))
        return di.create_dataclass_type(self.acquisition_function)

    def create_train_eval_model(self, runtime_config) -> TrainEvalModel:
        config = {**vars(self), **runtime_config}
        di = DependencyInjection(config, [])
        return di.create_dataclass_type(self.train_eval_model)

    def run(self, store):
        torch.manual_seed(self.seed)

        initial_training_set_indices = mnist_initial_samples
        store["initial_training_set_indices"] = initial_training_set_indices

        # Active Learning setup
        active_learning_data, validation_dataset, test_dataset = self.load_dataset(initial_training_set_indices)
        store["dataset_info"] = dict(training=repr(active_learning_data.base_dataset), test=repr(test_dataset))

        # initial_training_set_indices = active_learning_data.get_random_pool_indices(self.initial_set_size)
        # initial_training_set_indices = get_balanced_sample_indices(
        #     active_learning_data.pool_dataset, 10, self.initial_set_size // 10
        # )

        train_loader = torch.utils.data.DataLoader(
            active_learning_data.training_dataset,
            batch_size=64,
            sampler=RandomFixedLengthSampler(active_learning_data.training_dataset, self.min_samples_per_epoch),
            drop_last=True,
        )
        pool_loader = torch.utils.data.DataLoader(
            active_learning_data.pool_dataset, batch_size=128, drop_last=False, shuffle=False
        )

        validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=128, drop_last=False)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, drop_last=False)

        store["active_learning_steps"] = []
        active_learning_steps = store["active_learning_steps"]

        acquisition_function = self.create_acquisition_function()

        # Active Training Loop
        while True:
            training_set_size = len(active_learning_data.training_dataset)
            print(f"Training set size {training_set_size}:")

            # iteration_log = dict(training={}, pool_training={}, evaluation_metrics=None, acquisition=None)
            active_learning_steps.append({})
            iteration_log = active_learning_steps[-1]

            iteration_log["training"] = {}

            model_optimizer = self.model_optimizer_factory().create_model_optimizer()

            train(
                model=model_optimizer.model,
                optimizer=model_optimizer.optimizer,
                training_samples=self.num_training_samples,
                validation_samples=self.num_validation_samples,
                train_loader=train_loader,
                validation_loader=validation_loader,
                patience=self.num_patience_epochs,
                max_epochs=self.max_training_epochs,
                device=self.device,
                training_log=iteration_log["training"],
            )

            evaluation_metrics = evaluate(
                model=model_optimizer.model,
                num_samples=self.num_validation_samples,
                loader=test_loader,
                device=self.device,
            )
            iteration_log["evaluation_metrics"] = evaluation_metrics
            print(f"Perf after training {evaluation_metrics}")

            if training_set_size >= self.max_training_set:
                print("Done.")
                break

            trained_model = TrainedMCDropoutModel(num_samples=self.num_pool_samples, model=model_optimizer.model)

            if isinstance(acquisition_function, CandidateBatchComputer):
                candidate_batch = acquisition_function.compute_candidate_batch(trained_model, pool_loader, self.device)
            elif isinstance(acquisition_function, EvalCandidateBatchComputer):
                current_max_epochs = iteration_log["training"]["best_epoch"]

                train_eval_model = self.create_train_eval_model(
                    dict(
                        max_epochs=current_max_epochs,
                        training_dataset=active_learning_data.training_dataset,
                        eval_dataset=active_learning_data.pool_dataset,
                        validation_loader=validation_loader,
                        trained_model=trained_model,
                    )
                )

                iteration_log["eval_training"] = {}
                trained_eval_model = train_eval_model(training_log=iteration_log["eval_training"], device=self.device)

                candidate_batch = acquisition_function.compute_candidate_batch(
                    trained_model, trained_eval_model, pool_loader, device=self.device
                )
            else:
                raise ValueError(f"Unknown acquisition function {acquisition_function}!")

            candidate_global_indices = [
                get_base_dataset_index(active_learning_data.pool_dataset, index).index
                for index in candidate_batch.indices
            ]
            candidate_labels = [
                get_target(active_learning_data.base_dataset, index).item() for index in candidate_global_indices
            ]

            iteration_log["acquisition"] = dict(
                indices=candidate_global_indices, labels=candidate_labels, scores=candidate_batch.scores
            )

            active_learning_data.acquire(candidate_batch.indices)

            ls = ", ".join(f"{label} ({score:.4})" for label, score in zip(candidate_labels, candidate_batch.scores))
            print(f"Acquiring (label, score)s: {ls}")

In [None]:
# experiment

experiment = Experiment(
    seed=1120,
    max_training_epochs=30,
    max_training_set=130,
    acquisition_function=acquisition_functions.CoreSetPIG,
    acquisition_size=10,
    num_pool_samples=20,
    temperature=8,
    device="cuda",
)

results = {}
experiment.run(results)

Creating: CoreSetPIG(
	acquisition_size=10
)
Training set size 20:


  3%|3         | 1/30 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.58984375, 'crossentropy': 1.861675076186657}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.609375, 'crossentropy': 1.8966327086091042}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.61328125, 'crossentropy': 2.2349564135074615}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.638671875, 'crossentropy': 2.2560648694634438}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.62646484375, 'crossentropy': 2.5049142464995384}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.63427734375, 'crossentropy': 2.484818249940872}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6298828125, 'crossentropy': 2.687239721417427}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.63427734375, 'crossentropy': 2.723147451877594}
RestoringEarlyStopping: 4 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.646484375, 'crossentropy': 2.708091840147972}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.63720703125, 'crossentropy': 3.0899680256843567}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6435546875, 'crossentropy': 3.0240620225667953}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6396484375, 'crossentropy': 3.111246943473816}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.63037109375, 'crossentropy': 3.20389886200428}
RestoringEarlyStopping: 4 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.63818359375, 'crossentropy': 3.054595798254013}
RestoringEarlyStopping: 5 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.64404296875, 'crossentropy': 3.118549272418022}
RestoringEarlyStopping: 6 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.64306640625, 'crossentropy': 3.1763417422771454}
RestoringEarlyStopping: 7 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6376953125, 'crossentropy': 3.3015053272247314}
RestoringEarlyStopping: 8 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.63134765625, 'crossentropy': 3.177192836999893}
RestoringEarlyStopping: 9 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6357421875, 'crossentropy': 3.133062794804573}
RestoringEarlyStopping: 10 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.63671875, 'crossentropy': 3.0634772777557373}
RestoringEarlyStopping: 11 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.63671875, 'crossentropy': 3.2152883112430573}
RestoringEarlyStopping: 12 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.63720703125, 'crossentropy': 3.1880739629268646}
RestoringEarlyStopping: 13 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.63818359375, 'crossentropy': 3.2935092598199844}
RestoringEarlyStopping: 14 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.63916015625, 'crossentropy': 3.2426420748233795}
RestoringEarlyStopping: 15 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.64306640625, 'crossentropy': 3.394015669822693}
RestoringEarlyStopping: 16 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.63623046875, 'crossentropy': 3.478360116481781}
RestoringEarlyStopping: 17 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.634765625, 'crossentropy': 3.3266737908124924}
RestoringEarlyStopping: 18 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.63720703125, 'crossentropy': 3.433752804994583}
RestoringEarlyStopping: 19 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6318359375, 'crossentropy': 3.5318945348262787}
RestoringEarlyStopping: 20 / 20
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: 0.646484375)
RestoringEarlyStopping: Restoring optimizer.


[1/79]   1%|1          [00:00<?]

Perf after training {'accuracy': 0.6592, 'crossentropy': 2.7150737060546875}
Creating: TrainSelfDistillationEvalModel(
	num_pool_samples=20,
	num_training_samples=1,
	num_validation_samples=20,
	num_patience_epochs=20,
	max_epochs=9,
	training_dataset=<torch.utils.data.dataset.Subset object at 0x7feff9a91a00>,
	eval_dataset=<torch.utils.data.dataset.Subset object at 0x7feff9a91160>,
	validation_loader=<torch.utils.data.dataloader.DataLoader object at 0x7feff9a91c40>,
	training_batch_size=64,
	model_optimizer_factory=<class 'batchbald_redux.models.MnistOptimizerFactory'>,
	trained_model=TrainedMCDropoutModel(num_samples=20, model=BayesianMNISTCNN(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv1_drop): ConsistentMCDropout2d(p=0.5)
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): ConsistentMCDropout2d(p=0.5)
  (fc1): Linear(in_features=1024, out_features=128, bias=True)
  (fc1_drop): ConsistentMCDropout(p=0.5)
  (fc2): Linear(in_features=12

get_predictions_labels:   0%|          | 0/1159040 [00:00<?, ?it/s]

 11%|#1        | 1/9 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.43017578125, 'crossentropy': 1.8302679657936096}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.638671875, 'crossentropy': 1.3730441853404045}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.67138671875, 'crossentropy': 1.2699518576264381}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6455078125, 'crossentropy': 1.3164364397525787}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.65625, 'crossentropy': 1.3247218653559685}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6318359375, 'crossentropy': 1.337683230638504}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.64404296875, 'crossentropy': 1.3481997102499008}
RestoringEarlyStopping: 4 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6533203125, 'crossentropy': 1.2842165678739548}
RestoringEarlyStopping: 5 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.65283203125, 'crossentropy': 1.267775110900402}
RestoringEarlyStopping: 6 / 20
RestoringEarlyStopping: Restoring best parameters. (Score: 0.67138671875)
RestoringEarlyStopping: Restoring optimizer.


get_predictions_labels:   0%|          | 0/1158640 [00:00<?, ?it/s]

get_predictions_labels:   0%|          | 0/1158640 [00:00<?, ?it/s]

Acquiring (label, score)s: 8 (11.27), 8 (10.29), 8 (9.887), 8 (9.7), 8 (9.488), 2 (9.172), 8 (9.05), 8 (9.03), 8 (9.028), 8 (8.879)
Training set size 30:


  3%|3         | 1/30 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.09814453125, 'crossentropy': 2.36720934510231}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.47412109375, 'crossentropy': 1.9964702874422073}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.5673828125, 'crossentropy': 2.023635044693947}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.59716796875, 'crossentropy': 2.4468328282237053}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.61376953125, 'crossentropy': 2.475491687655449}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.60595703125, 'crossentropy': 2.7150914072990417}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.5986328125, 'crossentropy': 3.1024139523506165}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.60546875, 'crossentropy': 3.224428102374077}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.61474609375, 'crossentropy': 3.174846813082695}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.62109375, 'crossentropy': 3.1456754207611084}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.59912109375, 'crossentropy': 3.4486187547445297}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.5927734375, 'crossentropy': 3.461155340075493}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.61669921875, 'crossentropy': 3.534189835190773}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.62109375, 'crossentropy': 3.5443670451641083}
RestoringEarlyStopping: 4 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.609375, 'crossentropy': 3.4280517399311066}
RestoringEarlyStopping: 5 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6240234375, 'crossentropy': 3.2481976002454758}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.603515625, 'crossentropy': 3.803278312087059}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.60693359375, 'crossentropy': 3.457992807030678}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.619140625, 'crossentropy': 3.5223197788000107}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6083984375, 'crossentropy': 3.768906906247139}
RestoringEarlyStopping: 4 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6162109375, 'crossentropy': 3.76054884493351}
RestoringEarlyStopping: 5 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.609375, 'crossentropy': 3.9079986810684204}
RestoringEarlyStopping: 6 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6181640625, 'crossentropy': 3.8486253023147583}
RestoringEarlyStopping: 7 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.61181640625, 'crossentropy': 3.7971579879522324}
RestoringEarlyStopping: 8 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.60791015625, 'crossentropy': 3.6871055513620377}
RestoringEarlyStopping: 9 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.61865234375, 'crossentropy': 3.7344785034656525}
RestoringEarlyStopping: 10 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.61376953125, 'crossentropy': 3.744379445910454}
RestoringEarlyStopping: 11 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.60595703125, 'crossentropy': 3.9774679839611053}
RestoringEarlyStopping: 12 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.60791015625, 'crossentropy': 4.051399990916252}
RestoringEarlyStopping: 13 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.630859375, 'crossentropy': 3.5242514312267303}
RestoringEarlyStopping: Restoring best parameters. (Score: 0.630859375)
RestoringEarlyStopping: Restoring optimizer.


[1/79]   1%|1          [00:00<?]

Perf after training {'accuracy': 0.6439, 'crossentropy': 3.3748569875717163}
Creating: TrainSelfDistillationEvalModel(
	num_pool_samples=20,
	num_training_samples=1,
	num_validation_samples=20,
	num_patience_epochs=20,
	max_epochs=30,
	training_dataset=<torch.utils.data.dataset.Subset object at 0x7feff9a91a00>,
	eval_dataset=<torch.utils.data.dataset.Subset object at 0x7feff9a91160>,
	validation_loader=<torch.utils.data.dataloader.DataLoader object at 0x7feff9a91c40>,
	training_batch_size=64,
	model_optimizer_factory=<class 'batchbald_redux.models.MnistOptimizerFactory'>,
	trained_model=TrainedMCDropoutModel(num_samples=20, model=BayesianMNISTCNN(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv1_drop): ConsistentMCDropout2d(p=0.5)
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): ConsistentMCDropout2d(p=0.5)
  (fc1): Linear(in_features=1024, out_features=128, bias=True)
  (fc1_drop): ConsistentMCDropout(p=0.5)
  (fc2): Linear(in_features=1

get_predictions_labels:   0%|          | 0/1159040 [00:00<?, ?it/s]

  3%|3         | 1/30 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.267578125, 'crossentropy': 2.0938870534300804}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.55517578125, 'crossentropy': 1.598021686077118}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6123046875, 'crossentropy': 1.4015748500823975}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.65283203125, 'crossentropy': 1.338459052145481}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.646484375, 'crossentropy': 1.3440247178077698}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.64453125, 'crossentropy': 1.3631226867437363}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.64404296875, 'crossentropy': 1.3815156817436218}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.638671875, 'crossentropy': 1.4001456573605537}
RestoringEarlyStopping: 4 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.650390625, 'crossentropy': 1.3500040546059608}
RestoringEarlyStopping: 5 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.65673828125, 'crossentropy': 1.3488147482275963}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.65625, 'crossentropy': 1.3011836484074593}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6669921875, 'crossentropy': 1.3105402663350105}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.65234375, 'crossentropy': 1.4201809763908386}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.662109375, 'crossentropy': 1.3530398905277252}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.65625, 'crossentropy': 1.3171802461147308}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.65576171875, 'crossentropy': 1.3881082311272621}
RestoringEarlyStopping: 4 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6435546875, 'crossentropy': 1.392128013074398}
RestoringEarlyStopping: 5 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6357421875, 'crossentropy': 1.4009165540337563}
RestoringEarlyStopping: 6 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.66015625, 'crossentropy': 1.3795245438814163}
RestoringEarlyStopping: 7 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.65771484375, 'crossentropy': 1.364939920604229}
RestoringEarlyStopping: 8 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.658203125, 'crossentropy': 1.3537090197205544}
RestoringEarlyStopping: 9 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6669921875, 'crossentropy': 1.3423882573843002}
RestoringEarlyStopping: 10 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.63818359375, 'crossentropy': 1.4446620345115662}
RestoringEarlyStopping: 11 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.66015625, 'crossentropy': 1.311586432158947}
RestoringEarlyStopping: 12 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.65869140625, 'crossentropy': 1.4008911326527596}
RestoringEarlyStopping: 13 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.65576171875, 'crossentropy': 1.349029116332531}
RestoringEarlyStopping: 14 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.65087890625, 'crossentropy': 1.3849503919482231}
RestoringEarlyStopping: 15 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.65283203125, 'crossentropy': 1.3950904682278633}
RestoringEarlyStopping: 16 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.65380859375, 'crossentropy': 1.4009385406970978}
RestoringEarlyStopping: 17 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.64892578125, 'crossentropy': 1.374249517917633}
RestoringEarlyStopping: 18 / 20
RestoringEarlyStopping: Restoring best parameters. (Score: 0.6669921875)
RestoringEarlyStopping: Restoring optimizer.


get_predictions_labels:   0%|          | 0/1158440 [00:00<?, ?it/s]

get_predictions_labels:   0%|          | 0/1158440 [00:00<?, ?it/s]

Acquiring (label, score)s: 2 (15.77), 3 (13.1), 9 (12.49), 9 (12.41), 9 (12.3), 2 (12.27), 3 (12.2), 9 (12.14), 5 (11.54), 3 (11.23)
Training set size 40:


  3%|3         | 1/30 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.111328125, 'crossentropy': 2.3318781703710556}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.34912109375, 'crossentropy': 1.8821969851851463}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.537109375, 'crossentropy': 1.7203078269958496}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6591796875, 'crossentropy': 1.5206803306937218}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.6943359375, 'crossentropy': 1.5637250021100044}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.716796875, 'crossentropy': 1.616085909307003}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.68603515625, 'crossentropy': 1.8539494574069977}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.70654296875, 'crossentropy': 1.7446252256631851}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.71044921875, 'crossentropy': 1.7009914070367813}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.70947265625, 'crossentropy': 1.7612690478563309}
RestoringEarlyStopping: 4 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.70947265625, 'crossentropy': 1.7257337048649788}
RestoringEarlyStopping: 5 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.716796875, 'crossentropy': 1.8035038113594055}
RestoringEarlyStopping: 6 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.70703125, 'crossentropy': 1.843830369412899}
RestoringEarlyStopping: 7 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.7216796875, 'crossentropy': 1.8283613175153732}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.72607421875, 'crossentropy': 1.832522302865982}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.71875, 'crossentropy': 1.9175538644194603}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.73583984375, 'crossentropy': 1.8122596442699432}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.7197265625, 'crossentropy': 1.9268733710050583}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.71728515625, 'crossentropy': 1.996249035000801}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.7333984375, 'crossentropy': 1.8442267403006554}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.72509765625, 'crossentropy': 1.9454561695456505}
RestoringEarlyStopping: 4 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.724609375, 'crossentropy': 1.9881621971726418}
RestoringEarlyStopping: 5 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.7138671875, 'crossentropy': 2.0779401287436485}
RestoringEarlyStopping: 6 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.72265625, 'crossentropy': 2.0707436874508858}
RestoringEarlyStopping: 7 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.70751953125, 'crossentropy': 2.163809895515442}
RestoringEarlyStopping: 8 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.71435546875, 'crossentropy': 2.225319117307663}
RestoringEarlyStopping: 9 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.70556640625, 'crossentropy': 2.332764506340027}
RestoringEarlyStopping: 10 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.724609375, 'crossentropy': 2.1506718322634697}
RestoringEarlyStopping: 11 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.7255859375, 'crossentropy': 2.0951114147901535}
RestoringEarlyStopping: 12 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.73193359375, 'crossentropy': 2.06838371604681}
RestoringEarlyStopping: 13 / 20
RestoringEarlyStopping: Restoring best parameters. (Score: 0.73583984375)
RestoringEarlyStopping: Restoring optimizer.


[1/79]   1%|1          [00:00<?]

Perf after training {'accuracy': 0.7359, 'crossentropy': 1.8007461918830872}
Creating: TrainSelfDistillationEvalModel(
	num_pool_samples=20,
	num_training_samples=1,
	num_validation_samples=20,
	num_patience_epochs=20,
	max_epochs=17,
	training_dataset=<torch.utils.data.dataset.Subset object at 0x7feff9a91a00>,
	eval_dataset=<torch.utils.data.dataset.Subset object at 0x7feff9a91160>,
	validation_loader=<torch.utils.data.dataloader.DataLoader object at 0x7feff9a91c40>,
	training_batch_size=64,
	model_optimizer_factory=<class 'batchbald_redux.models.MnistOptimizerFactory'>,
	trained_model=TrainedMCDropoutModel(num_samples=20, model=BayesianMNISTCNN(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv1_drop): ConsistentMCDropout2d(p=0.5)
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): ConsistentMCDropout2d(p=0.5)
  (fc1): Linear(in_features=1024, out_features=128, bias=True)
  (fc1_drop): ConsistentMCDropout(p=0.5)
  (fc2): Linear(in_features=1

get_predictions_labels:   0%|          | 0/1159040 [00:00<?, ?it/s]

  6%|5         | 1/17 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.45556640625, 'crossentropy': 1.9962274953722954}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Exception ignored in: <function tqdm.__del__ at 0x7feffda580d0>
Traceback (most recent call last):
  File "/home/blackhc/anaconda3/envs/active_learning/lib/python3.8/site-packages/tqdm/std.py", line 1134, in __del__
    def __del__(self):
KeyboardInterrupt: 


Epoch metrics: {'accuracy': 0.732421875, 'crossentropy': 1.235934503376484}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.74267578125, 'crossentropy': 1.0898063816130161}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.74755859375, 'crossentropy': 0.9996308349072933}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.76416015625, 'crossentropy': 0.9109205268323421}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.74951171875, 'crossentropy': 0.9393021687865257}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Exception ignored in: <function tqdm.__del__ at 0x7feffda580d0>
Traceback (most recent call last):
  File "/home/blackhc/anaconda3/envs/active_learning/lib/python3.8/site-packages/tqdm/std.py", line 1134, in __del__
    def __del__(self):
KeyboardInterrupt: 


Epoch metrics: {'accuracy': 0.7666015625, 'crossentropy': 0.8904788084328175}


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.76416015625, 'crossentropy': 0.900515228509903}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.75732421875, 'crossentropy': 0.9173464551568031}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.76025390625, 'crossentropy': 0.8849859461188316}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.7626953125, 'crossentropy': 0.8823465593159199}
RestoringEarlyStopping: 4 / 20


[1/16]   6%|6          [00:00<?]

[1/16]   6%|6          [00:00<?]

Epoch metrics: {'accuracy': 0.755859375, 'crossentropy': 0.8908888138830662}
RestoringEarlyStopping: 5 / 20


[1/16]   6%|6          [00:00<?]

Engine run is terminating due to exception: .


KeyboardInterrupt: 

In [None]:
results

{'initial_training_set_indices': [38043,
  40091,
  17418,
  2094,
  39879,
  3133,
  5011,
  40683,
  54379,
  24287,
  9849,
  59305,
  39508,
  39356,
  8758,
  52579,
  13655,
  7636,
  21562,
  41329],
 'dataset_info': {'training': "'FastMNIST (Train)'",
  'test': "'FastMNIST (Test)'"},
 'active_learning_steps': [{'training': {'epochs': [{'accuracy': 0.62109375,
      'crossentropy': 2.6530187726020813},
     {'accuracy': 0.6376953125, 'crossentropy': 2.762658029794693},
     {'accuracy': 0.646484375, 'crossentropy': 3.056214064359665},
     {'accuracy': 0.6416015625, 'crossentropy': 3.1257119178771973}],
    'best_epoch': 1},
   'evaluation_metrics': {'accuracy': 0.631,
    'crossentropy': 2.6251225173950195}}]}

In [None]:
# experiment

experiment = Experiment(
    max_training_epochs=1, max_training_set=25, acquisition_function=AcquisitionFunction.randombaldical
)

results = {}
experiment.run(results)

results

Training set size 20:


100%|##########| 1/1 [00:00<?, ?it/s]

[1/384]   0%|           [00:00<?]

[1/64]   2%|1          [00:00<?]

RestoringEarlyStopping: Restoring best parameters. (Score: -6.529030114412308)
RestoringEarlyStopping: Restoring optimizer.


[1/157]   1%|           [00:00<?]

Perf after training {'accuracy': 0.5367, 'crossentropy': 6.438035237884521}


get_predictions:   0%|          | 0/463616 [00:00<?, ?it/s]

100%|##########| 1/1 [00:00<?, ?it/s]

[1/1811]   0%|           [00:00<?]

[1/64]   2%|1          [00:00<?]

RestoringEarlyStopping: Restoring best parameters. (Score: -5.1637596152722836)
RestoringEarlyStopping: Restoring optimizer.


get_predictions:   0%|          | 0/2317680 [00:00<?, ?it/s]

get_predictions:   0%|          | 0/2317680 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Acquiring (label, score)s: 8 (0.8711), 8 (0.8687), 3 (0.876), 3 (0.8465), 3 (0.8811)
Training set size 25:


100%|##########| 1/1 [00:00<?, ?it/s]

[1/384]   0%|           [00:00<?]

[1/64]   2%|1          [00:00<?]

RestoringEarlyStopping: Restoring best parameters. (Score: -4.6851686127483845)
RestoringEarlyStopping: Restoring optimizer.


[1/157]   1%|           [00:00<?]

Perf after training {'accuracy': 0.6256, 'crossentropy': 4.484497045135498}
Done.


{'initial_training_set_indices': [38043,
  40091,
  17418,
  2094,
  39879,
  3133,
  5011,
  40683,
  54379,
  24287,
  9849,
  59305,
  39508,
  39356,
  8758,
  52579,
  13655,
  7636,
  21562,
  41329],
 'active_learning_steps': [{'training': {'epochs': [{'accuracy': 0.538818359375,
      'crossentropy': 6.529030114412308}],
    'best_epoch': 1},
   'evalution_metrics': {'accuracy': 0.5367,
    'crossentropy': 6.438035237884521},
   'pool_training': {'epochs': [{'accuracy': 0.531005859375,
      'crossentropy': 5.1637596152722836}],
    'best_epoch': 1},
   'acquisition': {'indices': [63338, 10856, 63452, 81864, 109287],
    'labels': [8, 8, 3, 3, 3],
    'scores': [0.8710822958846325,
     0.8687216999221631,
     0.8759664372823723,
     0.8464646732511746,
     0.8810812784952251]}},
  {'training': {'epochs': [{'accuracy': 0.62255859375,
      'crossentropy': 4.6851686127483845}],
    'best_epoch': 1},
   'evalution_metrics': {'accuracy': 0.6256,
    'crossentropy': 4.4844970451

In [None]:
if False:
    configs = (
        [
            Experiment(
                seed=seed + 120,
                acquisition_function=acquisition_function,
                acquisition_size=acquisition_size,
                num_pool_samples=100,
            )
            for seed in range(5)
            for acquisition_size in [5, 10]
            for acquisition_function in [AcquisitionFunction.batchbald, AcquisitionFunction.batchbaldical]
        ]
        + [
            Experiment(
                seed=seed + 120,
                acquisition_function=acquisition_function,
                acquisition_size=acquisition_size,
                num_pool_samples=max(20, acquisition_size),
            )
            for seed in range(5)
            for acquisition_size in [5, 10, 20, 50]
            for acquisition_function in [
                AcquisitionFunction.bald,
                AcquisitionFunction.thompsonbald,
                AcquisitionFunction.randombald,
            ]
        ]
        + [
            Experiment(
                seed=seed + 120,
                acquisition_function=AcquisitionFunction.random,
                acquisition_size=5,
                num_pool_samples=20,
            )
            for seed in range(40)
            for acquisition_size in [5]
        ]
    )
if False:
    configs = [
        Experiment(
            seed=seed + 240,
            acquisition_function=acquisition_function,
            acquisition_size=acquisition_size,
            num_pool_samples=max(20, acquisition_size),
        )
        for seed in range(5)
        for acquisition_size in [5, 10, 20, 50]
        for acquisition_function in [
            AcquisitionFunction.baldical,
            AcquisitionFunction.randombaldical,
        ]
    ]
if False:
    configs = [
        Experiment(
            seed=seed + 340,
            acquisition_function=acquisition_function,
            acquisition_size=acquisition_size,
            num_pool_samples=20,
            temperature=temperature,
        )
        for seed in range(5)
        for acquisition_size in [5, 10, 20, 50]
        for acquisition_function in [
            AcquisitionFunction.temperedbald,
        ]
        for temperature in [13, 15, 18]
    ]
if False:
    configs = [
        Experiment(
            seed=seed + 340,
            acquisition_function=acquisition_function,
            acquisition_size=acquisition_size,
            num_pool_samples=20,
            temperature=temperature,
        )
        for seed in range(5)
        for acquisition_size in [5, 10, 20, 50]
        for acquisition_function in [
            AcquisitionFunction.temperedbald,
        ]
        for temperature in [8, 10]
    ] + [
        Experiment(
            seed=seed + 340,
            acquisition_function=acquisition_function,
            acquisition_size=acquisition_size,
            num_pool_samples=20,
            temperature=temperature,
        )
        for seed in range(5)
        for acquisition_size in [5, 10, 20, 50]
        for acquisition_function in [
            AcquisitionFunction.temperedbaldical,
        ]
        for temperature in [8, 10, 13]
    ]
if False:
    configs = [
        Experiment(
            seed=seed + 500,
            acquisition_function=acquisition_function,
            acquisition_size=acquisition_size,
            num_pool_samples=20,
            temperature=temperature,
        )
        for seed in range(5)
        for acquisition_size in [5, 10, 20, 50]
        for acquisition_function in [
            AcquisitionFunction.temperedical,
        ]
        for temperature in [5, 8, 11]
    ] + [
        Experiment(
            seed=seed + 600,
            acquisition_function=acquisition_function,
            acquisition_size=acquisition_size,
            num_pool_samples=20,
        )
        for seed in range(5)
        for acquisition_size in [5, 10, 20, 50]
        for acquisition_function in [
            AcquisitionFunction.ical,
        ]
    ]
if False:
    configs = (
        [
            Experiment(
                seed=seed + 1000,
                acquisition_function=acquisition_function,
                acquisition_size=acquisition_size,
                num_pool_samples=num_pool_samples,
                temperature=temperature,
            )
            for seed in range(5)
            for acquisition_size in [5, 10, 20, 50]
            for num_pool_samples in [20, 100]
            for acquisition_function in [
                AcquisitionFunction.temperedical,
                AcquisitionFunction.temperedbald,
                AcquisitionFunction.temperedbaldical,
            ]
            for temperature in [2, 5, 8]
        ]
        + [
            Experiment(
                seed=seed + 2000,
                acquisition_function=acquisition_function,
                acquisition_size=acquisition_size,
                num_pool_samples=num_pool_samples,
            )
            for seed in range(5)
            for acquisition_size in [5, 10, 20, 50]
            for num_pool_samples in [20, 100]
            for acquisition_function in [
                AcquisitionFunction.ical,
                AcquisitionFunction.bald,
                AcquisitionFunction.baldical,
                AcquisitionFunction.randombald,
            ]
        ]
        + [
            Experiment(
                seed=seed + 2000,
                acquisition_function=acquisition_function,
                acquisition_size=acquisition_size,
                num_pool_samples=max(num_pool_samples, acquisition_size),
            )
            for seed in range(5)
            for acquisition_size in [5, 10, 20, 50]
            for num_pool_samples in [20, 100]
            for acquisition_function in [
                AcquisitionFunction.thompsonbald,
            ]
        ]
        + [
            Experiment(
                seed=seed + 3000,
                acquisition_function=acquisition_function,
                acquisition_size=acquisition_size,
                num_pool_samples=100,
            )
            for seed in range(5)
            for acquisition_size in [5]
            for acquisition_function in [
                AcquisitionFunction.batchbaldical,
                AcquisitionFunction.batchbald,
            ]
        ]
    )
if False:
    configs = [
    Experiment(
        seed=seed,
        acquisition_function=acquisition_functions.BALD,
        acquisition_size=acquisition_size,
        num_pool_samples=num_pool_samples,
    )
    for seed in range(5)
    for acquisition_size in [5, 10, 20, 50]
    for num_pool_samples in [10, 20, 50, 100]
] + [
    Experiment(
        seed=seed,
        acquisition_function=acquisition_functions.Random,
        acquisition_size=5,
        num_pool_samples=20,
    )
    for seed in range(20)
]


In [None]:
# exports

configs = [
    Experiment(
        seed=seed + 789,
        acquisition_function=acquisition_function,
        acquisition_size=acquisition_size,
        num_pool_samples=num_pool_samples,
        repeated_mnist_repetitions=repeated_mnist_repetitions
    )
    for seed in range(5)
    for acquisition_function in [
                acquisition_functions.EIG,
                acquisition_functions.BALD,
                acquisition_functions.EvalBALD,
                acquisition_functions.BatchBALD,
                acquisition_functions.BatchEvalBALD,
                acquisition_functions.BatchEIG
            ]
    for acquisition_size in [5]
    for num_pool_samples in [100]
    for repeated_mnist_repetitions in [1,2]
]

if not is_run_from_ipython() and __name__ == "__main__":
    for job_id, store in embedded_experiments(__file__, len(configs)):
        config = configs[job_id]
        config.seed += job_id
        print(config)
        store["config"] = dataclasses.asdict(config)
        store["log"] = {}

        try:
            config.run(store=store)
        except Exception:
            store["exception"] = traceback.format_exc()
            raise

In [None]:
len(configs)

60

In [None]:
# slow
import prettyprinter

prettyprinter.install_extras(include={"dataclasses"})

prettyprinter.pprint(configs)

[
    Experiment(
        seed=789,
        num_pool_samples=100,
        acquisition_function=batchbald_redux.acquisition_functions.EIG  # class
    ),
    Experiment(
        seed=789,
        num_pool_samples=100,
        repeated_mnist_repetitions=2,
        acquisition_function=batchbald_redux.acquisition_functions.EIG  # class
    ),
    Experiment(seed=789, num_pool_samples=100),
    Experiment(
        seed=789,
        num_pool_samples=100,
        repeated_mnist_repetitions=2
    ),
    Experiment(
        seed=789,
        num_pool_samples=100,
        # class
        acquisition_function=batchbald_redux.acquisition_functions.EvalBALD
    ),
    Experiment(
        seed=789,
        num_pool_samples=100,
        repeated_mnist_repetitions=2,
        # class
        acquisition_function=batchbald_redux.acquisition_functions.EvalBALD
    ),
    Experiment(
        seed=789,
        num_pool_samples=100,
        # class
        acquisition_function=batchbald_redux.acquisition_f