# An experiment template
> Modularity might not be the solution, but it's all we got.

In [None]:
# default_exp experiment

In [None]:
# hide
import blackhc.project.script

Import modules and functions were are going to use.

In [None]:
# exports

import dataclasses
from dataclasses import dataclass
from typing import Optional

import torch
import torch.utils.data
from ignite.contrib.engines.common import (
    add_early_stopping_by_val_score,
    setup_common_training_handlers,
)
from ignite.contrib.handlers import ProgressBar
from ignite.engine import Events, create_supervised_evaluator, create_supervised_trainer
from ignite.metrics import Accuracy, Loss, RunningAverage
from torch import nn
from torch.utils.data import Dataset

from batchbald_redux.active_learning import (
    ActiveLearningData,
    RandomFixedLengthSampler,
    get_balanced_sample_indices,
    get_base_indices,
)
from batchbald_redux.batchbald import get_bald_batch
from batchbald_redux.black_box_model_training import evaluate, get_predictions, train
from batchbald_redux.consistent_mc_dropout import (
    GeometricMeanPrediction,
    SamplerModel,
    geometric_mean_loss,
    multi_sample_loss,
)
from batchbald_redux.example_models import BayesianMNISTCNN
from batchbald_redux.fast_mnist import FastMNIST
from batchbald_redux.repeated_mnist import create_repeated_MNIST_dataset

In [None]:
# exports


@dataclass
class Experiment:
    acquisition_size: int = 10
    max_training_set: int = 300
    num_pool_samples: int = 20
    num_eval_samples: int = 4
    num_training_samples: int = 1
    num_patience_epochs: int = 3
    max_training_epochs: int = 10
    device = "cuda"
    validation_set_size: int = 1024
    initial_set_size: int = 20
    samples_per_epoch: int = 32768

    def load_dataset(self) -> (ActiveLearningData, Dataset, Dataset):
        train_dataset, test_dataset = create_repeated_MNIST_dataset(num_repetitions=1, add_noise=False)
        active_learning_data = ActiveLearningData(train_dataset)

        validation_dataset = active_learning_data.extract_dataset_from_pool(self.validation_set_size)

        return active_learning_data, validation_dataset, test_dataset

    def new_model(self):
        return BayesianMNISTCNN()

    def new_optimizer(self, model):
        return torch.optim.Adam(model.parameters(), weight_decay=5e-4)

    def get_candidate_batch(self, model, pool_loader):
        # Evaluate pool set
        bald_model = SamplerModel(model, self.num_pool_samples)
        pool_log_probs_N_K_C = get_predictions(model=bald_model, loader=pool_loader, device=self.device)

        # Evaluate BALD scores
        candidate_batch = get_bald_batch(
            pool_log_probs_N_K_C, batch_size=self.acquisition_size, dtype=torch.double, device=self.device
        )
        return candidate_batch

    def run(self, results):
        results["hparams"] = dataclasses.asdict(self)

        # Active Learning setup
        active_learning_data, validation_dataset, test_dataset = self.load_dataset()

        # initial_training_set_indices = active_learning_data.get_random_pool_indices(self.initial_set_size)
        initial_training_set_indices = get_balanced_sample_indices(
            active_learning_data.dataset.targets, 10, self.initial_set_size // 10
        )
        active_learning_data.acquire(initial_training_set_indices)

        results["initial_training_set_indices"] = initial_training_set_indices

        train_loader = torch.utils.data.DataLoader(
            active_learning_data.training_dataset,
            batch_size=64,
            sampler=RandomFixedLengthSampler(active_learning_data.training_dataset, self.samples_per_epoch),
        )
        pool_loader = torch.utils.data.DataLoader(active_learning_data.pool_dataset, batch_size=64, drop_last=False)

        validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=64, drop_last=False)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, drop_last=False)

        results["active_learning_steps"] = []
        active_learning_steps = results["active_learning_steps"]

        # Active Training Loop
        while True:
            training_set_size = len(active_learning_data.training_dataset)
            print(f"Training set size {training_set_size}:")

            iteration_log = dict(training_log=[], evalution_metrics=None, acquisition=None)

            model = self.new_model()
            optimizer = self.new_optimizer(model)
            train(
                model=model,
                optimizer=optimizer,
                training_samples=self.num_training_samples,
                validation_samples=self.num_eval_samples,
                train_loader=train_loader,
                validation_loader=validation_loader,
                patience=self.num_patience_epochs,
                max_epochs=self.max_training_epochs,
                device=self.device,
                epochs_log=iteration_log["training_log"],
            )

            evaluation_metrics = evaluate(
                model=model, num_samples=self.num_eval_samples, loader=test_loader, device=self.device
            )
            iteration_log["evalution_metrics"] = evaluation_metrics
            print(f"Perf after training {evaluation_metrics}")

            if training_set_size >= self.max_training_set:
                print("Done.")
                break

            candidate_batch = self.get_candidate_batch(model, pool_loader)

            candidate_global_indices = get_base_indices(active_learning_data.pool_dataset, candidate_batch.indices)
            candidate_labels = [
                active_learning_data.dataset.targets[index].item() for index in candidate_global_indices
            ]

            iteration_log["acquisition"] = dict(
                indices=candidate_global_indices, labels=candidate_labels, scores=candidate_batch.scores
            )

            active_learning_data.acquire(candidate_batch.indices)

            ls = ", ".join(f"{label} ({score:.4})" for label, score in zip(candidate_labels, candidate_batch.scores))
            print(f"Acquiring (label, score)s: {ls}")

            active_learning_steps.append(iteration_log)

        return active_learning_steps

In [None]:
# experiment

experiment = Experiment(max_training_epochs=1, max_training_set=80)

results = {}
experiment.run(results)

results

Training set size 20:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=512.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=157.0), HTML(value='')))

Perf after training {'accuracy': 0.6184, 'crossentropy': 4.713012725067139}


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=922.0), HTML(value='')))

HBox(children=(HTML(value='Conditional Entropy'), FloatProgress(value=0.0, max=58956.0), HTML(value='')))

HBox(children=(HTML(value='Entropy'), FloatProgress(value=0.0, max=58956.0), HTML(value='')))

Acquiring (label, score)s: 7 (1.314), 7 (1.303), 7 (1.237), 7 (1.215), 2 (1.21), 7 (1.207), 2 (1.198), 2 (1.196), 7 (1.196), 7 (1.196)
Training set size 30:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=512.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=157.0), HTML(value='')))

Perf after training {'accuracy': 0.6501, 'crossentropy': 4.576156882095337}


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=922.0), HTML(value='')))

HBox(children=(HTML(value='Conditional Entropy'), FloatProgress(value=0.0, max=58946.0), HTML(value='')))

HBox(children=(HTML(value='Entropy'), FloatProgress(value=0.0, max=58946.0), HTML(value='')))

Acquiring (label, score)s: 5 (1.403), 8 (1.331), 4 (1.303), 3 (1.3), 3 (1.265), 9 (1.25), 3 (1.249), 3 (1.241), 4 (1.222), 9 (1.22)
Training set size 40:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=512.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=157.0), HTML(value='')))

Perf after training {'accuracy': 0.7086, 'crossentropy': 2.193311638832092}


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=921.0), HTML(value='')))

HBox(children=(HTML(value='Conditional Entropy'), FloatProgress(value=0.0, max=58936.0), HTML(value='')))

HBox(children=(HTML(value='Entropy'), FloatProgress(value=0.0, max=58936.0), HTML(value='')))

Acquiring (label, score)s: 4 (1.297), 4 (1.232), 5 (1.232), 5 (1.231), 5 (1.219), 5 (1.206), 0 (1.195), 5 (1.189), 7 (1.173), 9 (1.17)
Training set size 50:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=512.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=157.0), HTML(value='')))

Perf after training {'accuracy': 0.7608, 'crossentropy': 1.9754455966949462}


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=921.0), HTML(value='')))

HBox(children=(HTML(value='Conditional Entropy'), FloatProgress(value=0.0, max=58926.0), HTML(value='')))

HBox(children=(HTML(value='Entropy'), FloatProgress(value=0.0, max=58926.0), HTML(value='')))

Acquiring (label, score)s: 0 (1.234), 6 (1.214), 2 (1.208), 3 (1.201), 0 (1.201), 9 (1.197), 1 (1.197), 2 (1.191), 3 (1.19), 2 (1.19)
Training set size 60:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=512.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=157.0), HTML(value='')))

Perf after training {'accuracy': 0.7984, 'crossentropy': 1.5577151654243468}


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=921.0), HTML(value='')))

HBox(children=(HTML(value='Conditional Entropy'), FloatProgress(value=0.0, max=58916.0), HTML(value='')))

HBox(children=(HTML(value='Entropy'), FloatProgress(value=0.0, max=58916.0), HTML(value='')))

Acquiring (label, score)s: 1 (1.261), 9 (1.249), 2 (1.248), 0 (1.217), 7 (1.206), 5 (1.198), 3 (1.195), 2 (1.192), 3 (1.191), 9 (1.183)
Training set size 70:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=512.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=157.0), HTML(value='')))

Perf after training {'accuracy': 0.8284, 'crossentropy': 1.070880346775055}


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=921.0), HTML(value='')))

HBox(children=(HTML(value='Conditional Entropy'), FloatProgress(value=0.0, max=58906.0), HTML(value='')))

HBox(children=(HTML(value='Entropy'), FloatProgress(value=0.0, max=58906.0), HTML(value='')))

Acquiring (label, score)s: 2 (1.203), 2 (1.198), 2 (1.189), 0 (1.166), 7 (1.162), 5 (1.143), 3 (1.137), 5 (1.135), 4 (1.133), 4 (1.127)
Training set size 80:


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=512.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=157.0), HTML(value='')))

Perf after training {'accuracy': 0.8569, 'crossentropy': 0.913114131975174}
Done.


{'hparams': {'acquisition_size': 10,
  'max_training_set': 80,
  'num_pool_samples': 20,
  'num_eval_samples': 4,
  'num_training_samples': 1,
  'num_patience_epochs': 3,
  'max_training_epochs': 1,
  'validation_set_size': 1024,
  'initial_set_size': 20,
  'samples_per_epoch': 32768},
 'initial_training_set_indices': tensor([41376, 32395, 47549,  5049, 49034, 33411, 11929, 46723, 10032, 27715,
          4160, 37802, 13018, 18609, 30609, 25413, 26449,  1904, 48999, 40092]),
 'active_learning_steps': [{'training_log': [{'accuracy': 0.6044921875,
     'crossentropy': 4.612063035368919}],
   'evalution_metrics': {'accuracy': 0.6184,
    'crossentropy': 4.713012725067139},
   'acquisition': {'indices': [17963,
     28675,
     25782,
     12819,
     43191,
     39687,
     19229,
     45077,
     25797,
     52443],
    'labels': [7, 7, 7, 7, 2, 7, 2, 2, 7, 7],
    'scores': [1.3138152360916138,
     1.3030644655227661,
     1.2367550432682037,
     1.2148968577384949,
     1.210195183753