# Experiment CIFAR-10
> Can we get better by training on our assumptions?

In [None]:
# default_exp experiment_cifar10

In [None]:
# hide
import blackhc.project.script

Appended /home/blackhc/PycharmProjects/bald-ical/src to paths
Switched to directory /home/blackhc/PycharmProjects/bald-ical
%load_ext autoreload
%autoreload 2


Import modules and functions were are going to use.

In [None]:
# exports

import dataclasses
import traceback
from dataclasses import dataclass, field
from typing import Type, Union

import torch
import torch.utils.data
from blackhc.project import is_run_from_ipython
from blackhc.project.experiment import embedded_experiments
from torch import nn
from torch.utils.data import Dataset

import batchbald_redux.acquisition_functions as acquisition_functions
from batchbald_redux.acquisition_functions import (
    CandidateBatchComputer,
    EvalCandidateBatchComputer,
)
from batchbald_redux.active_learning import ActiveLearningData, RandomFixedLengthSampler
from batchbald_redux.black_box_model_training import evaluate_old, train, train_with_schedule
from batchbald_redux.dataset_challenges import (
    NamedDataset,
    create_repeated_MNIST_dataset,
    get_balanced_sample_indices,
    get_base_dataset_index,
    get_target, AdditiveGaussianNoise, AliasDataset, get_balanced_sample_indices_by_class,
)
from batchbald_redux.datasets import get_dataset
from batchbald_redux.datasets import train_validation_split
from batchbald_redux.di import DependencyInjection
from batchbald_redux.fast_mnist import FastMNIST
from batchbald_redux.model_optimizer_factory import ModelOptimizerFactory
from batchbald_redux.resnet_models import Cifar10BayesianResnetFactory
from batchbald_redux.train_eval_model import (
    TrainEvalModel,
    TrainSelfDistillationEvalModel,
    TrainSelfDistillationEvalModelWithSchedule
)
from batchbald_redux.trained_model import TrainedBayesianModel

In [None]:
# exports

@dataclass
class ExperimentData:
    active_learning: ActiveLearningData
    train_dataset: Dataset
    train_augmentations: nn.Module
    validation_dataset: Dataset
    test_dataset: Dataset
    evaluation_dataset: Dataset
    initial_training_set_indices: [int]
    evaluation_set_indices: [int]


@dataclass
class ExperimentDataConfig:
    id_dataset_name: str
    initial_training_set_size: int
    validation_set_size: int
    evaluation_set_size: int
    id_repetitions: float
    add_dataset_noise: bool
    validation_split_random_state: int

    device: str

    def load(self) -> ExperimentData:
        return load_experiment_data(
            id_dataset_name=self.id_dataset_name,
            initial_training_set_size=self.initial_training_set_size,
            validation_set_size=self.validation_set_size,
            evaluation_set_size=self.evaluation_set_size,
            id_repetitions=self.id_repetitions,
            add_dataset_noise=self.add_dataset_noise,
            validation_split_random_state=self.validation_split_random_state,
            device=self.device,
        )


def load_experiment_data(
    *,
    id_dataset_name: str,
    initial_training_set_size: int,
    validation_set_size: int,
    evaluation_set_size: int,
    id_repetitions: float,
    add_dataset_noise: bool,
    validation_split_random_state: int,
    device: str,
) -> ExperimentData:
    split_dataset = get_dataset(id_dataset_name, root="data", validation_set_size=validation_set_size,
                                validation_split_random_state=validation_split_random_state, normalize_like_cifar10=True)

    train_dataset = split_dataset.train

    # If we reduce the train set, we need to do so before picking the initial train set.
    if id_repetitions < 1:
        train_dataset = train_dataset * id_repetitions

    num_classes = train_dataset.get_num_classes()
    initial_samples_per_class = initial_training_set_size // num_classes
    evaluation_set_samples_per_class = evaluation_set_size // num_classes
    samples_per_class = initial_samples_per_class + evaluation_set_samples_per_class
    balanced_samples_indices = get_balanced_sample_indices_by_class(
        train_dataset,
        num_classes=num_classes,
        samples_per_class=samples_per_class,
        seed=validation_split_random_state,
    )

    initial_training_set_indices = [
        idx for by_class in balanced_samples_indices.values() for idx in by_class[:initial_samples_per_class]
    ]
    evaluation_set_indices = [
        idx for by_class in balanced_samples_indices.values() for idx in by_class[initial_samples_per_class:]
    ]

    # If we over-sample the train set, we do so after picking the initial train set to avoid duplicates.
    if id_repetitions > 1:
        train_dataset = train_dataset * id_repetitions

    if add_dataset_noise:
        train_dataset = AdditiveGaussianNoise(train_dataset, 0.1)

    active_learning_data = ActiveLearningData(train_dataset)

    active_learning_data.acquire_base_indices(initial_training_set_indices)

    evaluation_dataset = AliasDataset(
        active_learning_data.extract_dataset_from_base_indices(evaluation_set_indices),
        f"Evaluation Set ({len(evaluation_set_indices)} samples)",
    )

    return ExperimentData(
        active_learning=active_learning_data,
        train_dataset=train_dataset,
        train_augmentations=split_dataset.train_augmentations,
        validation_dataset=split_dataset.validation,
        test_dataset=split_dataset.test,
        evaluation_dataset=evaluation_dataset,
        initial_training_set_indices=initial_training_set_indices,
        evaluation_set_indices=evaluation_set_indices,
    )

In [None]:
# exports

@dataclass
class Experiment:
    seed: int
    acquisition_function: Union[
        Type[CandidateBatchComputer], Type[EvalCandidateBatchComputer]
    ]

    id_dataset_name: str = "CIFAR-10"
    initial_training_set_size: int = 5000
    validation_set_size: int = 5000
    evaluation_set_size: int = 0
    id_repetitions: float = 1
    add_dataset_noise: bool = False
    validation_split_random_state: int = 0

    acquisition_size: int = 2500
    max_training_set: int = 40000
    num_pool_samples: int = 20
    num_validation_samples: int = 20
    num_training_samples: int = 1
    max_training_epochs: int = 120
    training_batch_size: int = 128
    device: str = "cuda"
    min_samples_per_epoch: int = 5056
    patience_schedule: [int] = (6, 4, 2)
    factor_schedule: [int] = (0.1,)
    train_eval_model: Type[TrainEvalModel] = TrainSelfDistillationEvalModelWithSchedule
    model_optimizer_factory: Type[ModelOptimizerFactory] = Cifar10BayesianResnetFactory
    acquisition_function_args: dict = None
    temperature: float = 0.0
    prefer_accuracy: bool = True

    def load_experiment_data(self) -> ExperimentData:
        di = DependencyInjection(vars(self))
        edc: ExperimentDataConfig = di.create_dataclass_type(ExperimentDataConfig)
        return edc.load()

    # Simple Dependency Injection
    def create_acquisition_function(self):
        di = DependencyInjection(vars(self))
        return di.create_dataclass_type(self.acquisition_function)

    def create_train_eval_model(self, runtime_config) -> TrainEvalModel:
        config = {**vars(self), **runtime_config}
        di = DependencyInjection(config, [])
        return di.create_dataclass_type(self.train_eval_model)

    def run(self, store):
        torch.manual_seed(self.seed)

        # Active Learning setup
        data = self.load_experiment_data()
        store["dataset_info"] = dict(training=repr(data.active_learning.base_dataset), test=repr(data.test_dataset))
        store["initial_training_set_indices"] = data.initial_training_set_indices
        store["evaluation_set_indices"] = data.evaluation_set_indices

        train_loader = torch.utils.data.DataLoader(
            data.active_learning.training_dataset,
            batch_size=self.training_batch_size,
            sampler=RandomFixedLengthSampler(data.active_learning.training_dataset, self.min_samples_per_epoch),
            drop_last=True,
        )
        pool_loader = torch.utils.data.DataLoader(
            data.active_learning.pool_dataset, batch_size=128, drop_last=False, shuffle=False
        )

        validation_loader = torch.utils.data.DataLoader(data.validation_dataset, batch_size=512, drop_last=False)
        test_loader = torch.utils.data.DataLoader(data.test_dataset, batch_size=512, drop_last=False)

        store["active_learning_steps"] = []
        active_learning_steps = store["active_learning_steps"]

        acquisition_function = self.create_acquisition_function()

        # Active Training Loop
        while True:
            training_set_size = len(data.active_learning.training_dataset)
            print(f"Training set size {training_set_size}:")

            # iteration_log = dict(training={}, pool_training={}, evaluation_metrics=None, acquisition=None)
            active_learning_steps.append({})
            iteration_log = active_learning_steps[-1]

            iteration_log["training"] = {}

            model_optimizer = self.model_optimizer_factory().create_model_optimizer()

            if training_set_size > 0:
                train_with_schedule(
                    model=model_optimizer.model,
                    optimizer=model_optimizer.optimizer,
                    train_augmentations=data.train_augmentations,
                    training_samples=self.num_training_samples,
                    validation_samples=self.num_validation_samples,
                    train_loader=train_loader,
                    validation_loader=validation_loader,
                    patience_schedule = self.patience_schedule,
                    factor_schedule = self.factor_schedule,
                    max_epochs=self.max_training_epochs,
                    device=self.device,
                    training_log=iteration_log["training"],
                    prefer_accuracy=self.prefer_accuracy
                )

            evaluation_metrics = evaluate_old(
                model=model_optimizer.model,
                num_samples=self.num_validation_samples,
                loader=test_loader,
                device=self.device,
            )
            iteration_log["evaluation_metrics"] = evaluation_metrics
            print(f"Perf after training {evaluation_metrics}")

            if training_set_size >= self.max_training_set:
                print("Done.")
                break

            trained_model = TrainedBayesianModel(model=model_optimizer.model)

            if isinstance(acquisition_function, CandidateBatchComputer):
                candidate_batch = acquisition_function.compute_candidate_batch(trained_model, pool_loader, self.device)
            elif isinstance(acquisition_function, EvalCandidateBatchComputer):
                current_max_epochs = len(iteration_log["training"]["epochs"])

                if self.evaluation_set_size:
                    eval_dataset = data.evaluation_dataset
                else:
                    eval_dataset = data.active_learning.pool_dataset

                train_eval_model = self.create_train_eval_model(
                    dict(
                        max_epochs=current_max_epochs + 2,
                        training_dataset=data.active_learning.training_dataset,
                        eval_dataset=eval_dataset,
                        validation_loader=validation_loader,
                        trained_model=trained_model,
                        train_augmentations=data.train_augmentations
                    )
                )

                iteration_log["eval_training"] = {}
                trained_eval_model = train_eval_model(training_log=iteration_log["eval_training"], device=self.device)

                candidate_batch = acquisition_function.compute_candidate_batch(
                    trained_model, trained_eval_model, pool_loader, device=self.device
                )
            else:
                raise ValueError(f"Unknown acquisition function {acquisition_function}!")

            candidate_global_indices = [
                get_base_dataset_index(data.active_learning.pool_dataset, index).index
                for index in candidate_batch.indices
            ]
            candidate_labels = [
                get_target(data.active_learning.pool_dataset, index).item() for index in candidate_batch.indices
            ]

            iteration_log["acquisition"] = dict(
                indices=candidate_global_indices, labels=candidate_labels, scores=candidate_batch.scores
            )

            data.active_learning.acquire(candidate_batch.indices)

            ls = ", ".join(f"{label} ({score:.4})" for label, score in zip(candidate_labels, candidate_batch.scores))
            print(f"Acquiring (label, score)s: {ls}")

In [None]:
# experiment

experiment = Experiment(
    seed=1120,
    max_training_epochs=1,
    max_training_set=20005,
    acquisition_function=acquisition_functions.EvalBALD,
    acquisition_size=5,
    num_pool_samples=20,
    initial_training_set_size=20000,
    temperature=8,
    min_samples_per_epoch=5000,
    device="cuda",
)

results = {}
experiment.run(results)

Creating: ExperimentDataConfig(
	id_dataset_name=CIFAR-10,
	initial_training_set_size=20000,
	validation_set_size=5000,
	evaluation_set_size=0,
	id_repetitions=1,
	add_dataset_noise=False,
	validation_split_random_state=0,
	device=cuda
)
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Creating: EvalBALD(
	acquisition_size=5
)
Training set size 20000:


100%|##########| 1/1 [00:00<?, ?it/s]

[1/39]   3%|2          [00:00<?]

[1/10]  10%|#          [00:00<?]

Epoch metrics: {'accuracy': 0.1998, 'crossentropy': 2.38129228515625}


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.1965, 'crossentropy': 2.36414514465332}
Creating: TrainSelfDistillationEvalModelWithSchedule(
	num_pool_samples=20,
	num_training_samples=1,
	num_validation_samples=20,
	patience_schedule=(6, 4, 2),
	factor_schedule=(0.1,),
	max_epochs=3,
	training_dataset=<torch.utils.data.dataset.Subset object at 0x7f687bbd9d90>,
	eval_dataset=<torch.utils.data.dataset.Subset object at 0x7f687bbd9af0>,
	validation_loader=<torch.utils.data.dataloader.DataLoader object at 0x7f687bbf1e20>,
	training_batch_size=128,
	model_optimizer_factory=<class 'batchbald_redux.resnet_models.Cifar10BayesianResnetFactory'>,
	trained_model=TrainedMCDropoutModel(num_samples=20, model=BayesianResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): Identity()
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Co

get_predictions_labels:   0%|          | 0/900000 [00:00<?, ?it/s]

 33%|###3      | 1/3 [00:00<?, ?it/s]

[1/39]   3%|2          [00:00<?]

[1/10]  10%|#          [00:00<?]

Epoch metrics: {'accuracy': 0.1992, 'crossentropy': 2.42520475769043}


[1/39]   3%|2          [00:00<?]

[1/10]  10%|#          [00:00<?]

Epoch metrics: {'accuracy': 0.155, 'crossentropy': 2.468776754760742}
Epoch 2: 0.155 worse than 0.1992, patience: 1/6!


[1/39]   3%|2          [00:00<?]

[1/10]  10%|#          [00:00<?]

Epoch metrics: {'accuracy': 0.2074, 'crossentropy': 2.1420733009338377}


get_predictions_labels:   0%|          | 0/500000 [00:00<?, ?it/s]

get_predictions_labels:   0%|          | 0/500000 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/25000 [00:00<?, ?it/s]

Entropy:   0%|          | 0/25000 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/25000 [00:00<?, ?it/s]

Entropy:   0%|          | 0/25000 [00:00<?, ?it/s]

Acquiring (label, score)s: 1 (0.1945), 1 (0.1942), 9 (0.1849), 9 (0.1817), 1 (0.181)
Training set size 20005:


100%|##########| 1/1 [00:00<?, ?it/s]

[1/39]   3%|2          [00:00<?]

[1/10]  10%|#          [00:00<?]

Epoch metrics: {'accuracy': 0.2416, 'crossentropy': 2.68387165145874}


[1/20]   5%|5          [00:00<?]

Perf after training {'accuracy': 0.2411, 'crossentropy': 2.77079727935791}
Done.


In [None]:
results

{'dataset_info': {'training': "'CIFAR-10 (Train, seed=0, 45000 samples)'",
  'test': "'CIFAR-10 (Test)'"},
 'initial_training_set_indices': [28964,
  11411,
  14886,
  9609,
  39292,
  398,
  3431,
  21404,
  44846,
  11777,
  6361,
  36817,
  5044,
  37231,
  14346,
  24286,
  4294,
  28590,
  16297,
  12733,
  19940,
  27283,
  27046,
  17495,
  4417,
  40795,
  10717,
  3957,
  14535,
  20341,
  27604,
  43757,
  26320,
  40449,
  10574,
  12396,
  14656,
  21304,
  44149,
  12180,
  27762,
  22949,
  32997,
  11309,
  29865,
  36001,
  20338,
  24032,
  34368,
  9137,
  23376,
  13769,
  44858,
  15640,
  40594,
  407,
  40764,
  7166,
  17277,
  15347,
  7175,
  10233,
  14617,
  35065,
  39662,
  32385,
  28273,
  15891,
  26145,
  27266,
  38700,
  14319,
  31039,
  4596,
  21831,
  6428,
  27461,
  6582,
  518,
  20455,
  6795,
  21079,
  30299,
  33470,
  38939,
  27229,
  22701,
  33968,
  19425,
  6796,
  5874,
  32641,
  32181,
  5994,
  43189,
  38244,
  32894,
  18469,
  

In [None]:
# experiment

experiment = Experiment(
    max_training_epochs=1, max_training_set=25, acquisition_function=AcquisitionFunction.randombaldical
)

results = {}
experiment.run(results)

results

Training set size 20:


100%|##########| 1/1 [00:00<?, ?it/s]

[1/384]   0%|           [00:00<?]

[1/64]   2%|1          [00:00<?]

RestoringEarlyStopping: Restoring best parameters. (Score: -6.529030114412308)
RestoringEarlyStopping: Restoring optimizer.


[1/157]   1%|           [00:00<?]

Perf after training {'accuracy': 0.5367, 'crossentropy': 6.438035237884521}


get_predictions:   0%|          | 0/463616 [00:00<?, ?it/s]

100%|##########| 1/1 [00:00<?, ?it/s]

[1/1811]   0%|           [00:00<?]

[1/64]   2%|1          [00:00<?]

RestoringEarlyStopping: Restoring best parameters. (Score: -5.1637596152722836)
RestoringEarlyStopping: Restoring optimizer.


get_predictions:   0%|          | 0/2317680 [00:00<?, ?it/s]

get_predictions:   0%|          | 0/2317680 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Acquiring (label, score)s: 8 (0.8711), 8 (0.8687), 3 (0.876), 3 (0.8465), 3 (0.8811)
Training set size 25:


100%|##########| 1/1 [00:00<?, ?it/s]

[1/384]   0%|           [00:00<?]

[1/64]   2%|1          [00:00<?]

RestoringEarlyStopping: Restoring best parameters. (Score: -4.6851686127483845)
RestoringEarlyStopping: Restoring optimizer.


[1/157]   1%|           [00:00<?]

Perf after training {'accuracy': 0.6256, 'crossentropy': 4.484497045135498}
Done.


{'initial_training_set_indices': [38043,
  40091,
  17418,
  2094,
  39879,
  3133,
  5011,
  40683,
  54379,
  24287,
  9849,
  59305,
  39508,
  39356,
  8758,
  52579,
  13655,
  7636,
  21562,
  41329],
 'active_learning_steps': [{'training': {'epochs': [{'accuracy': 0.538818359375,
      'crossentropy': 6.529030114412308}],
    'best_epoch': 1},
   'evalution_metrics': {'accuracy': 0.5367,
    'crossentropy': 6.438035237884521},
   'pool_training': {'epochs': [{'accuracy': 0.531005859375,
      'crossentropy': 5.1637596152722836}],
    'best_epoch': 1},
   'acquisition': {'indices': [63338, 10856, 63452, 81864, 109287],
    'labels': [8, 8, 3, 3, 3],
    'scores': [0.8710822958846325,
     0.8687216999221631,
     0.8759664372823723,
     0.8464646732511746,
     0.8810812784952251]}},
  {'training': {'epochs': [{'accuracy': 0.62255859375,
      'crossentropy': 4.6851686127483845}],
    'best_epoch': 1},
   'evalution_metrics': {'accuracy': 0.6256,
    'crossentropy': 4.4844970451

In [None]:
# exports

configs = [
    Experiment(
        seed=seed + 8945,
        acquisition_function=acquisition_function,
        acquisition_size=acquisition_size,
        num_pool_samples=num_pool_samples,
        initial_training_set_size=5000,
        evaluation_set_size=0,
        max_training_set=20000,
        temperature=temperature,
        id_repetitions=id_repetitions,
        add_dataset_noise=True

    )
    for seed in range(5)
    for acquisition_function in [
        acquisition_functions.SoftmaxBALD,
    ]
    for acquisition_size in [200]
    for num_pool_samples in [100]
    for temperature in [1/64]
    for id_repetitions in [1,5,10,20]
] + [
    Experiment(
        seed=seed + 8945,
        acquisition_function=acquisition_function,
        acquisition_size=acquisition_size,
        num_pool_samples=num_pool_samples,
        initial_training_set_size=5000,
        evaluation_set_size=0,
        max_training_set=20000,
        temperature=temperature,
        id_repetitions=id_repetitions,
        add_dataset_noise=True
    )
    for seed in range(5)
    for acquisition_function in [
        acquisition_functions.BALD,
    ]
    for acquisition_size in [200]
    for num_pool_samples in [100]
    for temperature in [0]
    for id_repetitions in [1,5,10,20]
]

if not is_run_from_ipython() and __name__ == "__main__":
    for job_id, store in embedded_experiments(__file__, len(configs)):
        config = configs[job_id]
        config.seed += job_id
        print(config)
        store["config"] = dataclasses.asdict(config)
        store["log"] = {}

        try:
            config.run(store=store)
        except Exception:
            store["exception"] = traceback.format_exc()
            raise

In [None]:
len(configs)

40

In [None]:
import prettyprinter
prettyprinter.install_extras({"dataclasses"})
prettyprinter.pprint(configs)

[
    Experiment(
        seed=8945,
        # class
        acquisition_function=batchbald_redux.acquisition_functions.SoftmaxBALD,
        add_dataset_noise=True,
        acquisition_size=100,
        max_training_set=20000,
        num_pool_samples=100,
        temperature=0.015625
    ),
    Experiment(
        seed=8945,
        # class
        acquisition_function=batchbald_redux.acquisition_functions.SoftmaxBALD,
        id_repetitions=5,
        add_dataset_noise=True,
        acquisition_size=100,
        max_training_set=20000,
        num_pool_samples=100,
        temperature=0.015625
    ),
    Experiment(
        seed=8945,
        # class
        acquisition_function=batchbald_redux.acquisition_functions.SoftmaxBALD,
        id_repetitions=10,
        add_dataset_noise=True,
        acquisition_size=100,
        max_training_set=20000,
        num_pool_samples=100,
        temperature=0.015625
    ),
    Experiment(
        seed=8945,
        # class
        acquisition_f