# Uniform Target OOD Experiment
> Can we get better by training on our assumptions?

In [None]:
# default_exp uniform_target_ood_experiment

In [None]:
# hide
import blackhc.project.script

Appended /home/blackhc/PycharmProjects/bald-ical/src to paths
Switched to directory /home/blackhc/PycharmProjects/bald-ical
%load_ext autoreload
%autoreload 2


Import modules and functions were are going to use.

In [None]:
# exports

import dataclasses
import traceback
from dataclasses import dataclass
from typing import Type, Union

import torch
import torch.utils.data
from blackhc.project import is_run_from_ipython
from blackhc.project.experiment import embedded_experiments
from torch.utils.data import Dataset

import batchbald_redux.acquisition_functions as acquisition_functions
from batchbald_redux.acquisition_functions import (
    CandidateBatchComputer,
    EvalCandidateBatchComputer,
)
from batchbald_redux.active_learning import ActiveLearningData, RandomFixedLengthSampler
from batchbald_redux.black_box_model_training import evaluate, train
from batchbald_redux.dataset_challenges import (
    AdditiveGaussianNoise,
    AliasDataset,
    NamedDataset,
    get_balanced_sample_indices,
    get_base_dataset_index,
    get_target,
)
from batchbald_redux.datasets import train_validation_split
from batchbald_redux.di import DependencyInjection
from batchbald_redux.fast_mnist import FastFashionMNIST, FastMNIST
from batchbald_redux.model_optimizer_factory import ModelOptimizerFactory
from batchbald_redux.models import MnistOptimizerFactory
from batchbald_redux.train_eval_model import (
    TrainEvalModel,
    TrainSelfDistillationPoolModel,
)
from batchbald_redux.trained_model import TrainedMCDropoutModel

In [None]:
# exports


@dataclass
class ExperimentDatasets:
    active_learning: ActiveLearningData
    ood_dataset: NamedDataset
    validation_dataset: torch.utils.data.Dataset
    test_dataset: torch.utils.data.Dataset
    initial_training_set_indices: [int]


@dataclass
class UniformTargetOodExperiment:
    seed: int = 1337
    acquisition_size: int = 5
    max_training_set: int = 450
    num_pool_samples: int = 20
    num_eval_samples: int = 20
    num_training_samples: int = 1
    num_patience_epochs: int = 3
    max_training_epochs: int = 30
    training_batch_size: int = 64
    device: str = "cuda"
    validation_set_size: int = 1024
    validation_split_random_state: int = 0
    initial_training_set_size: int = 20
    samples_per_epoch: int = 5056
    mnist_repetitions: int = 1
    ood_fmnist_repetitions: int = 1
    add_dataset_noise: bool = False
    acquisition_function: Union[
        Type[CandidateBatchComputer], Type[EvalCandidateBatchComputer]
    ] = acquisition_functions.BALD
    train_eval_model: TrainEvalModel = TrainSelfDistillationPoolModel
    model_optimizer_factory: Type[ModelOptimizerFactory] = MnistOptimizerFactory
    acquisition_function_args: dict = None
    temperature: float = 0.0

    def load_dataset(self, initial_training_set_size) -> (ActiveLearningData, Dataset, Dataset):
        # num_classes = 10, input_size = 28
        train_dataset = NamedDataset(
            FastMNIST("data", train=True, download=True, device=self.device), "FastMNIST (train)"
        )
        if self.mnist_repetitions > 1:
            train_dataset = train_dataset * self.mnist_repetitions

        ood_dataset = FastFashionMNIST("data", train=True, download=True, device=self.device)
        ood_dataset = NamedDataset(ood_dataset, f"OoD Dataset ({len(ood_dataset)} samples)")
        if self.ood_fmnist_repetitions > 1:
            ood_dataset = ood_dataset * self.ood_fmnist_repetitions

        train_dataset, validation_dataset = train_validation_split(
            full_train_dataset=train_dataset,
            full_validation_dataset=train_dataset,
            train_labels=train_dataset.get_targets().cpu(),
            validation_set_size=self.validation_set_size,
            validation_split_random_state=self.validation_split_random_state,
        )

        train_dataset = AliasDataset(train_dataset, f"FastMNIST (train; {len(train_dataset)} samples)")
        validation_dataset = AliasDataset(
            validation_dataset, f"FastMNIST (validation; {len(validation_dataset)} samples)"
        )

        num_classes = train_dataset.get_num_classes()
        samples_per_class = initial_training_set_size / num_classes
        initial_training_set_indices = get_balanced_sample_indices(
            train_dataset,
            num_classes=num_classes,
            samples_per_class=samples_per_class,
            seed=self.validation_split_random_state,
        )

        train_dataset = train_dataset.one_hot(device=self.device) + ood_dataset.uniform_target(device=self.device)

        if self.add_dataset_noise:
            train_dataset = AdditiveGaussianNoise(train_dataset, 0.1)

        test_dataset = FastMNIST("data", train=False, device=None)
        test_dataset = NamedDataset(test_dataset, f"FastMNIST (test, {len(test_dataset)} samples)")

        active_learning_data = ActiveLearningData(train_dataset)

        active_learning_data.acquire(initial_training_set_indices)

        return ExperimentDatasets(
            active_learning=active_learning_data,
            validation_dataset=validation_dataset,
            test_dataset=test_dataset,
            initial_training_set_indices=initial_training_set_indices,
            ood_dataset=ood_dataset,
        )

    # Simple Dependency Injection
    def create_acquisition_function(self):
        di = DependencyInjection(vars(self))
        return di.create_dataclass_type(self.acquisition_function)

    def create_train_eval_model(self, runtime_config) -> TrainEvalModel:
        config = {**vars(self), **runtime_config}
        di = DependencyInjection(config, [])
        return di.create_dataclass_type(self.train_eval_model)

    def run(self, store):
        torch.manual_seed(self.seed)

        # Active Learning setup
        data = self.load_dataset(self.initial_training_set_size)
        store["dataset_info"] = dict(training=repr(data.active_learning.base_dataset), test=repr(data.test_dataset))
        store["initial_training_set_indices"] = data.initial_training_set_indices

        # initial_training_set_indices = data.active_learning.get_random_pool_indices(self.initial_set_size)
        # initial_training_set_indices = get_balanced_sample_indices(
        #     data.active_learning.pool_dataset, 10, self.initial_set_size // 10
        # )

        train_loader = torch.utils.data.DataLoader(
            data.active_learning.training_dataset,
            batch_size=64,
            sampler=RandomFixedLengthSampler(data.active_learning.training_dataset, self.samples_per_epoch),
            drop_last=True,
        )
        pool_loader = torch.utils.data.DataLoader(
            data.active_learning.pool_dataset, batch_size=128, drop_last=False, shuffle=False
        )

        validation_loader = torch.utils.data.DataLoader(data.validation_dataset, batch_size=512, drop_last=False)
        test_loader = torch.utils.data.DataLoader(data.test_dataset, batch_size=512, drop_last=False)

        store["active_learning_steps"] = []
        active_learning_steps = store["active_learning_steps"]

        acquisition_function = self.create_acquisition_function()

        # Active Training Loop
        while True:
            training_set_size = len(data.active_learning.training_dataset)
            print(f"Training set size {training_set_size}:")

            # iteration_log = dict(training={}, pool_training={}, evaluation_metrics=None, acquisition=None)
            active_learning_steps.append({})
            iteration_log = active_learning_steps[-1]

            iteration_log["training"] = {}

            model_optimizer = self.model_optimizer_factory().create_model_optimizer()

            loss = torch.nn.KLDivLoss(log_target=False, reduction="batchmean")

            train(
                model=model_optimizer.model,
                optimizer=model_optimizer.optimizer,
                training_samples=self.num_training_samples,
                validation_samples=self.num_eval_samples,
                train_loader=train_loader,
                validation_loader=validation_loader,
                patience=self.num_patience_epochs,
                max_epochs=self.max_training_epochs,
                device=self.device,
                training_log=iteration_log["training"],
                loss=loss,
                validation_loss=torch.nn.NLLLoss(),
            )

            evaluation_metrics = evaluate(
                model=model_optimizer.model, num_samples=self.num_eval_samples, loader=test_loader, device=self.device
            )
            iteration_log["evaluation_metrics"] = evaluation_metrics
            print(f"Perf after training {evaluation_metrics}")

            if training_set_size >= self.max_training_set:
                print("Done.")
                break

            trained_model = TrainedMCDropoutModel(num_pool_samples=self.num_pool_samples, model=model_optimizer.model)

            if isinstance(acquisition_function, CandidateBatchComputer):
                candidate_batch = acquisition_function.compute_candidate_batch(trained_model, pool_loader, self.device)
            elif isinstance(acquisition_function, EvalCandidateBatchComputer):
                current_max_epochs = iteration_log["training"]["best_epoch"]

                train_eval_model = self.create_train_eval_model(
                    dict(
                        max_epochs=current_max_epochs,
                        training_dataset=data.active_learning.training_dataset,
                        pool_dataset=data.active_learning.pool_dataset,
                        validation_loader=validation_loader,
                        trained_model=trained_model,
                    )
                )

                iteration_log["eval_training"] = {}
                trained_eval_model = train_eval_model(training_log=iteration_log["eval_training"], device=self.device)

                candidate_batch = acquisition_function.compute_candidate_batch(
                    trained_model, trained_eval_model, pool_loader, device=self.device
                )
            else:
                raise ValueError(f"Unknown acquisition function {acquisition_function}!")

            candidate_global_dataset_indices = []
            candidate_labels = []
            for index in candidate_batch.indices:
                base_di = get_base_dataset_index(data.active_learning.pool_dataset, index)
                dataset_type = "ood" if base_di.dataset == data.ood_dataset else "id"
                candidate_global_dataset_indices.append((dataset_type, base_di.index))
                label = get_target(base_di.dataset.dataset, base_di.index)
                candidate_labels.append(label)

            iteration_log["acquisition"] = dict(
                indices=candidate_global_dataset_indices, labels=candidate_labels, scores=candidate_batch.scores
            )

            data.active_learning.acquire(candidate_batch.indices)

            print(candidate_batch)
            print(candidate_global_dataset_indices)

            ls = ", ".join(f"{label} ({score:.4})" for label, score in zip(candidate_labels, candidate_batch.scores))
            print(f"Acquiring (label, score)s: {ls}")

In [None]:
# experiment

UniformTargetOodExperiment(device="cpu").load_dataset(20).active_learning.base_dataset

(FastMNIST (train; 58976 samples) | one_hot_targets{'num_classes': 10}) + ('OoD Dataset (60000 samples)' | uniform_targets{'num_classes': 10})

In [None]:
# experiment

experiment = UniformTargetOodExperiment(
    seed=1120,
    max_training_epochs=30,
    num_patience_epochs=5,
    max_training_set=130,
    acquisition_function=acquisition_functions.TemperedBALD,
    acquisition_size=10,
    num_pool_samples=20,
    temperature=8,
    device="cuda",
)

results = {}
experiment.run(results)

Resolved: TemperedBALD with {'acquisition_size': 10, 'temperature': 8}
Creating: TemperedBALD(acquisition_size=10,temperature=8)
Training set size 20:


  3%|3         | 1/30 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.662109375, 'crossentropy': 2.29226815700531}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.6435546875, 'crossentropy': 2.7630242109298706}
RestoringEarlyStopping: 1 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.65625, 'crossentropy': 2.741236925125122}
RestoringEarlyStopping: 2 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.6396484375, 'crossentropy': 3.0881651639938354}
RestoringEarlyStopping: 3 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.642578125, 'crossentropy': 2.6988677978515625}
RestoringEarlyStopping: 4 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.6220703125, 'crossentropy': 2.984955906867981}
RestoringEarlyStopping: 5 / 5
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -2.29226815700531)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.6873, 'crossentropy': 1.9491361656188966}


get_predictions_labels:   0%|          | 0/2379120 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/118956 [00:00<?, ?it/s]

Entropy:   0%|          | 0/118956 [00:00<?, ?it/s]

CandidateBatch(scores=[0.7649296522140503, 0.6463639736175537, 0.7584395408630371, 0.7132210731506348, 0.4925457239151001, 0.8587175011634827, 0.5526180863380432, 0.7923805117607117, 0.8029134273529053, 0.9720259308815002], indices=[28649, 39810, 24821, 52274, 16992, 57124, 42485, 37550, 51575, 23789])
[('id', 34790), ('id', 57689), ('id', 14635), ('id', 2559), ('id', 48296), ('id', 44676), ('id', 7280), ('id', 47833), ('id', 47698), ('id', 4530)]
Acquiring (label, score)s: 9 (0.7649), 3 (0.6464), 3 (0.7584), 9 (0.7132), 8 (0.4925), 8 (0.8587), 4 (0.5526), 9 (0.7924), 7 (0.8029), 7 (0.972)
Training set size 30:


  3%|3         | 1/30 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.6806640625, 'crossentropy': 1.92692232131958}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.6650390625, 'crossentropy': 2.4903862476348877}
RestoringEarlyStopping: 1 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.662109375, 'crossentropy': 3.1092079877853394}
RestoringEarlyStopping: 2 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.6943359375, 'crossentropy': 2.6525484323501587}
RestoringEarlyStopping: 3 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.6826171875, 'crossentropy': 2.904231071472168}
RestoringEarlyStopping: 4 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.6650390625, 'crossentropy': 2.880333423614502}
RestoringEarlyStopping: 5 / 5
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -1.92692232131958)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.6956, 'crossentropy': 1.974534324645996}


get_predictions_labels:   0%|          | 0/2378920 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/118946 [00:00<?, ?it/s]

Entropy:   0%|          | 0/118946 [00:00<?, ?it/s]

CandidateBatch(scores=[0.9801360964775085, 0.9810218811035156, 0.9297162890434265, 1.04594224691391, 0.9509317874908447, 0.9027801156044006, 0.6929494142532349, 0.8805331587791443, 0.9176118969917297, 0.7259587645530701], indices=[12134, 651, 56582, 29503, 533, 42207, 4415, 32130, 58408, 18633])
[('id', 22083), ('id', 57091), ('id', 46400), ('id', 23154), ('id', 3141), ('id', 23019), ('id', 3513), ('id', 49607), ('id', 38579), ('id', 12339)]
Acquiring (label, score)s: 2 (0.9801), 0 (0.981), 2 (0.9297), 0 (1.046), 5 (0.9509), 6 (0.9028), 8 (0.6929), 3 (0.8805), 0 (0.9176), 0 (0.726)
Training set size 40:


  3%|3         | 1/30 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7431640625, 'crossentropy': 1.406740665435791}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.74609375, 'crossentropy': 1.7936604022979736}
RestoringEarlyStopping: 1 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7451171875, 'crossentropy': 1.9130260944366455}
RestoringEarlyStopping: 2 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.751953125, 'crossentropy': 1.8783711791038513}
RestoringEarlyStopping: 3 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.73828125, 'crossentropy': 2.1438169479370117}
RestoringEarlyStopping: 4 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.771484375, 'crossentropy': 2.0849498510360718}
RestoringEarlyStopping: 5 / 5
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -1.406740665435791)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.7467, 'crossentropy': 1.3747571056365966}


get_predictions_labels:   0%|          | 0/2378720 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/118936 [00:00<?, ?it/s]

Entropy:   0%|          | 0/118936 [00:00<?, ?it/s]

CandidateBatch(scores=[0.6628928780555725, 0.9062231779098511, 0.7181797027587891, 0.8345595598220825, 0.9604091048240662, 0.8515303134918213, 0.9072465896606445, 0.854984700679779, 0.8592267632484436, 0.8482197523117065], indices=[38220, 26928, 51065, 18200, 49667, 29363, 52936, 31682, 13228, 54543])
[('id', 18202), ('id', 51431), ('id', 7832), ('id', 39531), ('id', 49824), ('id', 5841), ('id', 517), ('id', 47844), ('id', 58464), ('id', 17263)]
Acquiring (label, score)s: 4 (0.6629), 8 (0.9062), 3 (0.7182), 0 (0.8346), 8 (0.9604), 2 (0.8515), 8 (0.9072), 7 (0.855), 8 (0.8592), 2 (0.8482)
Training set size 50:


  3%|3         | 1/30 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7001953125, 'crossentropy': 1.5207454562187195}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.744140625, 'crossentropy': 1.6204071044921875}
RestoringEarlyStopping: 1 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7353515625, 'crossentropy': 1.707343339920044}
RestoringEarlyStopping: 2 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7392578125, 'crossentropy': 1.9458958506584167}
RestoringEarlyStopping: 3 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.748046875, 'crossentropy': 1.9507938027381897}
RestoringEarlyStopping: 4 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.755859375, 'crossentropy': 2.0034207105636597}
RestoringEarlyStopping: 5 / 5
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -1.5207454562187195)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.7069, 'crossentropy': 1.4323673820495606}


get_predictions_labels:   0%|          | 0/2378520 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/118926 [00:00<?, ?it/s]

Entropy:   0%|          | 0/118926 [00:00<?, ?it/s]

CandidateBatch(scores=[0.6243653297424316, 0.785733163356781, 0.6850082874298096, 0.5655449628829956, 0.8411515951156616, 0.7936561703681946, 0.6278172731399536, 0.5508645176887512, 0.6094920039176941, 0.6971174478530884], indices=[12056, 16243, 39814, 266, 23968, 26111, 31665, 50972, 48065, 15568])
[('id', 51971), ('id', 44807), ('id', 47914), ('id', 56303), ('id', 12449), ('id', 19276), ('id', 22341), ('id', 54577), ('id', 26748), ('id', 35022)]
Acquiring (label, score)s: 7 (0.6244), 7 (0.7857), 0 (0.685), 9 (0.5655), 5 (0.8412), 6 (0.7937), 2 (0.6278), 5 (0.5509), 9 (0.6095), 3 (0.6971)
Training set size 60:


  3%|3         | 1/30 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.744140625, 'crossentropy': 1.1800535917282104}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7607421875, 'crossentropy': 1.2563155889511108}
RestoringEarlyStopping: 1 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7666015625, 'crossentropy': 1.4744848012924194}
RestoringEarlyStopping: 2 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7734375, 'crossentropy': 1.5940645337104797}
RestoringEarlyStopping: 3 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.76953125, 'crossentropy': 1.6297144889831543}
RestoringEarlyStopping: 4 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7705078125, 'crossentropy': 1.7202304005622864}
RestoringEarlyStopping: 5 / 5
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -1.1800535917282104)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.7678, 'crossentropy': 1.164126841545105}


get_predictions_labels:   0%|          | 0/2378320 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/118916 [00:00<?, ?it/s]

Entropy:   0%|          | 0/118916 [00:00<?, ?it/s]

CandidateBatch(scores=[0.5548802018165588, 0.6266323328018188, 0.5661203861236572, 0.4892882704734802, 0.5311686396598816, 0.7994290590286255, 0.7255553007125854, 0.668678879737854, 0.8141413331031799, 0.6423149108886719], indices=[58362, 11279, 56959, 50766, 44170, 14545, 16316, 9500, 46831, 46027])
[('id', 5553), ('id', 58428), ('id', 8099), ('id', 31805), ('id', 13709), ('id', 3988), ('id', 24598), ('id', 43833), ('id', 49348), ('id', 2960)]
Acquiring (label, score)s: 6 (0.5549), 5 (0.6266), 3 (0.5661), 3 (0.4893), 6 (0.5312), 5 (0.7994), 8 (0.7256), 3 (0.6687), 5 (0.8141), 3 (0.6423)
Training set size 70:


  3%|3         | 1/30 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7529296875, 'crossentropy': 1.1912946701049805}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7763671875, 'crossentropy': 1.4222760200500488}
RestoringEarlyStopping: 1 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7744140625, 'crossentropy': 1.5467740893363953}
RestoringEarlyStopping: 2 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.775390625, 'crossentropy': 1.7166051268577576}
RestoringEarlyStopping: 3 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7626953125, 'crossentropy': 1.9085527062416077}
RestoringEarlyStopping: 4 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7587890625, 'crossentropy': 1.7439901232719421}
RestoringEarlyStopping: 5 / 5
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -1.1912946701049805)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.7632, 'crossentropy': 1.1292011716842651}


get_predictions_labels:   0%|          | 0/2378120 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/118906 [00:00<?, ?it/s]

Entropy:   0%|          | 0/118906 [00:00<?, ?it/s]

CandidateBatch(scores=[1.0132839679718018, 0.9294059872627258, 0.7564026117324829, 0.7244990468025208, 0.7026988863945007, 0.44244319200515747, 0.6692291498184204, 0.5551822185516357, 0.6044751405715942, 0.602584958076477], indices=[17980, 56327, 47342, 7249, 33112, 32663, 35730, 39203, 15533, 44575])
[('id', 47322), ('id', 30378), ('id', 56004), ('id', 9588), ('id', 34520), ('id', 12438), ('id', 55856), ('id', 12633), ('id', 26072), ('id', 996)]
Acquiring (label, score)s: 8 (1.013), 9 (0.9294), 8 (0.7564), 7 (0.7245), 6 (0.7027), 4 (0.4424), 9 (0.6692), 8 (0.5552), 1 (0.6045), 6 (0.6026)
Training set size 80:


  3%|3         | 1/30 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7421875, 'crossentropy': 1.1499993205070496}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7900390625, 'crossentropy': 1.2193711400032043}
RestoringEarlyStopping: 1 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7890625, 'crossentropy': 1.3602639436721802}
RestoringEarlyStopping: 2 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.79296875, 'crossentropy': 1.371889889240265}
RestoringEarlyStopping: 3 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7919921875, 'crossentropy': 1.6086080074310303}
RestoringEarlyStopping: 4 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.8037109375, 'crossentropy': 1.5090669989585876}
RestoringEarlyStopping: 5 / 5
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -1.1499993205070496)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.7558, 'crossentropy': 1.1705545278549194}


get_predictions_labels:   0%|          | 0/2377920 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/118896 [00:00<?, ?it/s]

Entropy:   0%|          | 0/118896 [00:00<?, ?it/s]

CandidateBatch(scores=[0.4609565734863281, 0.4971466660499573, 0.6920607089996338, 0.40427231788635254, 0.6079421043395996, 0.4499483108520508, 0.47887396812438965, 0.6789156198501587, 0.5227339267730713, 0.5833992958068848], indices=[104914, 55147, 42192, 90928, 38278, 76195, 1656, 29456, 20884, 59766])
[('ood', 46018), ('id', 30648), ('id', 47624), ('ood', 32032), ('id', 41399), ('ood', 17299), ('id', 7958), ('id', 31046), ('id', 34121), ('ood', 870)]
Acquiring (label, score)s: 7 (0.461), 6 (0.4971), 5 (0.6921), 5 (0.4043), 2 (0.6079), 5 (0.4499), 4 (0.4789), 6 (0.6789), 6 (0.5227), 7 (0.5834)
Training set size 90:


  3%|3         | 1/30 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.775390625, 'crossentropy': 0.9796779155731201}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.8115234375, 'crossentropy': 0.924060046672821}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.82421875, 'crossentropy': 0.8886075019836426}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.8232421875, 'crossentropy': 1.1618123054504395}
RestoringEarlyStopping: 1 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.8212890625, 'crossentropy': 1.249982237815857}
RestoringEarlyStopping: 2 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.8203125, 'crossentropy': 1.2220505475997925}
RestoringEarlyStopping: 3 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.8173828125, 'crossentropy': 1.4013068079948425}
RestoringEarlyStopping: 4 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.83203125, 'crossentropy': 1.3042701482772827}
RestoringEarlyStopping: 5 / 5
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -0.8886075019836426)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.8258, 'crossentropy': 0.9901506677627564}


get_predictions_labels:   0%|          | 0/2377720 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/118886 [00:00<?, ?it/s]

Entropy:   0%|          | 0/118886 [00:00<?, ?it/s]

CandidateBatch(scores=[0.785701334476471, 0.6953006386756897, 1.0358776450157166, 0.8161039352416992, 0.6050604283809662, 1.0230085849761963, 0.7517524510622025, 0.8154670000076294, 0.5249642729759216, 0.5878008008003235], indices=[40162, 42070, 24142, 25168, 53707, 7138, 52108, 36687, 58152, 27968])
[('id', 2738), ('id', 7112), ('id', 35970), ('id', 59201), ('id', 8426), ('id', 28389), ('id', 11084), ('id', 55865), ('id', 46645), ('id', 30325)]
Acquiring (label, score)s: 2 (0.7857), 5 (0.6953), 2 (1.036), 5 (0.8161), 5 (0.6051), 2 (1.023), 2 (0.7518), 5 (0.8155), 9 (0.525), 8 (0.5878)
Training set size 100:


  3%|3         | 1/30 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.79296875, 'crossentropy': 1.0250415802001953}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.8232421875, 'crossentropy': 0.9144939482212067}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.814453125, 'crossentropy': 1.0325961112976074}
RestoringEarlyStopping: 1 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.826171875, 'crossentropy': 1.0396987199783325}
RestoringEarlyStopping: 2 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.8154296875, 'crossentropy': 0.9862814545631409}
RestoringEarlyStopping: 3 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.81640625, 'crossentropy': 1.1241838335990906}
RestoringEarlyStopping: 4 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.806640625, 'crossentropy': 1.1315916776657104}
RestoringEarlyStopping: 5 / 5
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -0.9144939482212067)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.8295, 'crossentropy': 0.8799887210845947}


get_predictions_labels:   0%|          | 0/2377520 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/118876 [00:00<?, ?it/s]

Entropy:   0%|          | 0/118876 [00:00<?, ?it/s]

CandidateBatch(scores=[0.8403417468070984, 0.5301047563552856, 0.868882417678833, 0.5900723338127136, 0.7904089093208313, 0.43132004141807556, 0.7150993347167969, 0.5607828944921494, 0.7269548773765564, 0.9900280237197876], indices=[32415, 15509, 5749, 53534, 22589, 54135, 45306, 19137, 61077, 9059])
[('id', 3424), ('id', 14678), ('id', 826), ('id', 32555), ('id', 25695), ('id', 27778), ('id', 37293), ('id', 58861), ('ood', 2198), ('id', 24219)]
Acquiring (label, score)s: 3 (0.8403), 8 (0.5301), 9 (0.8689), 9 (0.5901), 3 (0.7904), 5 (0.4313), 3 (0.7151), 0 (0.5608), 0 (0.727), 3 (0.99)
Training set size 110:


  3%|3         | 1/30 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7763671875, 'crossentropy': 1.0281965136528015}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.830078125, 'crossentropy': 0.8675925433635712}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.8466796875, 'crossentropy': 0.8614071011543274}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.837890625, 'crossentropy': 0.8216311633586884}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.8447265625, 'crossentropy': 0.9421774446964264}
RestoringEarlyStopping: 1 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.837890625, 'crossentropy': 1.0396858155727386}
RestoringEarlyStopping: 2 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.845703125, 'crossentropy': 0.9711623191833496}
RestoringEarlyStopping: 3 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.828125, 'crossentropy': 1.1070913076400757}
RestoringEarlyStopping: 4 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.8388671875, 'crossentropy': 1.071656048297882}
RestoringEarlyStopping: 5 / 5
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -0.8216311633586884)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.8579, 'crossentropy': 0.7984678560256958}


get_predictions_labels:   0%|          | 0/2377320 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/118866 [00:00<?, ?it/s]

Entropy:   0%|          | 0/118866 [00:00<?, ?it/s]

CandidateBatch(scores=[0.7083787322044373, 0.7269266247749329, 0.7572749257087708, 0.8621625304222107, 0.8396031856536865, 0.4663722515106201, 0.7750089168548584, 0.7862884998321533, 0.6026186943054199, 0.6440697908401489], indices=[45725, 26589, 27310, 38568, 28714, 44741, 42114, 19988, 34416, 55194])
[('id', 2980), ('id', 29843), ('id', 59101), ('id', 50905), ('id', 33362), ('id', 53199), ('id', 28674), ('id', 2580), ('id', 37385), ('id', 57301)]
Acquiring (label, score)s: 7 (0.7084), 6 (0.7269), 8 (0.7573), 7 (0.8622), 7 (0.8396), 1 (0.4664), 9 (0.775), 3 (0.7863), 4 (0.6026), 1 (0.6441)
Training set size 120:


  3%|3         | 1/30 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.8603515625, 'crossentropy': 0.8311156928539276}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.8701171875, 'crossentropy': 0.7022132575511932}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.86328125, 'crossentropy': 0.7403413951396942}
RestoringEarlyStopping: 1 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.890625, 'crossentropy': 0.6433476209640503}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.892578125, 'crossentropy': 0.6431189477443695}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.87890625, 'crossentropy': 0.7555327117443085}
RestoringEarlyStopping: 1 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.892578125, 'crossentropy': 0.7210302352905273}
RestoringEarlyStopping: 2 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9013671875, 'crossentropy': 0.6657683253288269}
RestoringEarlyStopping: 3 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.8984375, 'crossentropy': 0.7139799296855927}
RestoringEarlyStopping: 4 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.8994140625, 'crossentropy': 0.6940875053405762}
RestoringEarlyStopping: 5 / 5
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -0.6431189477443695)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.8993, 'crossentropy': 0.6387750835418701}


get_predictions_labels:   0%|          | 0/2377120 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/118856 [00:00<?, ?it/s]

Entropy:   0%|          | 0/118856 [00:00<?, ?it/s]

CandidateBatch(scores=[0.5113729238510132, 0.668445885181427, 0.47866612672805786, 0.7121152877807617, 0.5793507099151611, 0.3499523401260376, 0.9433722794055939, 0.9978058338165283, 0.8205229640007019, 0.6541504561901093], indices=[94053, 29657, 65368, 18575, 16859, 86783, 4233, 13406, 28632, 10913])
[('ood', 35196), ('id', 10038), ('ood', 6509), ('id', 3094), ('id', 44789), ('ood', 27925), ('id', 18884), ('id', 2014), ('id', 10967), ('id', 54574)]
Acquiring (label, score)s: 1 (0.5114), 9 (0.6684), 8 (0.4787), 8 (0.7121), 8 (0.5794), 6 (0.35), 8 (0.9434), 7 (0.9978), 4 (0.8205), 2 (0.6542)
Training set size 130:


  3%|3         | 1/30 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.8427734375, 'crossentropy': 0.8000897169113159}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.8759765625, 'crossentropy': 0.6306926608085632}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.892578125, 'crossentropy': 0.5581420660018921}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.90625, 'crossentropy': 0.5141166895627975}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.896484375, 'crossentropy': 0.6222748458385468}
RestoringEarlyStopping: 1 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9013671875, 'crossentropy': 0.6084293723106384}
RestoringEarlyStopping: 2 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.90234375, 'crossentropy': 0.6095933318138123}
RestoringEarlyStopping: 3 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.8984375, 'crossentropy': 0.6646586060523987}
RestoringEarlyStopping: 4 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.900390625, 'crossentropy': 0.6744061410427094}
RestoringEarlyStopping: 5 / 5
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -0.5141166895627975)
RestoringEarlyStopping: Restoring optimizer.


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.9188, 'crossentropy': 0.49625730819702146}
Done.


In [None]:
# experiment

results

{'dataset_info': {'training': "(FastMNIST (train; 58976 samples)) + ('OoD Dataset (60000 samples)' | constant_target{'target': tensor(-1, device='cuda:0'), 'num_classes': 10})",
  'test': "'FastMNIST (test, 10000 samples)'"},
 'initial_training_set_indices': [30392,
  53434,
  12640,
  8533,
  22304,
  37915,
  58226,
  44119,
  3091,
  14640,
  58125,
  39579,
  43812,
  53689,
  52296,
  46037,
  22015,
  40334,
  57520,
  43803],
 'active_learning_steps': [{'training': {'epochs': [], 'best_epoch': None},
   'evaluation_metrics': {'accuracy': 0.1129,
    'crossentropy': 2.35244740562439},
   'acquisition': {'indices': [('ood', 27822),
     ('id', 32367),
     ('id', 55086),
     ('id', 53929),
     ('id', 48696),
     ('ood', 37815),
     ('ood', 47304),
     ('ood', 28667),
     ('ood', 40866),
     ('ood', 35119)],
    'labels': [-1, 1, 7, 6, 2, -1, -1, -1, -1, -1],
    'scores': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}},
  {'training': {'epochs': [], 'best_epoch': None}

In [None]:
# experiment

experiment = Experiment(
    max_training_epochs=1, max_training_set=25, acquisition_function=AcquisitionFunction.randombaldical
)

results = {}
experiment.run(results)

results

Training set size 20:


100%|##########| 1/1 [00:00<?, ?it/s]

[1/384]   0%|           [00:00<?]

[1/64]   2%|1          [00:00<?]

RestoringEarlyStopping: Restoring best parameters. (Score: -6.529030114412308)
RestoringEarlyStopping: Restoring optimizer.


[1/157]   1%|           [00:00<?]

Perf after training {'accuracy': 0.5367, 'crossentropy': 6.438035237884521}


get_predictions:   0%|          | 0/463616 [00:00<?, ?it/s]

100%|##########| 1/1 [00:00<?, ?it/s]

[1/1811]   0%|           [00:00<?]

[1/64]   2%|1          [00:00<?]

RestoringEarlyStopping: Restoring best parameters. (Score: -5.1637596152722836)
RestoringEarlyStopping: Restoring optimizer.


get_predictions:   0%|          | 0/2317680 [00:00<?, ?it/s]

get_predictions:   0%|          | 0/2317680 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Acquiring (label, score)s: 8 (0.8711), 8 (0.8687), 3 (0.876), 3 (0.8465), 3 (0.8811)
Training set size 25:


100%|##########| 1/1 [00:00<?, ?it/s]

[1/384]   0%|           [00:00<?]

[1/64]   2%|1          [00:00<?]

RestoringEarlyStopping: Restoring best parameters. (Score: -4.6851686127483845)
RestoringEarlyStopping: Restoring optimizer.


[1/157]   1%|           [00:00<?]

Perf after training {'accuracy': 0.6256, 'crossentropy': 4.484497045135498}
Done.


{'initial_training_set_indices': [38043,
  40091,
  17418,
  2094,
  39879,
  3133,
  5011,
  40683,
  54379,
  24287,
  9849,
  59305,
  39508,
  39356,
  8758,
  52579,
  13655,
  7636,
  21562,
  41329],
 'active_learning_steps': [{'training': {'epochs': [{'accuracy': 0.538818359375,
      'crossentropy': 6.529030114412308}],
    'best_epoch': 1},
   'evalution_metrics': {'accuracy': 0.5367,
    'crossentropy': 6.438035237884521},
   'pool_training': {'epochs': [{'accuracy': 0.531005859375,
      'crossentropy': 5.1637596152722836}],
    'best_epoch': 1},
   'acquisition': {'indices': [63338, 10856, 63452, 81864, 109287],
    'labels': [8, 8, 3, 3, 3],
    'scores': [0.8710822958846325,
     0.8687216999221631,
     0.8759664372823723,
     0.8464646732511746,
     0.8810812784952251]}},
  {'training': {'epochs': [{'accuracy': 0.62255859375,
      'crossentropy': 4.6851686127483845}],
    'best_epoch': 1},
   'evalution_metrics': {'accuracy': 0.6256,
    'crossentropy': 4.4844970451

In [None]:
# exports

configs = [
    UniformTargetOodExperiment(
        seed=seed,
        acquisition_function=acquisition_functions.TemperedBALD,
        acquisition_size=acquisition_size,
        num_pool_samples=num_pool_samples,
        temperature=8,
    )
    for seed in range(5)
    for acquisition_size in [5, 10, 20, 50]
    for num_pool_samples in [50]
] + [
    UniformTargetOodExperiment(
        seed=seed,
        acquisition_function=acquisition_functions.Random,
        acquisition_size=5,
    )
    for seed in range(20)
]

if not is_run_from_ipython() and __name__ == "__main__":
    for job_id, store in embedded_experiments(__file__, len(configs)):
        config = configs[job_id]
        config.seed += job_id
        print(config)
        store["config"] = dataclasses.asdict(config)
        store["log"] = {}

        try:
            config.run(store=store)
        except Exception:
            store["exception"] = traceback.format_exc()
            raise

In [None]:
len(configs)

40

In [None]:
configs

[RejectionOodExperiment(seed=0, acquisition_size=5, max_training_set=450, num_pool_samples=50, num_eval_samples=20, num_training_samples=1, num_patience_epochs=3, max_training_epochs=30, training_batch_size=64, device='cuda', validation_set_size=1024, validation_split_random_state=0, initial_training_set_size=20, samples_per_epoch=5056, mnist_repetitions=1, ood_fmnist_repetitions=1, add_dataset_noise=False, acquisition_function=<class 'batchbald_redux.acquisition_functions.TemperedBALD'>, train_eval_model=<class 'batchbald_redux.train_eval_model.TrainSelfDistillationPoolModel'>, model_optimizer_factory=<class 'batchbald_redux.models.MnistOptimizerFactory'>, acquisition_function_args=None, temperature=8),
 RejectionOodExperiment(seed=0, acquisition_size=10, max_training_set=450, num_pool_samples=50, num_eval_samples=20, num_training_samples=1, num_patience_epochs=3, max_training_epochs=30, training_batch_size=64, device='cuda', validation_set_size=1024, validation_split_random_state=0, 