# Experiment
> Can we get better by training on our assumptions?

In [None]:
# default_exp experiment

In [None]:
# hide
import blackhc.project.script

Appended /home/blackhc/PycharmProjects/bald-ical/src to paths
Switched to directory /home/blackhc/PycharmProjects/bald-ical
%load_ext autoreload
%autoreload 2


Import modules and functions were are going to use.

In [None]:
# exports

import dataclasses
import traceback
from dataclasses import dataclass
from enum import Enum

import blackhc.project.script
import numpy as np
import torch
import torch.utils.data
from blackhc.project.experiment import embedded_experiments
from torch.utils.data import Dataset

from batchbald_redux.active_learning import (
    ActiveLearningData,
    RandomFixedLengthSampler,
)
from batchbald_redux.dataset_challenges import (
    create_repeated_MNIST_dataset,
    get_balanced_sample_indices,
    get_base_index,
)
from batchbald_redux.batchbald import (
    CandidateBatch,
    get_bald_ical_scores,
    get_bald_scores,
    get_batchbald_batch,
    get_batchbaldical_batch,
    get_coreset_bald_scores,
    get_batch_coreset_bald_batch,
    get_ical_scores,
    get_sampled_tempered_scorers,
    get_thompson_bald_batch,
    get_top_k_scorers,
    get_top_random_scorers,
)
from batchbald_redux.black_box_model_training import evaluate, get_predictions, get_predictions_labels, train
from batchbald_redux.consistent_mc_dropout import SamplerModel
from batchbald_redux.example_models import BayesianMNISTCNN

In [None]:
# exports


class AcquisitionFunction(Enum):
    random = "Random"
    bald = "BALD"
    batchbald = "BatchBALD"
    batchbaldical = "BatchBALD-ICAL"
    baldical = "BALD-ICAL"
    coresetbald = "CoreSetBALD"
    batchcoresetbald = "BatchCoreSetBALD"
    ical = "ICAL"
    randombald = "RandomBALD"
    randombaldical = "RandomBALD-ICAL"
    thompsonbald = "ThompsonBALD"
    temperedbald = "TemperedBALD"
    temperedcoresetbald = "TemperedCoreSetBALD"
    temperedbaldical = "TemperedBALD-ICAL"
    temperedical = "TemperedICAL"

In [None]:
# exports


class PredictionDataset(torch.utils.data.Dataset):
    dataset: torch.utils.data.Dataset
    predictions: torch.Tensor

    def __init__(self, dataset, predictions):
        assert len(dataset) == predictions.shape[0], f"{len(dataset)} == {predictions.shape[0]}"

        self.dataset = dataset
        self.predictions = predictions

    def __getitem__(self, index):
        x, y = self.dataset[index]
        p = self.predictions[index]
        return x, p

    def __len__(self):
        return len(self.dataset)

In [None]:
# exports

# From the BatchBALD Repo

mnist_initial_samples = [
    38043,
    40091,
    17418,
    2094,
    39879,
    3133,
    5011,
    40683,
    54379,
    24287,
    9849,
    59305,
    39508,
    39356,
    8758,
    52579,
    13655,
    7636,
    21562,
    41329,
]

In [None]:
# exports

bald_scores = []


@dataclass
class Experiment:
    seed: int = 1337
    acquisition_size: int = 5
    max_training_set: int = 300
    num_pool_samples: int = 20
    num_eval_samples: int = 20
    num_training_samples: int = 1
    num_patience_epochs: int = 3
    max_training_epochs: int = 30
    device = "cuda"
    validation_set_size: int = 4096
    initial_set_size: int = 20
    samples_per_epoch: int = 5056
    repeated_mnist_repetitions: int = 2
    add_dataset_noise: bool = True
    acquisition_function: AcquisitionFunction = AcquisitionFunction.bald
    save_bald_scores: bool = False
    temperature: float = 0.0

    def load_dataset(self) -> (ActiveLearningData, Dataset, Dataset):
        train_dataset, test_dataset = create_repeated_MNIST_dataset(
            num_repetitions=self.repeated_mnist_repetitions, add_noise=self.add_dataset_noise
        )
        active_learning_data = ActiveLearningData(train_dataset)

        validation_dataset = active_learning_data.extract_dataset_from_pool(self.validation_set_size)

        return active_learning_data, validation_dataset, test_dataset

    def new_model(self):
        return BayesianMNISTCNN()

    def new_optimizer(self, model):
        return torch.optim.Adam(model.parameters(), weight_decay=5e-4)

    def get_random_candidate_batch(self, num_pool_samples):
        indices = np.random.choice(num_pool_samples, size=self.acquisition_size, replace=False)
        return CandidateBatch([0.0] * self.acquisition_size, indices)

    def get_thompson_bald_candidate_batch(self, model, pool_loader):
        # Evaluate pool set
        log_probs_N_K_C = get_predictions(
            model=model, num_samples=self.num_pool_samples, num_classes=10, loader=pool_loader, device=self.device
        )

        # Evaluate BALD scores
        candidate_batch = get_thompson_bald_batch(
            log_probs_N_K_C,
            batch_size=self.acquisition_size,
            dtype=torch.double,
            device=self.device,
        )
        return candidate_batch

    def get_coreset_bald_candidate_batch(self, model, pool_loader, *, get_scorers):
        # Evaluate pool set
        log_probs_N_K_C, labels_N = get_predictions_labels(
            model=model, num_samples=self.num_pool_samples, num_classes=10, loader=pool_loader, device=self.device
        )

        # Evaluate BALD scores
        scores_N = get_coreset_bald_scores(log_probs_N_K_C, labels_N, dtype=torch.double, device=self.device)

        if self.save_bald_scores:
            bald_scores.append(scores_N)

        candidate_batch = get_scorers(scores_N, batch_size=self.acquisition_size)

        return candidate_batch

    def get_bald_candidate_batch(self, model, pool_loader, *, get_scorers):
        # Evaluate pool set
        log_probs_N_K_C = get_predictions(
            model=model, num_samples=self.num_pool_samples, num_classes=10, loader=pool_loader, device=self.device
        )

        # Evaluate BALD scores
        scores_N = get_bald_scores(log_probs_N_K_C, dtype=torch.double, device=self.device)

        if self.save_bald_scores:
            bald_scores.append(scores_N)

        candidate_batch = get_scorers(scores_N, batch_size=self.acquisition_size)

        return candidate_batch

    def get_batchbald_ical_candidate_batch(self, model, pool_model, pool_loader):
        # Evaluate pool set
        normal_log_probs_N_K_C = get_predictions(
            model=model, num_samples=self.num_pool_samples, num_classes=10, loader=pool_loader, device=self.device
        )

        pool_log_probs_N_K_C = get_predictions(
            model=pool_model, num_samples=self.num_pool_samples, num_classes=10, loader=pool_loader, device=self.device
        )

        # Evaluate BALD scores
        candidate_batch = get_batchbaldical_batch(
            normal_log_probs_N_K_C,
            pool_log_probs_N_K_C,
            batch_size=self.acquisition_size,
            num_samples=1000000,
            dtype=torch.double,
            device=self.device,
        )
        return candidate_batch

    def get_bald_ical_candidate_batch(self, model, pool_model, *, pool_loader, get_scorers):
        # Evaluate pool set
        normal_log_probs_N_K_C = get_predictions(
            model=model, num_samples=self.num_pool_samples, num_classes=10, loader=pool_loader, device=self.device
        )

        pool_log_probs_N_K_C = get_predictions(
            model=pool_model, num_samples=self.num_pool_samples, num_classes=10, loader=pool_loader, device=self.device
        )

        scores_N = get_bald_ical_scores(
            normal_log_probs_N_K_C, pool_log_probs_N_K_C, dtype=torch.double, device=self.device
        )

        candidate_batch = get_scorers(scores_N, batch_size=self.acquisition_size)

        return candidate_batch

    def get_ical_candidate_batch(self, model, pool_model, *, pool_loader, get_scorers):
        # Evaluate pool set
        normal_log_probs_N_K_C = get_predictions(
            model=model, num_samples=self.num_pool_samples, num_classes=10, loader=pool_loader, device=self.device
        )

        pool_log_probs_N_K_C = get_predictions(
            model=pool_model, num_samples=self.num_pool_samples, num_classes=10, loader=pool_loader, device=self.device
        )

        scores_N = get_ical_scores(normal_log_probs_N_K_C, pool_log_probs_N_K_C, dtype=torch.double, device=self.device)

        candidate_batch = get_scorers(scores_N, batch_size=self.acquisition_size)

        return candidate_batch

    def get_batchbald_candidate_batch(self, model, pool_loader):
        # Evaluate pool set
        log_probs_N_K_C = get_predictions(
            model=model, num_samples=self.num_pool_samples, num_classes=10, loader=pool_loader, device=self.device
        )

        # Evaluate BALD scores
        candidate_batch = get_batchbald_batch(
            log_probs_N_K_C,
            batch_size=self.acquisition_size,
            num_samples=1000000,
            dtype=torch.double,
            device=self.device,
        )
        return candidate_batch

    def get_batch_coreset_bald_candidate_batch(self, model, pool_loader):
        # Evaluate pool set
        log_probs_N_K_C, labels_N = get_predictions_labels(
            model=model, num_samples=self.num_pool_samples, num_classes=10, loader=pool_loader, device=self.device
        )

        # Evaluate BALD scores
        candidate_batch = get_batch_coreset_bald_batch(
            log_probs_N_K_C,
            labels_N,
            batch_size=self.acquisition_size,
            dtype=torch.double,
            device="cpu",
        )
        return candidate_batch

    def train_pool_model(
        self, *, model, train_pool_dataset, train_pool_loader, validation_loader, num_epochs, training_log
    ):
        log_probs_N_C = (
            get_predictions(
                model=model,
                num_samples=self.num_eval_samples,
                num_classes=10,
                loader=train_pool_loader,
                device=self.device,
            )
            .mean(dim=1)
            .cpu()
        )

        train_pool_prediction_dataset = PredictionDataset(train_pool_dataset, log_probs_N_C)
        train_pool_prediction_loader = torch.utils.data.DataLoader(
            train_pool_prediction_dataset, batch_size=64, drop_last=True, shuffle=True
        )

        pool_model = self.new_model()
        pool_optimizer = self.new_optimizer(pool_model)

        loss = torch.nn.KLDivLoss(log_target=True, reduction="batchmean")

        train(
            model=pool_model,
            optimizer=pool_optimizer,
            loss=loss,
            validation_loss=torch.nn.NLLLoss(),
            training_samples=self.num_training_samples,
            validation_samples=self.num_eval_samples,
            train_loader=train_pool_prediction_loader,
            validation_loader=validation_loader,
            patience=self.num_patience_epochs,
            max_epochs=num_epochs,
            device=self.device,
            training_log=training_log,
        )
        # print(training_log)

        return pool_model

    def run(self, store):
        torch.manual_seed(self.seed)

        # Active Learning setup
        active_learning_data, validation_dataset, test_dataset = self.load_dataset()

        # initial_training_set_indices = active_learning_data.get_random_pool_indices(self.initial_set_size)
        # initial_training_set_indices = get_balanced_sample_indices(
        #     active_learning_data.pool_dataset, 10, self.initial_set_size // 10
        # )
        initial_training_set_indices = mnist_initial_samples
        active_learning_data.acquire(initial_training_set_indices)

        store["initial_training_set_indices"] = initial_training_set_indices

        train_loader = torch.utils.data.DataLoader(
            active_learning_data.training_dataset,
            batch_size=64,
            sampler=RandomFixedLengthSampler(active_learning_data.training_dataset, self.samples_per_epoch),
            drop_last=True,
        )
        pool_loader = torch.utils.data.DataLoader(
            active_learning_data.pool_dataset, batch_size=64, drop_last=False, shuffle=False
        )

        validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=64, drop_last=False)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, drop_last=False)

        store["active_learning_steps"] = []
        active_learning_steps = store["active_learning_steps"]

        # Active Training Loop
        while True:
            training_set_size = len(active_learning_data.training_dataset)
            print(f"Training set size {training_set_size}:")

            # iteration_log = dict(training={}, pool_training={}, evalution_metrics=None, acquisition=None)
            active_learning_steps.append({})
            iteration_log = active_learning_steps[-1]

            iteration_log["training"] = {}

            model = self.new_model()
            optimizer = self.new_optimizer(model)
            train(
                model=model,
                optimizer=optimizer,
                training_samples=self.num_training_samples,
                validation_samples=self.num_eval_samples,
                train_loader=train_loader,
                validation_loader=validation_loader,
                patience=self.num_patience_epochs,
                max_epochs=self.max_training_epochs,
                device=self.device,
                training_log=iteration_log["training"],
            )

            evaluation_metrics = evaluate(
                model=model, num_samples=self.num_eval_samples, loader=test_loader, device=self.device
            )
            iteration_log["evalution_metrics"] = evaluation_metrics
            print(f"Perf after training {evaluation_metrics}")

            if training_set_size >= self.max_training_set:
                print("Done.")
                break

            if self.acquisition_function in (
                AcquisitionFunction.batchbaldical,
                AcquisitionFunction.baldical,
                AcquisitionFunction.randombaldical,
                AcquisitionFunction.temperedbaldical,
                AcquisitionFunction.ical,
                AcquisitionFunction.temperedical,
            ):
                train_pool_dataset = torch.utils.data.ConcatDataset(
                    [active_learning_data.training_dataset, active_learning_data.pool_dataset]
                )
                train_pool_loader = torch.utils.data.DataLoader(train_pool_dataset, batch_size=64, drop_last=False)

                num_epochs = iteration_log["training"]["best_epoch"]

                iteration_log["pool_training"] = {}
                pool_model = self.train_pool_model(
                    model=model,
                    train_pool_dataset=train_pool_dataset,
                    train_pool_loader=train_pool_loader,
                    validation_loader=validation_loader,
                    num_epochs=num_epochs,
                    training_log=iteration_log["pool_training"],
                )

                if self.acquisition_function == AcquisitionFunction.batchbaldical:
                    candidate_batch = self.get_batchbald_ical_candidate_batch(model, pool_model, pool_loader)
                elif self.acquisition_function == AcquisitionFunction.baldical:
                    candidate_batch = self.get_bald_ical_candidate_batch(
                        model, pool_model, pool_loader=pool_loader, get_scorers=get_top_k_scorers
                    )
                elif self.acquisition_function == AcquisitionFunction.randombaldical:
                    candidate_batch = self.get_bald_ical_candidate_batch(
                        model,
                        pool_model,
                        pool_loader=pool_loader,
                        get_scorers=lambda scores_N, batch_size: get_top_random_scorers(
                            scores_N, batch_size=batch_size, num_classes=10
                        ),
                    )
                elif self.acquisition_function == AcquisitionFunction.temperedbaldical:
                    candidate_batch = self.get_bald_ical_candidate_batch(
                        model,
                        pool_model,
                        pool_loader=pool_loader,
                        get_scorers=lambda scores_N, batch_size: get_sampled_tempered_scorers(
                            scores_N, batch_size=batch_size, temperature=self.temperature
                        ),
                    )
                elif self.acquisition_function == AcquisitionFunction.ical:
                    candidate_batch = self.get_ical_candidate_batch(
                        model, pool_model, pool_loader=pool_loader, get_scorers=get_top_k_scorers
                    )
                elif self.acquisition_function == AcquisitionFunction.temperedical:
                    candidate_batch = self.get_ical_candidate_batch(
                        model,
                        pool_model,
                        pool_loader=pool_loader,
                        get_scorers=lambda scores_N, batch_size: get_sampled_tempered_scorers(
                            scores_N, batch_size=batch_size, temperature=self.temperature
                        ),
                    )
                else:
                    raise f"Unexpected acquisition function {self.acquisition_function}!"
            elif self.acquisition_function == AcquisitionFunction.bald:
                candidate_batch = self.get_bald_candidate_batch(model, pool_loader, get_scorers=get_top_k_scorers)
            elif self.acquisition_function == AcquisitionFunction.coresetbald:
                candidate_batch = self.get_coreset_bald_candidate_batch(
                    model, pool_loader, get_scorers=get_top_k_scorers
                )
            elif self.acquisition_function == AcquisitionFunction.randombald:
                candidate_batch = self.get_bald_candidate_batch(
                    model,
                    pool_loader,
                    get_scorers=lambda scores_N, batch_size: get_top_random_scorers(
                        scores_N, batch_size=batch_size, num_classes=10
                    ),
                )
            elif self.acquisition_function == AcquisitionFunction.temperedbald:
                candidate_batch = self.get_bald_candidate_batch(
                    model,
                    pool_loader,
                    get_scorers=lambda scores_N, batch_size: get_sampled_tempered_scorers(
                        scores_N, batch_size=batch_size, temperature=self.temperature
                    ),
                )
            elif self.acquisition_function == AcquisitionFunction.temperedcoresetbald:
                candidate_batch = self.get_coreset_bald_candidate_batch(
                    model,
                    pool_loader,
                    get_scorers=lambda scores_N, batch_size: get_sampled_tempered_scorers(
                        scores_N, batch_size=batch_size, temperature=self.temperature
                    ),
                )
            elif self.acquisition_function == AcquisitionFunction.batchbald:
                candidate_batch = self.get_batchbald_candidate_batch(model, pool_loader)
            elif self.acquisition_function == AcquisitionFunction.batchcoresetbald:
                candidate_batch = self.get_batch_coreset_bald_candidate_batch(model, pool_loader)
            elif self.acquisition_function == AcquisitionFunction.thompsonbald:
                candidate_batch = self.get_thompson_bald_candidate_batch(model, pool_loader)
            elif self.acquisition_function == AcquisitionFunction.random:
                candidate_batch = self.get_random_candidate_batch(len(active_learning_data.pool_dataset))
            else:
                raise f"Unknown acquisition function {self.acquisition_function}!"

            candidate_global_indices = [get_base_index(active_learning_data.pool_dataset, index) for index in candidate_batch.indices]
            candidate_labels = [active_learning_data.dataset[index][1].item() for index in candidate_global_indices]

            iteration_log["acquisition"] = dict(
                indices=candidate_global_indices, labels=candidate_labels, scores=candidate_batch.scores
            )

            active_learning_data.acquire(candidate_batch.indices)

            ls = ", ".join(f"{label} ({score:.4})" for label, score in zip(candidate_labels, candidate_batch.scores))
            print(f"Acquiring (label, score)s: {ls}")

In [None]:
experiment = Experiment(
    seed=1120,
    max_training_epochs=5,
    max_training_set=100,
    acquisition_function=AcquisitionFunction.batchcoresetbald,
    acquisition_size=5,
    num_pool_samples=10,
    save_bald_scores=False,
    temperature=5,
)

results = {}
experiment.run(results)

Training set size 20:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.547607421875, 'crossentropy': 2.6741107683628798}


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.54296875, 'crossentropy': 3.1681929621845484}
RestoringEarlyStopping: 1 / 3


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.565185546875, 'crossentropy': 3.309005254879594}
RestoringEarlyStopping: 2 / 3


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.559814453125, 'crossentropy': 3.593839146196842}
RestoringEarlyStopping: 3 / 3
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -2.6741107683628798)
RestoringEarlyStopping: Restoring optimizer.


[1/157]   1%|           [00:00<?]

Perf after training {'accuracy': 0.5614, 'crossentropy': 2.337559079360962}


get_predictions_labels:   0%|          | 0/1158840 [00:00<?, ?it/s]

BatchCoreSetBALD:   0%|          | 0/5 [00:00<?, ?it/s]

Acquiring (label, score)s: 5 (2.273), 8 (2.303), 4 (2.303), 5 (2.303), 8 (2.303)
Training set size 25:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.470458984375, 'crossentropy': 2.5281595941632986}


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.49853515625, 'crossentropy': 3.0946903359144926}
RestoringEarlyStopping: 1 / 3


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.49462890625, 'crossentropy': 3.3822752833366394}
RestoringEarlyStopping: 2 / 3


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.49365234375, 'crossentropy': 3.7328700982034206}
RestoringEarlyStopping: 3 / 3
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -2.5281595941632986)
RestoringEarlyStopping: Restoring optimizer.


[1/157]   1%|           [00:00<?]

Perf after training {'accuracy': 0.5007, 'crossentropy': 2.1726310874938966}


get_predictions_labels:   0%|          | 0/1158790 [00:00<?, ?it/s]

BatchCoreSetBALD:   0%|          | 0/5 [00:00<?, ?it/s]

Acquiring (label, score)s: 6 (2.204), 9 (2.301), 9 (2.303), 6 (2.303), 9 (2.303)
Training set size 30:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.483154296875, 'crossentropy': 2.6777254678308964}


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.482177734375, 'crossentropy': 3.694551568478346}
RestoringEarlyStopping: 1 / 3


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.50830078125, 'crossentropy': 3.859209433197975}
RestoringEarlyStopping: 2 / 3


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.51123046875, 'crossentropy': 4.147335961461067}
RestoringEarlyStopping: 3 / 3
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -2.6777254678308964)
RestoringEarlyStopping: Restoring optimizer.


[1/157]   1%|           [00:00<?]

Perf after training {'accuracy': 0.4905, 'crossentropy': 2.4438606189727783}


get_predictions_labels:   0%|          | 0/1158740 [00:00<?, ?it/s]

BatchCoreSetBALD:   0%|          | 0/5 [00:00<?, ?it/s]

Acquiring (label, score)s: 7 (2.23), 7 (2.302), 7 (2.303), 7 (2.303), 7 (2.303)
Training set size 35:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.517822265625, 'crossentropy': 2.1099866963922977}


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.54443359375, 'crossentropy': 2.6397241074591875}
RestoringEarlyStopping: 1 / 3


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.538818359375, 'crossentropy': 3.114008691161871}
RestoringEarlyStopping: 2 / 3


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.551513671875, 'crossentropy': 3.1788441240787506}
RestoringEarlyStopping: 3 / 3
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -2.1099866963922977)
RestoringEarlyStopping: Restoring optimizer.


[1/157]   1%|           [00:00<?]

Perf after training {'accuracy': 0.5307, 'crossentropy': 1.9598717361450195}


get_predictions_labels:   0%|          | 0/1158690 [00:00<?, ?it/s]

BatchCoreSetBALD:   0%|          | 0/5 [00:00<?, ?it/s]

Acquiring (label, score)s: 3 (2.279), 3 (2.303), 3 (2.303), 3 (2.303), 3 (2.303)
Training set size 40:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.496337890625, 'crossentropy': 2.030161777511239}


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.52587890625, 'crossentropy': 2.865957338362932}
RestoringEarlyStopping: 1 / 3


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.536376953125, 'crossentropy': 3.2322004958987236}
RestoringEarlyStopping: 2 / 3


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.555908203125, 'crossentropy': 3.4477310813963413}
RestoringEarlyStopping: 3 / 3
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -2.030161777511239)
RestoringEarlyStopping: Restoring optimizer.


[1/157]   1%|           [00:00<?]

Perf after training {'accuracy': 0.5018, 'crossentropy': 1.9976309867858886}


get_predictions_labels:   0%|          | 0/1158640 [00:00<?, ?it/s]

BatchCoreSetBALD:   0%|          | 0/5 [00:00<?, ?it/s]

Acquiring (label, score)s: 8 (2.119), 2 (2.301), 8 (2.303), 2 (2.303), 8 (2.303)
Training set size 45:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.498291015625, 'crossentropy': 1.8848206475377083}


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.53955078125, 'crossentropy': 2.457843743264675}
RestoringEarlyStopping: 1 / 3


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.5322265625, 'crossentropy': 2.689140809699893}
RestoringEarlyStopping: 2 / 3


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.561767578125, 'crossentropy': 2.7950778268277645}
RestoringEarlyStopping: 3 / 3
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -1.8848206475377083)
RestoringEarlyStopping: Restoring optimizer.


[1/157]   1%|           [00:00<?]

Perf after training {'accuracy': 0.5148, 'crossentropy': 1.7391920001983643}


get_predictions_labels:   0%|          | 0/1158590 [00:00<?, ?it/s]

BatchCoreSetBALD:   0%|          | 0/5 [00:00<?, ?it/s]

Acquiring (label, score)s: 2 (2.018), 1 (2.299), 1 (2.303), 2 (2.303), 1 (2.303)
Training set size 50:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.493408203125, 'crossentropy': 1.8742922320961952}


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.5185546875, 'crossentropy': 2.4364928118884563}
RestoringEarlyStopping: 1 / 3


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.5361328125, 'crossentropy': 2.7442991957068443}
RestoringEarlyStopping: 2 / 3


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.53125, 'crossentropy': 2.896417271345854}
RestoringEarlyStopping: 3 / 3
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -1.8742922320961952)
RestoringEarlyStopping: Restoring optimizer.


[1/157]   1%|           [00:00<?]

Perf after training {'accuracy': 0.5244, 'crossentropy': 1.6999741312026977}


get_predictions_labels:   0%|          | 0/1158540 [00:00<?, ?it/s]

BatchCoreSetBALD:   0%|          | 0/5 [00:00<?, ?it/s]

Acquiring (label, score)s: 6 (1.915), 6 (2.291), 6 (2.302), 6 (2.303), 6 (2.303)
Training set size 55:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.548095703125, 'crossentropy': 1.7283028792589903}


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.526123046875, 'crossentropy': 2.3733168579638004}
RestoringEarlyStopping: 1 / 3


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.515380859375, 'crossentropy': 2.764618206769228}
RestoringEarlyStopping: 2 / 3


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.55126953125, 'crossentropy': 2.7649690210819244}
RestoringEarlyStopping: 3 / 3
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -1.7283028792589903)
RestoringEarlyStopping: Restoring optimizer.


[1/157]   1%|           [00:00<?]

Perf after training {'accuracy': 0.5815, 'crossentropy': 1.6331533550262451}


get_predictions_labels:   0%|          | 0/1158490 [00:00<?, ?it/s]

BatchCoreSetBALD:   0%|          | 0/5 [00:00<?, ?it/s]

Acquiring (label, score)s: 2 (1.998), 2 (2.293), 2 (2.302), 2 (2.303), 2 (2.303)
Training set size 60:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.50732421875, 'crossentropy': 1.9244072064757347}


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.54638671875, 'crossentropy': 2.314682502299547}
RestoringEarlyStopping: 1 / 3


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.5693359375, 'crossentropy': 2.6745465956628323}
RestoringEarlyStopping: 2 / 3


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.56591796875, 'crossentropy': 2.898985553532839}
RestoringEarlyStopping: 3 / 3
RestoringEarlyStopping: Out of patience
RestoringEarlyStopping: Restoring best parameters. (Score: -1.9244072064757347)
RestoringEarlyStopping: Restoring optimizer.


[1/157]   1%|           [00:00<?]

Perf after training {'accuracy': 0.5437, 'crossentropy': 1.7806164630889894}


get_predictions_labels:   0%|          | 0/1158440 [00:00<?, ?it/s]

BatchCoreSetBALD:   0%|          | 0/5 [00:00<?, ?it/s]

Acquiring (label, score)s: 0 (2.056), 0 (2.295), 0 (2.302), 0 (2.303), 0 (2.303)
Training set size 65:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.528564453125, 'crossentropy': 1.8578040953725576}


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.5869140625, 'crossentropy': 2.2518604043871164}
RestoringEarlyStopping: 1 / 3


[1/79]   1%|1          [00:00<?]

[1/64]   2%|1          [00:00<?]

Engine run is terminating due to exception: .
Engine run is terminating due to exception: .


KeyboardInterrupt: 

In [None]:
experiment = Experiment(
    seed=1120,
    acquisition_function=AcquisitionFunction.randombald,
    acquisition_size=10,
    num_pool_samples=20,
    save_bald_scores=True,
)

results = {}
experiment.run(results)

len(bald_scores)

SyntaxError: invalid syntax (<ipython-input-12-f8a60034558a>, line 4)

17

In [None]:
torch.save(bald_scores, "bald_scores.tpickle")

TODO:

validate pool_model on the pool set only! (and not on train and pool together!)

In [None]:
# experiment

experiment = Experiment(
    max_training_epochs=1, max_training_set=25, acquisition_function=AcquisitionFunction.randombaldical
)

results = {}
experiment.run(results)

results

Training set size 20:


100%|##########| 1/1 [00:00<?, ?it/s]

[1/384]   0%|           [00:00<?]

[1/64]   2%|1          [00:00<?]

RestoringEarlyStopping: Restoring best parameters. (Score: -6.529030114412308)
RestoringEarlyStopping: Restoring optimizer.


[1/157]   1%|           [00:00<?]

Perf after training {'accuracy': 0.5367, 'crossentropy': 6.438035237884521}


get_predictions:   0%|          | 0/463616 [00:00<?, ?it/s]

100%|##########| 1/1 [00:00<?, ?it/s]

[1/1811]   0%|           [00:00<?]

[1/64]   2%|1          [00:00<?]

RestoringEarlyStopping: Restoring best parameters. (Score: -5.1637596152722836)
RestoringEarlyStopping: Restoring optimizer.


get_predictions:   0%|          | 0/2317680 [00:00<?, ?it/s]

get_predictions:   0%|          | 0/2317680 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Entropy:   0%|          | 0/115884 [00:00<?, ?it/s]

Acquiring (label, score)s: 8 (0.8711), 8 (0.8687), 3 (0.876), 3 (0.8465), 3 (0.8811)
Training set size 25:


100%|##########| 1/1 [00:00<?, ?it/s]

[1/384]   0%|           [00:00<?]

[1/64]   2%|1          [00:00<?]

RestoringEarlyStopping: Restoring best parameters. (Score: -4.6851686127483845)
RestoringEarlyStopping: Restoring optimizer.


[1/157]   1%|           [00:00<?]

Perf after training {'accuracy': 0.6256, 'crossentropy': 4.484497045135498}
Done.


{'initial_training_set_indices': [38043,
  40091,
  17418,
  2094,
  39879,
  3133,
  5011,
  40683,
  54379,
  24287,
  9849,
  59305,
  39508,
  39356,
  8758,
  52579,
  13655,
  7636,
  21562,
  41329],
 'active_learning_steps': [{'training': {'epochs': [{'accuracy': 0.538818359375,
      'crossentropy': 6.529030114412308}],
    'best_epoch': 1},
   'evalution_metrics': {'accuracy': 0.5367,
    'crossentropy': 6.438035237884521},
   'pool_training': {'epochs': [{'accuracy': 0.531005859375,
      'crossentropy': 5.1637596152722836}],
    'best_epoch': 1},
   'acquisition': {'indices': [63338, 10856, 63452, 81864, 109287],
    'labels': [8, 8, 3, 3, 3],
    'scores': [0.8710822958846325,
     0.8687216999221631,
     0.8759664372823723,
     0.8464646732511746,
     0.8810812784952251]}},
  {'training': {'epochs': [{'accuracy': 0.62255859375,
      'crossentropy': 4.6851686127483845}],
    'best_epoch': 1},
   'evalution_metrics': {'accuracy': 0.6256,
    'crossentropy': 4.4844970451

In [None]:
set(AcquisitionFunction) - {AcquisitionFunction.bald}

{<AcquisitionFunction.batchbald: 'BatchBALD'>,
 <AcquisitionFunction.batchbaldical: 'BatchBALD-ICAL'>,
 <AcquisitionFunction.random: 'Random'>,
 <AcquisitionFunction.randombald: 'RandomBALD'>,
 <AcquisitionFunction.thompsonbald: 'ThompsonBALD'>}

In [None]:
# exports

if __name__ == "__main__":
    if False:
        configs = (
            [
                Experiment(
                    seed=seed + 120,
                    acquisition_function=acquisition_function,
                    acquisition_size=acquisition_size,
                    num_pool_samples=100,
                )
                for seed in range(5)
                for acquisition_size in [5, 10]
                for acquisition_function in [AcquisitionFunction.batchbald, AcquisitionFunction.batchbaldical]
            ]
            + [
                Experiment(
                    seed=seed + 120,
                    acquisition_function=acquisition_function,
                    acquisition_size=acquisition_size,
                    num_pool_samples=max(20, acquisition_size),
                )
                for seed in range(5)
                for acquisition_size in [5, 10, 20, 50]
                for acquisition_function in [
                    AcquisitionFunction.bald,
                    AcquisitionFunction.thompsonbald,
                    AcquisitionFunction.randombald,
                ]
            ]
            + [
                Experiment(
                    seed=seed + 120,
                    acquisition_function=AcquisitionFunction.random,
                    acquisition_size=5,
                    num_pool_samples=20,
                )
                for seed in range(40)
                for acquisition_size in [5]
            ]
        )
    if False:
        configs = [
            Experiment(
                seed=seed + 240,
                acquisition_function=acquisition_function,
                acquisition_size=acquisition_size,
                num_pool_samples=max(20, acquisition_size),
            )
            for seed in range(5)
            for acquisition_size in [5, 10, 20, 50]
            for acquisition_function in [
                AcquisitionFunction.baldical,
                AcquisitionFunction.randombaldical,
            ]
        ]
    if False:
        configs = [
            Experiment(
                seed=seed + 340,
                acquisition_function=acquisition_function,
                acquisition_size=acquisition_size,
                num_pool_samples=20,
                temperature=temperature,
            )
            for seed in range(5)
            for acquisition_size in [5, 10, 20, 50]
            for acquisition_function in [
                AcquisitionFunction.temperedbald,
            ]
            for temperature in [13, 15, 18]
        ]
    if False:
        configs = [
            Experiment(
                seed=seed + 340,
                acquisition_function=acquisition_function,
                acquisition_size=acquisition_size,
                num_pool_samples=20,
                temperature=temperature,
            )
            for seed in range(5)
            for acquisition_size in [5, 10, 20, 50]
            for acquisition_function in [
                AcquisitionFunction.temperedbald,
            ]
            for temperature in [8, 10]
        ] + [
            Experiment(
                seed=seed + 340,
                acquisition_function=acquisition_function,
                acquisition_size=acquisition_size,
                num_pool_samples=20,
                temperature=temperature,
            )
            for seed in range(5)
            for acquisition_size in [5, 10, 20, 50]
            for acquisition_function in [
                AcquisitionFunction.temperedbaldical,
            ]
            for temperature in [8, 10, 13]
        ]
    if False:
        configs = [
            Experiment(
                seed=seed + 500,
                acquisition_function=acquisition_function,
                acquisition_size=acquisition_size,
                num_pool_samples=20,
                temperature=temperature,
            )
            for seed in range(5)
            for acquisition_size in [5, 10, 20, 50]
            for acquisition_function in [
                AcquisitionFunction.temperedical,
            ]
            for temperature in [5, 8, 11]
        ] + [
            Experiment(
                seed=seed + 600,
                acquisition_function=acquisition_function,
                acquisition_size=acquisition_size,
                num_pool_samples=20,
            )
            for seed in range(5)
            for acquisition_size in [5, 10, 20, 50]
            for acquisition_function in [
                AcquisitionFunction.ical,
            ]
        ]

    configs = (
        [
            Experiment(
                seed=seed + 1000,
                acquisition_function=acquisition_function,
                acquisition_size=acquisition_size,
                num_pool_samples=num_pool_samples,
                temperature=temperature,
            )
            for seed in range(5)
            for acquisition_size in [5, 10, 20, 50]
            for num_pool_samples in [20, 100]
            for acquisition_function in [
                AcquisitionFunction.temperedical,
                AcquisitionFunction.temperedbald,
                AcquisitionFunction.temperedbaldical,
            ]
            for temperature in [2, 5, 8]
        ]
        + [
            Experiment(
                seed=seed + 2000,
                acquisition_function=acquisition_function,
                acquisition_size=acquisition_size,
                num_pool_samples=num_pool_samples,
            )
            for seed in range(5)
            for acquisition_size in [5, 10, 20, 50]
            for num_pool_samples in [20, 100]
            for acquisition_function in [
                AcquisitionFunction.ical,
                AcquisitionFunction.bald,
                AcquisitionFunction.baldical,
                AcquisitionFunction.randombald,
            ]
        ]
        + [
            Experiment(
                seed=seed + 2000,
                acquisition_function=acquisition_function,
                acquisition_size=acquisition_size,
                num_pool_samples=max(num_pool_samples, acquisition_size),
            )
            for seed in range(5)
            for acquisition_size in [5, 10, 20, 50]
            for num_pool_samples in [20, 100]
            for acquisition_function in [
                AcquisitionFunction.thompsonbald,
            ]
        ]
        + [
            Experiment(
                seed=seed + 3000,
                acquisition_function=acquisition_function,
                acquisition_size=acquisition_size,
                num_pool_samples=100,
            )
            for seed in range(5)
            for acquisition_size in [5]
            for acquisition_function in [
                AcquisitionFunction.batchbaldical,
                AcquisitionFunction.batchbald,
            ]
        ]
    )

    for job_id, store in embedded_experiments(__file__, len(configs)):
        config = configs[job_id]
        config.seed += job_id
        print(config)
        store["config"] = dataclasses.asdict(config)
        store["log"] = {}

        try:
            config.run(store=store)
        except Exception:
            store["exception"] = traceback.format_exc()
            raise

NameError: name '__file__' is not defined