# Experiment
> Can we get better by training on our assumptions?

In [None]:
# default_exp experiment_xmi_pred

In [None]:
# hide
import blackhc.project.script

Import modules and functions were are going to use.

In [None]:
# exports

import dataclasses
import traceback
from dataclasses import dataclass
from typing import Type, Union

import torch
import torch.utils.data
from blackhc.project import is_run_from_ipython
from blackhc.project.experiment import embedded_experiments
from torch.utils.data import Dataset

import batchbald_redux.acquisition_functions as acquisition_functions
from batchbald_redux.acquisition_functions import (
    CandidateBatchComputer,
    EvalModelBatchComputer,
)
from batchbald_redux.active_learning import ActiveLearningData, RandomFixedLengthSampler
from batchbald_redux.black_box_model_training import evaluate_old, train
from batchbald_redux.fast_mnist import FastMNIST
from batchbald_redux.dataset_challenges import (
    create_repeated_MNIST_dataset,
    get_base_dataset_index,
    get_target,
    NamedDataset
)
from batchbald_redux.di import DependencyInjection
from batchbald_redux.model_optimizer_factory import ModelOptimizerFactory
from batchbald_redux.models import MnistOptimizerFactory

In [None]:
# exports

# From the BatchBALD Repo
from batchbald_redux.train_eval_model import (
    TrainEvalModel,
    TrainSelfDistillationEvalModel,
)
from batchbald_redux.trained_model import TrainedBayesianModel

mnist_initial_samples = [
    38043,
    40091,
    17418,
    2094,
    39879,
    3133,
    5011,
    40683,
    54379,
    24287,
    9849,
    59305,
    39508,
    39356,
    8758,
    52579,
    13655,
    7636,
    21562,
    41329,
]

In [None]:
# exports


@dataclass
class Experiment:
    seed: int = 1337
    acquisition_size: int = 5
    max_training_set: int = 300
    num_pool_samples: int = 20
    num_validation_samples: int = 20
    num_training_samples: int = 1
    num_patience_epochs: int = 5*4
    max_training_epochs: int = 30*4
    training_batch_size: int = 64
    device: str = "cuda"
    validation_set_size: int = 1024
    initial_set_size: int = 20
    min_samples_per_epoch: int = 1024
    repeated_mnist_repetitions: int = 1
    add_dataset_noise: bool = False
    acquisition_function: Union[
        Type[CandidateBatchComputer], Type[EvalModelBatchComputer]
    ] = acquisition_functions.BALD
    train_eval_model: TrainEvalModel = TrainSelfDistillationEvalModel
    model_optimizer_factory: Type[ModelOptimizerFactory] = MnistOptimizerFactory
    acquisition_function_args: dict = None
    temperature: float = 0.0

    def load_dataset(self, initial_training_set_indices) -> (ActiveLearningData, Dataset, Dataset):
        train_dataset = NamedDataset(
            FastMNIST("data", train=True, download=True, device=self.device), "FastMNIST (train)"
        )
        train_predictions = torch.load("./data/mnist_train_predictions.pt", map_location=self.device)

        # If we over-sample the train set, we do so after picking the initial train set to avoid duplicates.
        if self.repeated_mnist_repetitions > 1:
            train_dataset = train_dataset * self.repeated_mnist_repetitions

        train_dataset = train_dataset.override_targets(targets=train_predictions)

        if self.add_dataset_noise:
            train_dataset = AdditiveGaussianNoise(train_dataset, 0.1)

        active_learning_data = ActiveLearningData(train_dataset)

        active_learning_data.acquire(initial_training_set_indices)

        validation_dataset = active_learning_data.extract_dataset_from_pool(self.validation_set_size)
        validation_dataset = NamedDataset(validation_dataset, f"FastMNIST (validation, {len(validation_dataset)} samples)")
        validation_dataset = validation_dataset.override_targets(targets=validation_dataset.get_targets().argmax(dim=1))

        test_dataset = FastMNIST("data", train=False, device=None)
        test_dataset = NamedDataset(test_dataset, f"FastMNIST (test, {len(test_dataset)} samples)")

        return active_learning_data, validation_dataset, test_dataset

    # Simple Dependency Injection
    def create_acquisition_function(self):
        di = DependencyInjection(vars(self))
        return di.create_dataclass_type(self.acquisition_function)

    def create_train_eval_model(self, runtime_config) -> TrainEvalModel:
        config = {**vars(self), **runtime_config}
        di = DependencyInjection(config, [])
        return di.create_dataclass_type(self.train_eval_model)

    def run(self, store):
        torch.manual_seed(self.seed)

        if self.initial_set_size == len(mnist_initial_samples):
            initial_training_set_indices = mnist_initial_samples
        elif self.initial_set_size == 0:
            initial_training_set_indices = []
        else:
            raise Exception(f"initial_set_size {initial_set_size} not supported here!")

        store["initial_training_set_indices"] = initial_training_set_indices

        # Active Learning setup
        active_learning_data, validation_dataset, test_dataset = self.load_dataset(initial_training_set_indices)
        store["dataset_info"] = dict(training=repr(active_learning_data.base_dataset), test=repr(test_dataset))

        # initial_training_set_indices = active_learning_data.get_random_pool_indices(self.initial_set_size)
        # initial_training_set_indices = get_balanced_sample_indices(
        #     active_learning_data.pool_dataset, 10, self.initial_set_size // 10
        # )

        train_loader = torch.utils.data.DataLoader(
            active_learning_data.training_dataset,
            batch_size=64,
            sampler=RandomFixedLengthSampler(active_learning_data.training_dataset, self.min_samples_per_epoch),
            drop_last=True,
        )
        pool_loader = torch.utils.data.DataLoader(
            active_learning_data.pool_dataset, batch_size=128, drop_last=False, shuffle=False
        )

        validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=128, drop_last=False)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, drop_last=False)

        store["active_learning_steps"] = []
        active_learning_steps = store["active_learning_steps"]

        acquisition_function = self.create_acquisition_function()

        # Active Training Loop
        while True:
            training_set_size = len(active_learning_data.training_dataset)
            print(f"Training set size {training_set_size}:")

            # iteration_log = dict(training={}, pool_training={}, evaluation_metrics=None, acquisition=None)
            active_learning_steps.append({})
            iteration_log = active_learning_steps[-1]

            iteration_log["training"] = {}

            model_optimizer = self.model_optimizer_factory().create_model_optimizer()

            if training_set_size > 0:
                train(
                    model=model_optimizer.model,
                    optimizer=model_optimizer.optimizer,
                    training_samples=self.num_training_samples,
                    validation_samples=self.num_validation_samples,
                    train_loader=train_loader,
                    validation_loader=validation_loader,
                    patience=self.num_patience_epochs,
                    max_epochs=self.max_training_epochs,
                    device=self.device,
                    training_log=iteration_log["training"],
                    loss = torch.nn.KLDivLoss(log_target=False, reduction="batchmean"),
                    validation_loss=torch.nn.NLLLoss(),
                )

            evaluation_metrics = evaluate_old(
                model=model_optimizer.model,
                num_samples=self.num_validation_samples,
                loader=test_loader,
                device=self.device,
            )
            iteration_log["evaluation_metrics"] = evaluation_metrics
            print(f"Perf after training {evaluation_metrics}")

            if training_set_size >= self.max_training_set:
                print("Done.")
                break

            trained_model = TrainedBayesianModel(model=model_optimizer.model)

            if isinstance(acquisition_function, CandidateBatchComputer):
                candidate_batch = acquisition_function.compute_candidate_batch(trained_model, pool_loader, self.device)
            elif isinstance(acquisition_function, EvalModelBatchComputer):
                current_max_epochs = iteration_log["training"]["best_epoch"]

                train_eval_model = self.create_train_eval_model(
                    dict(
                        max_epochs=current_max_epochs,
                        training_dataset=active_learning_data.training_dataset,
                        eval_dataset=active_learning_data.pool_dataset,
                        validation_loader=validation_loader,
                        trained_model=trained_model,
                    )
                )

                iteration_log["eval_training"] = {}
                trained_eval_model = train_eval_model(training_log=iteration_log["eval_training"], device=self.device)

                candidate_batch = acquisition_function.compute_candidate_batch(
                    trained_model, trained_eval_model, pool_loader, device=self.device
                )
            else:
                raise ValueError(f"Unknown acquisition function {acquisition_function}!")

            candidate_global_indices = [
                get_base_dataset_index(active_learning_data.pool_dataset, index).index
                for index in candidate_batch.indices
            ]
            candidate_labels = [
                get_target(active_learning_data.base_dataset, index).argmax(dim=0).item() for index in candidate_global_indices
            ]

            iteration_log["acquisition"] = dict(
                indices=candidate_global_indices, labels=candidate_labels, scores=candidate_batch.scores
            )

            active_learning_data.acquire(candidate_batch.indices)

            ls = ", ".join(f"{label} ({score:.4})" for label, score in zip(candidate_labels, candidate_batch.scores))
            print(f"Acquiring (label, score)s: {ls}")

In [None]:
# experiment

experiment = Experiment(
    seed=1120,
    max_training_epochs=5,
    num_patience_epochs=5,
    max_training_set=130,
    acquisition_function=acquisition_functions.CoreSetBALD,
    acquisition_size=10,
    num_pool_samples=20,
    initial_set_size=0,
    temperature=8,
    device="cuda",
)

results = {}
experiment.run(results)

Resolved: CoreSetBALD with {'acquisition_size': 10}
Creating: CoreSetBALD(acquisition_size=10)
Training set size 0:


[1/79]   1%|1          [00:00<?]

Perf after training {'accuracy': 0.1047, 'crossentropy': 2.3365583019256593}


get_predictions_labels:   0%|          | 0/1179520 [00:00<?, ?it/s]

Acquiring (label, score)s: 8 (0.06852), 8 (0.06501), 8 (0.06416), 8 (0.06332), 8 (0.06326), 8 (0.06308), 8 (0.06303), 8 (0.06295), 8 (0.06286), 8 (0.0625)
Training set size 10:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.0830078125, 'crossentropy': 80.1539077758789}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.0830078125, 'crossentropy': 73.50736045837402}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.0830078125, 'crossentropy': 67.76939964294434}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.0830078125, 'crossentropy': 60.18304920196533}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.0830078125, 'crossentropy': 53.462871074676514}
RestoringEarlyStopping: Restoring best parameters. (Score: -53.462871074676514)
RestoringEarlyStopping: Restoring optimizer.


[1/79]   1%|1          [00:00<?]

Perf after training {'accuracy': 0.0974, 'crossentropy': 52.69513300170898}


get_predictions_labels:   0%|          | 0/1179320 [00:00<?, ?it/s]

Acquiring (label, score)s: 0 (2.996), 0 (2.996), 0 (2.996), 0 (2.996), 0 (2.996), 0 (2.996), 0 (2.996), 0 (2.996), 0 (2.996), 0 (2.996)
Training set size 20:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.171875, 'crossentropy': 18.25990891456604}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.1728515625, 'crossentropy': 16.943273305892944}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.1728515625, 'crossentropy': 15.827420234680176}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.171875, 'crossentropy': 15.969228029251099}
RestoringEarlyStopping: 1 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.1728515625, 'crossentropy': 14.921754360198975}
RestoringEarlyStopping: Restoring best parameters. (Score: -14.921754360198975)
RestoringEarlyStopping: Restoring optimizer.


[1/79]   1%|1          [00:00<?]

Perf after training {'accuracy': 0.1903, 'crossentropy': 15.213939340209961}


get_predictions_labels:   0%|          | 0/1179120 [00:00<?, ?it/s]

Acquiring (label, score)s: 1 (2.939), 4 (2.938), 1 (2.93), 1 (2.928), 9 (2.927), 4 (2.926), 1 (2.924), 1 (2.923), 9 (2.923), 1 (2.922)
Training set size 30:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.337890625, 'crossentropy': 8.328922271728516}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.3349609375, 'crossentropy': 9.288024067878723}
RestoringEarlyStopping: 1 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.3388671875, 'crossentropy': 8.356063544750214}
RestoringEarlyStopping: 2 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.3427734375, 'crossentropy': 8.738175511360168}
RestoringEarlyStopping: 3 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.3447265625, 'crossentropy': 8.88105833530426}
RestoringEarlyStopping: 4 / 5
RestoringEarlyStopping: Restoring best parameters. (Score: -8.328922271728516)
RestoringEarlyStopping: Restoring optimizer.


[1/79]   1%|1          [00:00<?]

Perf after training {'accuracy': 0.3565, 'crossentropy': 8.24685298461914}


get_predictions_labels:   0%|          | 0/1178920 [00:00<?, ?it/s]

Acquiring (label, score)s: 3 (2.99), 3 (2.989), 5 (2.988), 5 (2.987), 3 (2.985), 3 (2.985), 3 (2.985), 3 (2.983), 3 (2.983), 5 (2.982)
Training set size 40:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.3212890625, 'crossentropy': 6.385581731796265}


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.3427734375, 'crossentropy': 6.836143791675568}
RestoringEarlyStopping: 1 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.3408203125, 'crossentropy': 7.358673572540283}
RestoringEarlyStopping: 2 / 5


[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.3466796875, 'crossentropy': 7.287662506103516}
RestoringEarlyStopping: 3 / 5


[1/79]   1%|1          [00:00<?]

Engine run is terminating due to exception: .


Epoch metrics: {'accuracy': 0.3603515625, 'crossentropy': 6.9419087171554565}
RestoringEarlyStopping: 4 / 5
RestoringEarlyStopping: Restoring best parameters. (Score: -6.385581731796265)
RestoringEarlyStopping: Restoring optimizer.


KeyboardInterrupt: 

In [None]:
results

{'initial_training_set_indices': [38043,
  40091,
  17418,
  2094,
  39879,
  3133,
  5011,
  40683,
  54379,
  24287,
  9849,
  59305,
  39508,
  39356,
  8758,
  52579,
  13655,
  7636,
  21562,
  41329],
 'dataset_info': {'training': "'FastMNIST (Train)'",
  'test': "'FastMNIST (Test)'"},
 'active_learning_steps': [{'training': {'epochs': [{'accuracy': 0.62109375,
      'crossentropy': 2.6530187726020813},
     {'accuracy': 0.6376953125, 'crossentropy': 2.762658029794693},
     {'accuracy': 0.646484375, 'crossentropy': 3.056214064359665},
     {'accuracy': 0.6416015625, 'crossentropy': 3.1257119178771973}],
    'best_epoch': 1},
   'evaluation_metrics': {'accuracy': 0.631,
    'crossentropy': 2.6251225173950195}}]}

In [None]:
# experiment

experiment = Experiment(
    max_training_epochs=1, max_training_set=25, acquisition_function=AcquisitionFunction.randombaldical
)

results = {}
experiment.run(results)

results

Training set size 20:


100%|##########| 1/1 [00:00<?, ?it/s]

[1/384]   0%|           [00:00<?]

[1/64]   2%|1          [00:00<?]

RestoringEarlyStopping: Restoring best parameters. (Score: -6.529030114412308)
RestoringEarlyStopping: Restoring optimizer.


[1/157]   1%|           [00:00<?]

Perf after training {'accuracy': 0.5367, 'crossentropy': 6.438035237884521}


get_predictions:   0%|          | 0/463616 [00:00<?, ?it/s]

100%|##########| 1/1 [00:00<?, ?it/s]

[1/1811]   0%|           [00:00<?]

[1/64]   2%|1          [00:00<?]

RestoringEarlyStopping: Restoring best parameters. (Score: -5.1637596152722836)
RestoringEarlyStopping: Restoring optimizer.


get_predictions:   0%|          | 0/2317680 [00:00<?, ?it/s]

get_predictions:   0%|          | 0/2317680 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Acquiring (label, score)s: 8 (0.8711), 8 (0.8687), 3 (0.876), 3 (0.8465), 3 (0.8811)
Training set size 25:


100%|##########| 1/1 [00:00<?, ?it/s]

[1/384]   0%|           [00:00<?]

[1/64]   2%|1          [00:00<?]

RestoringEarlyStopping: Restoring best parameters. (Score: -4.6851686127483845)
RestoringEarlyStopping: Restoring optimizer.


[1/157]   1%|           [00:00<?]

Perf after training {'accuracy': 0.6256, 'crossentropy': 4.484497045135498}
Done.


{'initial_training_set_indices': [38043,
  40091,
  17418,
  2094,
  39879,
  3133,
  5011,
  40683,
  54379,
  24287,
  9849,
  59305,
  39508,
  39356,
  8758,
  52579,
  13655,
  7636,
  21562,
  41329],
 'active_learning_steps': [{'training': {'epochs': [{'accuracy': 0.538818359375,
      'crossentropy': 6.529030114412308}],
    'best_epoch': 1},
   'evalution_metrics': {'accuracy': 0.5367,
    'crossentropy': 6.438035237884521},
   'pool_training': {'epochs': [{'accuracy': 0.531005859375,
      'crossentropy': 5.1637596152722836}],
    'best_epoch': 1},
   'acquisition': {'indices': [63338, 10856, 63452, 81864, 109287],
    'labels': [8, 8, 3, 3, 3],
    'scores': [0.8710822958846325,
     0.8687216999221631,
     0.8759664372823723,
     0.8464646732511746,
     0.8810812784952251]}},
  {'training': {'epochs': [{'accuracy': 0.62255859375,
      'crossentropy': 4.6851686127483845}],
    'best_epoch': 1},
   'evalution_metrics': {'accuracy': 0.6256,
    'crossentropy': 4.4844970451

In [None]:
# exports

configs = [
    Experiment(
        seed=seed,
        acquisition_function=acquisition_function,
        acquisition_size=acquisition_size,
        num_pool_samples=num_pool_samples,
        max_training_set=150
    )
    for seed in range(5)
    for acquisition_function in [
        acquisition_functions.BALD,
        acquisition_functions.CoreSetBALD,
    ]
    for acquisition_size in [1]
    for num_pool_samples in [100]
]

if not is_run_from_ipython() and __name__ == "__main__":
    for job_id, store in embedded_experiments(__file__, len(configs)):
        config = configs[job_id]
        config.seed += job_id
        print(config)
        store["config"] = dataclasses.asdict(config)
        store["log"] = {}

        try:
            config.run(store=store)
        except Exception:
            store["exception"] = traceback.format_exc()
            raise

In [None]:
len(configs)

40