# Extract BALD Scores for the Test Set During Training
> Resistance is futile.

In [None]:
# default_exp extract_bald_scores

In [None]:
# hide
import blackhc.project.script

Appended /home/blackhc/PycharmProjects/bald-ical/src to paths
Switched to directory /home/blackhc/PycharmProjects/bald-ical
%load_ext autoreload
%autoreload 2


Import modules and functions were are going to use.

In [None]:
# exports

import dataclasses
import traceback
from dataclasses import dataclass
from typing import Optional, Type, Union

import torch
import torch.utils.data
from blackhc.project import is_run_from_ipython
from blackhc.project.experiment import embedded_experiment

import batchbald_redux.acquisition_functions.bald
import batchbald_redux.acquisition_functions.candidate_batch_computers
from batchbald_redux import acquisition_functions
from batchbald_redux import baseline_acquisition_functions
from batchbald_redux.acquisition_functions import (
    CandidateBatchComputer,
    EvalDatasetBatchComputer,
    EvalModelBatchComputer,
)
from batchbald_redux.black_box_model_training import evaluate
from batchbald_redux.dataset_challenges import get_base_dataset_index, get_target
from batchbald_redux.di import DependencyInjection
from batchbald_redux.experiment_data import (
    ExperimentData,
    ExperimentDataConfig,
    OoDDatasetConfig,
    StandardExperimentDataConfig,
)
from batchbald_redux.models import MnistModelTrainer
from batchbald_redux.resnet_models import Cifar10ModelTrainer
from batchbald_redux.train_eval_model import (
    TrainEvalModel,
    TrainSelfDistillationEvalModel,
)
from batchbald_redux.trained_model import ModelTrainer, BayesianEnsembleModelTrainer

from batchbald_redux.joint_entropy import compute_entropy
from batchbald_redux.acquisition_functions.bald import get_bald_scores

from batchbald_redux.dataset_challenges import get_base_dataset_index

In [None]:
# exports

@dataclass
class CustomBALD(batchbald_redux.acquisition_functions.bald.BALD):
    def get_candidate_batch(self, log_probs_N_K_C, device) -> batchbald_redux.acquisition_functions\
        .candidate_batch_computers.CandidateBatch:
        self.log_probs_N_K_C = log_probs_N_K_C
        return super().get_candidate_batch(log_probs_N_K_C, device)

    def extract_candidates(self, scores_N) -> batchbald_redux.acquisition_functions.candidate_batch_computers.CandidateBatch:
        self.scores_N = scores_N
        return super().extract_candidates(scores_N)

In [None]:
# exports

@dataclass
class ActiveLearner:
    acquisition_size: int
    max_training_set: int

    num_validation_samples: int
    num_pool_samples: int

    acquisition_function: Union[CandidateBatchComputer, EvalModelBatchComputer]
    train_eval_model: TrainEvalModel
    model_trainer: ModelTrainer
    data: ExperimentData

    disable_training_augmentations: bool

    device: Optional

    def __call__(self, log):
        log["seed"] = torch.seed()

        # Active Learning setup
        data = self.data

        train_augmentations = data.train_augmentations if not self.disable_training_augmentations else None

        model_trainer = self.model_trainer
        train_eval_model = self.train_eval_model

        train_loader = model_trainer.get_train_dataloader(data.active_learning.training_dataset)
        pool_loader = model_trainer.get_evaluation_dataloader(data.active_learning.pool_dataset)
        validation_loader = model_trainer.get_evaluation_dataloader(data.validation_dataset)
        test_loader = model_trainer.get_evaluation_dataloader(data.test_dataset)

        log["active_learning_steps"] = []
        active_learning_steps = log["active_learning_steps"]

        acquisition_function = self.acquisition_function

        num_iterations = 0
        max_iterations = int(
            1.5 * (self.max_training_set - len(data.active_learning.training_dataset)) / self.acquisition_size
        )

        # Active Training Loop
        while True:
            training_set_size = len(data.active_learning.training_dataset)
            print(f"Training set size {training_set_size}:")

            # iteration_log = dict(training={}, pool_training={}, evaluation_metrics=None, acquisition=None)
            active_learning_steps.append({})
            iteration_log = active_learning_steps[-1]

            iteration_log["training"] = {}

            # TODO: this is a hack! :(
            if data.ood_dataset is None:
                loss = validation_loss = torch.nn.NLLLoss()
            elif data.ood_exposure:
                loss = torch.nn.KLDivLoss(log_target=False, reduction="batchmean")
                validation_loss = torch.nn.NLLLoss()
            else:
                loss = validation_loss = torch.nn.NLLLoss()

            trained_model = model_trainer.get_trained(
                train_loader=train_loader,
                train_augmentations=train_augmentations,
                validation_loader=validation_loader,
                log=iteration_log["training"],
                loss=loss,
                validation_loss=validation_loss,
            )

            evaluation_metrics = evaluate(
                model=trained_model,
                num_samples=self.num_validation_samples,
                loader=test_loader,
                device=self.device,
                storage_device="cpu",
            )
            iteration_log["evaluation_metrics"] = evaluation_metrics
            print(f"Perf after training {evaluation_metrics}")

            if training_set_size >= self.max_training_set or num_iterations >= max_iterations:
                print("Done.")
                break

            candidate_batch = acquisition_function.compute_candidate_batch(trained_model, pool_loader, self.device)

            # Compute BALD scores on the test set.
            log_test_probs_N_K_C = trained_model.get_log_probs_N_K_C(test_loader, self.num_pool_samples, self.device, "cpu")
            bald_test_scores_N = get_bald_scores(log_test_probs_N_K_C, dtype=torch.double, device="cpu")
            entropy_test_scores_N = compute_entropy(log_test_probs_N_K_C)
            entropy_training_scores_N = compute_entropy(acquisition_function.log_probs_N_K_C)

            iteration_log["bald_training_scores"] = acquisition_function.scores_N
            iteration_log["entropy_training_scores_N"] = entropy_training_scores_N

            iteration_log["bald_test_scores"] = bald_test_scores_N
            iteration_log["entropy_test_scores"] = entropy_test_scores_N

            # Also save all the original indices
            # This is slow.
            print("Finding pool indices")
            pool_indices = [get_base_dataset_index(data.active_learning.pool_dataset, i).index for i in range(len(data.active_learning.pool_dataset))]
            iteration_log["pool_base_indices"] = pool_indices
            print("Storing pool indices")

            candidate_global_dataset_indices = []
            candidate_labels = []
            for index in candidate_batch.indices:
                base_di = get_base_dataset_index(data.active_learning.pool_dataset, index)
                dataset_type = "ood" if base_di.dataset == data.ood_dataset else "id"
                candidate_global_dataset_indices.append((dataset_type, base_di.index))
                label = get_target(data.active_learning.pool_dataset, index).tolist()
                candidate_labels.append(label)

            iteration_log["acquisition"] = dict(
                indices=candidate_global_dataset_indices, labels=candidate_labels, scores=candidate_batch.scores
            )

            print(candidate_batch)
            print(candidate_global_dataset_indices)

            if data.ood_dataset is None:
                data.active_learning.acquire(candidate_batch.indices)
            elif data.ood_exposure:
                data.active_learning.acquire(candidate_batch.indices)
            else:
                data.active_learning.acquire(
                    [index for index, label in zip(candidate_batch.indices, candidate_labels) if label != -1]
                )

            ls = ", ".join(f"{label} ({score:.4})" for label, score in zip(candidate_labels, candidate_batch.scores))
            print(f"Acquiring (label, score)s: {ls}")

            num_iterations += 1


@dataclass
class UnifiedExperiment:
    seed: int

    experiment_data_config: ExperimentDataConfig

    acquisition_size: int = 5
    max_training_set: int = 200

    max_training_epochs: int = 300

    num_pool_samples: int = 100
    num_validation_samples: int = 20
    num_training_samples: int = 1

    device: str = "cuda"
    acquisition_function: Union[Type[CandidateBatchComputer], Type[EvalModelBatchComputer]] = batchbald_redux\
        .acquisition_functions.bald.BALD
    train_eval_model: Type[TrainEvalModel] = TrainSelfDistillationEvalModel
    model_trainer_factory: Type[ModelTrainer] = Cifar10ModelTrainer
    ensemble_size: int = 1

    temperature: float = 0.0
    epig_bootstrap_type: acquisition_functions.BootstrapType = acquisition_functions.BootstrapType.NO_BOOTSTRAP
    epig_bootstrap_factor: float = 1.
    epig_dtype: torch.dtype = torch.double
    disable_training_augmentations: bool = False
    cache_explicit_eval_model: bool = False

    def load_experiment_data(self) -> ExperimentData:
        print(self.experiment_data_config)
        return self.experiment_data_config.load(self.device)

    # Simple Dependency Injection
    def create_acquisition_function(self):
        di = DependencyInjection(vars(self))
        return di.create_dataclass_type(self.acquisition_function)

    def create_train_eval_model(self) -> TrainEvalModel:
        di = DependencyInjection(vars(self))
        return di.create_dataclass_type(self.train_eval_model)

    def create_model_trainer(self) -> ModelTrainer:
        di = DependencyInjection(vars(self))
        return di.create_dataclass_type(self.model_trainer_factory)

    def run(self, store):
        torch.manual_seed(self.seed)

        # Active Learning setup
        data = self.load_experiment_data()
        store["dataset_info"] = dict(training=repr(data.active_learning.base_dataset), test=repr(data.test_dataset))
        store["initial_training_set_indices"] = data.initial_training_set_indices
        store["evaluation_set_indices"] = data.evaluation_set_indices

        acquisition_function = self.create_acquisition_function()
        model_trainer = self.create_model_trainer()
        if self.ensemble_size > 1:
            model_trainer = BayesianEnsembleModelTrainer(model_trainer=model_trainer, ensemble_size=self.ensemble_size)
        train_eval_model = self.create_train_eval_model()

        active_learner = ActiveLearner(
            acquisition_size=self.acquisition_size,
            max_training_set=self.max_training_set,
            num_validation_samples=self.num_validation_samples,
            num_pool_samples=self.num_pool_samples,
            disable_training_augmentations=self.disable_training_augmentations,
            acquisition_function=acquisition_function,
            train_eval_model=train_eval_model,
            model_trainer=model_trainer,
            data=data,
            device=self.device,
        )

        active_learner(store)

## MNIST only

In [None]:
# exports

# MNIST experiment (ood_exposure=False)

experiment = UnifiedExperiment(
    experiment_data_config=StandardExperimentDataConfig(
        id_dataset_name="MNIST",
        id_repetitions=1,
        initial_training_set_size=20,
        validation_set_size=4096,
        validation_split_random_state=0,
        evaluation_set_size=0,
        add_dataset_noise=False,
        ood_dataset_config=None,
    ),
    seed=1,
    max_training_epochs=120,
    max_training_set=250,
    acquisition_function=CustomBALD,
    acquisition_size=1,
    model_trainer_factory=MnistModelTrainer,
    num_pool_samples=100,
    ensemble_size=2,
    device="cuda",
)


results = {}
experiment.run(results)

StandardExperimentDataConfig(id_dataset_name='MNIST', id_repetitions=1, initial_training_set_size=20, validation_set_size=4096, validation_split_random_state=0, evaluation_set_size=0, add_dataset_noise=False, ood_dataset_config=None)
Creating: CustomBALD(
	acquisition_size=1,
	num_pool_samples=100
)
Creating: MnistModelTrainer(
	device=cuda,
	num_training_samples=1,
	num_validation_samples=20,
	max_training_epochs=120
)
Creating: TrainSelfDistillationEvalModel(
	num_pool_samples=100
)
Training set size 20:


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)


  1%|          | 1/120 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.5810546875, 'crossentropy': 1.7664165496826172}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.654541015625, 'crossentropy': 1.111796498298645}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.702880859375, 'crossentropy': 0.9382047653198242}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.713134765625, 'crossentropy': 0.9180607199668884}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.686279296875, 'crossentropy': 1.0661845207214355}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.712890625, 'crossentropy': 0.9724885821342468}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.705810546875, 'crossentropy': 0.9743767380714417}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.701416015625, 'crossentropy': 0.9962690472602844}
RestoringEarlyStopping: 4 / 20


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.69140625, 'crossentropy': 1.0272412300109863}
RestoringEarlyStopping: 5 / 20


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.694580078125, 'crossentropy': 1.0385000705718994}
RestoringEarlyStopping: 6 / 20


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.705078125, 'crossentropy': 1.0360578298568726}
RestoringEarlyStopping: 7 / 20


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.714599609375, 'crossentropy': 1.0005216598510742}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.703125, 'crossentropy': 1.0171035528182983}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.695068359375, 'crossentropy': 1.0967237949371338}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.678955078125, 'crossentropy': 1.2056372165679932}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.6923828125, 'crossentropy': 1.1356244087219238}
RestoringEarlyStopping: 4 / 20


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Exception ignored in: <function tqdm.__del__ at 0x7f5cdd9fd5e0>
Traceback (most recent call last):
  File "/home/blackhc/anaconda3/envs/active_learning/lib/python3.8/site-packages/tqdm/std.py", line 1134, in __del__
    def __del__(self):
KeyboardInterrupt: 


Epoch metrics: {'accuracy': 0.69189453125, 'crossentropy': 1.1392706632614136}
RestoringEarlyStopping: 5 / 20


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.694091796875, 'crossentropy': 1.1190389394760132}
RestoringEarlyStopping: 6 / 20


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.70166015625, 'crossentropy': 1.097624659538269}
RestoringEarlyStopping: 7 / 20


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.682373046875, 'crossentropy': 1.118462324142456}
RestoringEarlyStopping: 8 / 20


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Engine run is terminating due to exception: 
Engine run is terminating due to exception: 


KeyboardInterrupt: 

In [None]:
# exports

torch.save(results, "extracted_scores_results.tpickle")