# Unified Experiment Code
> Resistance is futile.

In [None]:
# default_exp sampling_unified_experiment

In [None]:
# hide
import blackhc.project.script

Appended /home/blackhc/PycharmProjects/bald-ical/src to paths
Switched to directory /home/blackhc/PycharmProjects/bald-ical
%load_ext autoreload
%autoreload 2


Import modules and functions were are going to use.

In [None]:
# exports

import dataclasses
import traceback
from dataclasses import dataclass
from typing import Dict, List, Optional, Type, Union

import numpy as np
import torch
import torch.utils.data
from blackhc.project import is_run_from_ipython
from blackhc.project.experiment import embedded_experiments
from torch import nn
from torch.utils.data import Dataset

import batchbald_redux.acquisition_functions.batchbald
import batchbald_redux.acquisition_functions.epig
import wandb
from batchbald_redux import acquisition_functions, baseline_acquisition_functions
from batchbald_redux.acquisition_functions import (
    CandidateBatchComputer,
    EvalDatasetBatchComputer,
    EvalModelBatchComputer,
)
from batchbald_redux.active_learning import SampledActiveLearningData
from batchbald_redux.black_box_model_training import evaluate
from batchbald_redux.dataset_challenges import (
    NamedDataset,
    get_base_dataset_index,
    get_target,
)
from batchbald_redux.datasets import get_dataset
from batchbald_redux.di import DependencyInjection
from batchbald_redux.experiment_data import (
    ExperimentData,
    ExperimentDataConfig,
    OoDDatasetConfig,
    StandardExperimentDataConfig,
)
from batchbald_redux.experiment_logging import init_wandb, log2wandb
from batchbald_redux.models import MnistModelTrainer
from batchbald_redux.resnet_models import Cifar10ModelTrainer
from batchbald_redux.train_eval_model import (
    TrainEvalModel,
    TrainSelfDistillationEvalModel,
)
from batchbald_redux.trained_model import BayesianEnsembleModelTrainer, ModelTrainer

In [None]:
# exports


@dataclass
class SampledExperimentData:
    train_dataset: Dataset
    validation_dataset: Dataset
    evaluation_dataset: Dataset
    test_dataset: Dataset

    train_augmentations: nn.Module

    initial_training_set_indices: List[int]
    evaluation_set_indices: List[int]

    ood_dataset: Optional[NamedDataset]

    # TODO: replace this with dataset info on the targets
    ood_exposure: bool

    device: str


@dataclass
class SampledExperimentDataConfig(ExperimentDataConfig):
    id_dataset_name: str
    initial_training_set_size: int
    validation_set_size: int
    validation_split_random_state: int

    def load(self, device) -> SampledExperimentData:
        return load_distribution_experiment_data(
            id_dataset_name=self.id_dataset_name,
            initial_training_set_size=self.initial_training_set_size,
            validation_set_size=self.validation_set_size,
            validation_split_random_state=self.validation_split_random_state,
            device=device,
        )


def get_class_indices_by_class_distribution(
    targets: torch.Tensor, *, threshold: float, class_counts: list, generator: np.random.Generator
) -> Dict[int, List[int]]:
    class_counts = list(class_counts)

    subset_indices = {label: [] for label in range(len(class_counts))}

    remaining_samples = sum(class_counts)

    indices = generator.permutation(len(targets))

    for index in indices:
        prob, target = targets[index].max(dim=-1, keepdim=False)
        prob = prob.item()
        target = target.item()
        if prob < threshold:
            continue

        if class_counts[target] > 0:
            subset_indices[target].append(index)
            class_counts[target] -= 1
            remaining_samples -= 1

            if remaining_samples <= 0:
                break

    return subset_indices


def get_class_indices_distribution(
    targets: torch.Tensor, *, threshold: float, class_counts: list, generator: np.random.Generator
) -> List[int]:
    indices_by_class = get_class_indices_by_class_distribution(
        targets=targets, threshold=threshold, class_counts=class_counts, generator=generator
    )
    return [index for by_class in indices_by_class.values() for index in by_class]


def get_balanced_indices_from_distribution(
    targets: torch.Tensor, *, threshold: float, num_classes, samples_per_class, seed: int
) -> List[int]:
    class_counts = [samples_per_class] * num_classes
    generator = np.random.default_rng(seed)

    return get_class_indices_distribution(
        targets=targets, threshold=threshold, class_counts=class_counts, generator=generator
    )


def load_distribution_experiment_data(
    *,
    id_dataset_name: str,
    initial_training_set_size: int,
    validation_set_size: int,
    validation_split_random_state: int,
    device: str,
) -> SampledExperimentData:
    split_dataset = get_dataset(
        id_dataset_name,
        root="data",
        validation_set_size=validation_set_size,
        validation_split_random_state=validation_split_random_state,
        normalize_like_cifar10=True,
        device_hint=device,
    )

    train_dataset = split_dataset.train

    targets = train_dataset.get_targets()
    num_classes = train_dataset.get_num_classes()
    initial_samples_per_class = initial_training_set_size // num_classes
    initial_training_set_indices = get_balanced_indices_from_distribution(
        targets=targets,
        threshold=0.95,
        num_classes=num_classes,
        samples_per_class=initial_samples_per_class,
        seed=validation_split_random_state,
    )

    original_ood_dataset = None
    ood_exposure = False

    return SampledExperimentData(
        train_dataset=train_dataset,
        validation_dataset=split_dataset.validation,
        test_dataset=split_dataset.test,
        evaluation_dataset=train_dataset.subset([]),
        train_augmentations=split_dataset.train_augmentations,
        initial_training_set_indices=initial_training_set_indices,
        evaluation_set_indices=[],
        ood_dataset=original_ood_dataset,
        ood_exposure=ood_exposure,
        device=split_dataset.device,
    )


@dataclass
class SampledActiveLearner:
    acquisition_size: int
    max_training_set: int

    num_validation_samples: int

    acquisition_function: Union[CandidateBatchComputer, EvalModelBatchComputer]
    train_eval_model: TrainEvalModel
    model_trainer: ModelTrainer
    data: SampledExperimentData

    disable_training_augmentations: bool
        
    allow_repeated_acquisition: bool

    device: Optional

    def __call__(self, log):
        log["seed"] = torch.seed()

        # Active Learning setup
        data = self.data
        
        active_learning_data = SampledActiveLearningData(data.train_dataset, allow_repeated_acquisition=self.allow_repeated_acquisition)

        active_learning_data.acquire_base_indices(data.initial_training_set_indices, select_majority=True)

        train_augmentations = data.train_augmentations if not self.disable_training_augmentations else None

        model_trainer = self.model_trainer
        train_eval_model = self.train_eval_model

        train_loader = model_trainer.get_train_dataloader(active_learning_data.training_dataset)
        pool_loader = model_trainer.get_evaluation_dataloader(active_learning_data.pool_dataset)
        validation_loader = model_trainer.get_evaluation_dataloader(data.validation_dataset)
        test_loader = model_trainer.get_evaluation_dataloader(data.test_dataset)

        log["active_learning_steps"] = []
        active_learning_steps = log["active_learning_steps"]

        acquisition_function = self.acquisition_function

        num_iterations = 0
        max_iterations = int(
            1.5 * (self.max_training_set - len(active_learning_data.training_dataset)) / self.acquisition_size
        )

        # Active Training Loop
        while True:
            training_set_size = len(active_learning_data.training_dataset)
            print(f"Training set size {training_set_size}:")

            # iteration_log = dict(training={}, pool_training={}, evaluation_metrics=None, acquisition=None)
            active_learning_steps.append({})
            iteration_log = active_learning_steps[-1]

            iteration_log["training"] = {}

            # TODO: this is a hack! :(
            if data.ood_dataset is None:
                loss = validation_loss = torch.nn.NLLLoss()
            elif data.ood_exposure:
                loss = torch.nn.KLDivLoss(log_target=False, reduction="batchmean")
                validation_loss = torch.nn.NLLLoss()
            else:
                loss = validation_loss = torch.nn.NLLLoss()

            trained_model = model_trainer.get_trained(
                train_loader=train_loader,
                train_augmentations=train_augmentations,
                validation_loader=validation_loader,
                log=iteration_log["training"],
                wandb_key_path="model_training",
                loss=loss,
                validation_loss=validation_loss,
            )

            evaluation_metrics = evaluate(
                model=trained_model,
                num_samples=self.num_validation_samples,
                loader=test_loader,
                device=self.device,
                storage_device="cpu",
            )
            iteration_log["evaluation_metrics"] = evaluation_metrics
            log2wandb(evaluation_metrics, commit=False)
            print(f"Perf after training {evaluation_metrics}")

            if training_set_size >= self.max_training_set or num_iterations >= max_iterations:
                log2wandb({}, commit=True)
                print("Done.")
                break

            if isinstance(acquisition_function, CandidateBatchComputer):
                candidate_batch = acquisition_function.compute_candidate_batch(trained_model, pool_loader, self.device)
            elif isinstance(acquisition_function, EvalDatasetBatchComputer):
                if len(data.evaluation_dataset) > 0:
                    eval_loader = model_trainer.get_evaluation_dataloader(data.evaluation_dataset)
                else:
                    eval_loader = pool_loader

                candidate_batch = acquisition_function.compute_candidate_batch(
                    model=trained_model, pool_loader=pool_loader, eval_loader=eval_loader, device=self.device
                )
            elif isinstance(acquisition_function, EvalModelBatchComputer):
                if len(data.evaluation_dataset) > 0:
                    eval_dataset = data.evaluation_dataset
                else:
                    eval_dataset = active_learning_data.pool_dataset

                iteration_log["eval_training"] = {}
                trained_eval_model = train_eval_model(
                    model_trainer=model_trainer,
                    training_dataset=active_learning_data.training_dataset,
                    train_augmentations=train_augmentations,
                    eval_dataset=eval_dataset,
                    validation_loader=validation_loader,
                    trained_model=trained_model,
                    storage_device=data.device,
                    device=self.device,
                    training_log=iteration_log["eval_training"],
                    wandb_key_path="eval_model_training",
                )

                candidate_batch = acquisition_function.compute_candidate_batch(
                    trained_model, trained_eval_model, pool_loader, device=self.device
                )
            else:
                raise ValueError(f"Unknown acquisition function {acquisition_function}!")

            candidate_global_dataset_indices = []
            candidate_images = []
            candidate_types = []
            for index in candidate_batch.indices:
                base_di = get_base_dataset_index(active_learning_data.pool_dataset, index)
                dataset_type = "ood" if base_di.dataset == data.ood_dataset else "id"
                candidate_types.append(dataset_type)

                candidate_global_dataset_indices.append((dataset_type, base_di.index))
                candidate_images.append(wandb.Image(active_learning_data.pool_dataset[index][0]))

            if data.ood_dataset is None:
                candidate_labels = active_learning_data.acquire(candidate_batch.indices)
            elif data.ood_exposure:
                candidate_labels = active_learning_data.acquire(candidate_batch.indices)
            else:
                id_candidate_indices = [
                    index
                    for index, dataset_type in zip(candidate_batch.indices, candidate_types)
                    if dataset_type == "id"
                ]
                id_labels = active_learning_data.acquire(id_candidate_indices)

                candidate_labels = [None] * len(candidate_types)
                id_reverse_indices = [
                    index for index, dataset_type in enumerate(candidate_types) if dataset_type == "id"
                ]
                for id_reverse_index, id_label in zip(id_reverse_indices, id_labels):
                    candidate_labels[id_reverse_index] = id_label

            acquisition_info = dict(
                indices=candidate_global_dataset_indices, labels=candidate_labels, scores=candidate_batch.scores
            )
            iteration_log["acquisition"] = acquisition_info

            acquistion_batch_table = wandb.Table(
                data=list(
                    zip(
                        *zip(*candidate_global_dataset_indices),
                        candidate_images,
                        candidate_labels,
                        candidate_batch.scores,
                    )
                ),
                columns=["dataset", "index", "sample", "label", "score"],
            )
            log2wandb(dict(acquisition=acquistion_batch_table), commit=False)

            ls = ",\n".join(
                f"{global_index}: {label} ({score:.4})"
                for global_index, label, score in zip(
                    candidate_global_dataset_indices, candidate_labels, candidate_batch.scores
                )
            )
            print(f"Acquiring global_index: label (score):\n{ls}")

            num_iterations += 1
            log2wandb({}, commit=True)


@dataclass
class SampledUnifiedExperiment:
    seed: int

    experiment_data_config: SampledExperimentDataConfig

    acquisition_size: int = 5
    max_training_set: int = 200

    max_training_epochs: int = 300

    num_pool_samples: int = 100
    num_validation_samples: int = 20
    num_training_samples: int = 1

    device: str = "cuda"
    acquisition_function: Union[
        Type[CandidateBatchComputer], Type[EvalModelBatchComputer]
    ] = None  # acquisition_functions.BALD
    train_eval_model: Type[TrainEvalModel] = TrainSelfDistillationEvalModel
    model_trainer_factory: Type[ModelTrainer] = None  # Cifar10ModelTrainer
    ensemble_size: int = 1

    temperature: float = 1.0
    coldness: float = 1.0
    stochastic_mode: acquisition_functions.StochasticMode = acquisition_functions.StochasticMode.TopK
    disable_training_augmentations: bool = False
    resnet18_dropout_head: bool = True

    allow_repeated_acquisition: bool = True

    def load_experiment_data(self) -> SampledExperimentData:
        print(self.experiment_data_config)
        return self.experiment_data_config.load(self.device)

    # Simple Dependency Injection
    def create_acquisition_function(self):
        di = DependencyInjection(vars(self))
        return di.create_dataclass_type(self.acquisition_function)

    def create_train_eval_model(self) -> TrainEvalModel:
        di = DependencyInjection(vars(self))
        return di.create_dataclass_type(self.train_eval_model)

    def create_model_trainer(self) -> ModelTrainer:
        di = DependencyInjection(vars(self))
        return di.create_dataclass_type(self.model_trainer_factory)

    def run(self, store, project=None, entity=None):
        init_wandb(self, project=project, entity=entity)

        torch.manual_seed(self.seed)

        # Active Learning setup
        data = self.load_experiment_data()
        store["dataset_info"] = dict(training=repr(data.train_dataset), test=repr(data.test_dataset))
        store["initial_training_set_indices"] = data.initial_training_set_indices
        store["evaluation_set_indices"] = data.evaluation_set_indices

        print(wandb.config)

        wandb.config.initial_training_set_indices = data.initial_training_set_indices
        wandb.config.evaluation_set_indices = data.evaluation_set_indices
        wandb.config["dataset_info"] = store["dataset_info"]

        acquisition_function = self.create_acquisition_function()
        model_trainer = self.create_model_trainer()
        if self.ensemble_size > 1:
            model_trainer = BayesianEnsembleModelTrainer(model_trainer=model_trainer, ensemble_size=self.ensemble_size)
        train_eval_model = self.create_train_eval_model()

        active_learner = SampledActiveLearner(
            acquisition_size=self.acquisition_size,
            max_training_set=self.max_training_set,
            num_validation_samples=self.num_validation_samples,
            disable_training_augmentations=self.disable_training_augmentations,
            acquisition_function=acquisition_function,
            train_eval_model=train_eval_model,
            model_trainer=model_trainer,
            data=data,
            allow_repeated_acquisition=self.allow_repeated_acquisition,
            device=self.device,
        )

        active_learner(store)

        wandb.finish()

## MNIST only

In [None]:
# experiment
# MNIST experiment (ood_exposure=False)

experiment = SampledUnifiedExperiment(
    experiment_data_config=SampledExperimentDataConfig(
        id_dataset_name="DistributionalAmbiguousMNIST",        
        initial_training_set_size=20,
        validation_set_size=4096,
        validation_split_random_state=0,               
    ),
    seed=1,
    max_training_epochs=5,
    max_training_set=20 + 10,
    acquisition_function=acquisition_functions.BALD,
    acquisition_size=1,
    model_trainer_factory=MnistModelTrainer,
    num_pool_samples=2,
    allow_repeated_acquisition=True,
    device="cuda",
)

results = {}
experiment.run(results)
results

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33moatml-andreas-kirsch[0m (use `wandb login --relogin` to force relogin)


SampledExperimentDataConfig(id_dataset_name='DistributionalAmbiguousMNIST', initial_training_set_size=20, validation_set_size=4096, validation_split_random_state=0)
{'Dataclass': '__main__.SampledUnifiedExperiment', 'seed': 1, 'experiment_data_config': {'Dataclass': '__main__.SampledExperimentDataConfig', 'id_dataset_name': 'DistributionalAmbiguousMNIST', 'initial_training_set_size': 20, 'validation_set_size': 4096, 'validation_split_random_state': 0}, 'acquisition_size': 1, 'max_training_set': 30, 'max_training_epochs': 5, 'num_pool_samples': 2, 'num_validation_samples': 20, 'num_training_samples': 1, 'device': 'cuda', 'acquisition_function': 'batchbald_redux.acquisition_functions.bald.BALD', 'train_eval_model': 'batchbald_redux.train_eval_model.TrainSelfDistillationEvalModel', 'model_trainer_factory': 'batchbald_redux.models.MnistModelTrainer', 'ensemble_size': 1, 'temperature': 1.0, 'coldness': 1.0, 'stochastic_mode': 'StochasticMode.TopK', 'disable_training_augmentations': False, '

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.37890625, 'crossentropy': 1.7864514589309692}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.505859375, 'crossentropy': 1.4874533414840698}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.531494140625, 'crossentropy': 1.6502480506896973}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.510498046875, 'crossentropy': 1.7343323230743408}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.5224609375, 'crossentropy': 1.8292348384857178}
RestoringEarlyStopping: 2 / 20
RestoringEarlyStopping: Restoring best parameters. (Score: 0.531494140625)
RestoringEarlyStopping: Restoring optimizer.
{'model_training/val_metrics': <wandb.data_types.Table object at 0x7fafc6c49e50>, 'model_training/best_epoch': 2, 'model_training/best_val_accuracy': 0.531494140625, 'model_training/best_val_crossentropy': 1.6502480506896973}


get_predictions_labels:   0%|          | 0/1400000 [00:00<?, ?it/s]

Perf after training {'accuracy': 0.40954285714285715, 'crossentropy': tensor(1.8577), '_timestamp': 1652301450, '_runtime': 22}


get_predictions_labels:   0%|          | 0/12000 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/6000 [00:00<?, ?it/s]

Entropy:   0%|          | 0/6000 [00:00<?, ?it/s]

Acquiring global_index: label (score):
('id', 5327): 7 (0.5811)

Training set size 21:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.295166015625, 'crossentropy': 1.9844443798065186}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.52197265625, 'crossentropy': 1.5216455459594727}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.509765625, 'crossentropy': 1.6709494590759277}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.526611328125, 'crossentropy': 1.7819267511367798}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.51708984375, 'crossentropy': 1.828109622001648}
RestoringEarlyStopping: 1 / 20
RestoringEarlyStopping: Restoring best parameters. (Score: 0.526611328125)
RestoringEarlyStopping: Restoring optimizer.
{'model_training/val_metrics': <wandb.data_types.Table object at 0x7fafacdb16a0>, 'model_training/best_epoch': 3, 'model_training/best_val_accuracy': 0.526611328125, 'model_training/best_val_crossentropy': 1.7819267511367798}


get_predictions_labels:   0%|          | 0/1400000 [00:00<?, ?it/s]

Perf after training {'accuracy': 0.42315714285714284, 'crossentropy': tensor(2.0497), '_timestamp': 1652301464, '_runtime': 36}


get_predictions_labels:   0%|          | 0/12000 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/6000 [00:00<?, ?it/s]

Entropy:   0%|          | 0/6000 [00:00<?, ?it/s]

Acquiring global_index: label (score):
('id', 3019): 8 (0.6507)

Training set size 22:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.31494140625, 'crossentropy': 1.8965338468551636}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.505126953125, 'crossentropy': 1.4474931955337524}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.53857421875, 'crossentropy': 1.399573802947998}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.547607421875, 'crossentropy': 1.5304324626922607}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.544677734375, 'crossentropy': 1.5962694883346558}
RestoringEarlyStopping: 1 / 20
RestoringEarlyStopping: Restoring best parameters. (Score: 0.547607421875)
RestoringEarlyStopping: Restoring optimizer.
{'model_training/val_metrics': <wandb.data_types.Table object at 0x7fafacdeeca0>, 'model_training/best_epoch': 3, 'model_training/best_val_accuracy': 0.547607421875, 'model_training/best_val_crossentropy': 1.5304324626922607}


get_predictions_labels:   0%|          | 0/1400000 [00:00<?, ?it/s]

Perf after training {'accuracy': 0.4380857142857143, 'crossentropy': tensor(1.8986), '_timestamp': 1652301478, '_runtime': 50}


get_predictions_labels:   0%|          | 0/12000 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/6000 [00:00<?, ?it/s]

Entropy:   0%|          | 0/6000 [00:00<?, ?it/s]

Acquiring global_index: label (score):
('id', 344): 7 (0.6915)

Training set size 23:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.29638671875, 'crossentropy': 2.002034902572632}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.433837890625, 'crossentropy': 1.6772496700286865}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.509765625, 'crossentropy': 1.57258939743042}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.51416015625, 'crossentropy': 1.6670124530792236}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.5234375, 'crossentropy': 1.7411009073257446}
RestoringEarlyStopping: Restoring best parameters. (Score: 0.5234375)
RestoringEarlyStopping: Restoring optimizer.
{'model_training/val_metrics': <wandb.data_types.Table object at 0x7fafac3990d0>, 'model_training/best_epoch': 4, 'model_training/best_val_accuracy': 0.5234375, 'model_training/best_val_crossentropy': 1.7411009073257446}


get_predictions_labels:   0%|          | 0/1400000 [00:00<?, ?it/s]

Perf after training {'accuracy': 0.4425, 'crossentropy': tensor(1.9377), '_timestamp': 1652301492, '_runtime': 64}


get_predictions_labels:   0%|          | 0/12000 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/6000 [00:00<?, ?it/s]

Entropy:   0%|          | 0/6000 [00:00<?, ?it/s]

Acquiring global_index: label (score):
('id', 2673): 9 (0.6725)

Training set size 24:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.20458984375, 'crossentropy': 2.092069625854492}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.51806640625, 'crossentropy': 1.5272133350372314}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.530517578125, 'crossentropy': 1.5148138999938965}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.546875, 'crossentropy': 1.5188500881195068}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.5380859375, 'crossentropy': 1.6396946907043457}
RestoringEarlyStopping: 1 / 20
RestoringEarlyStopping: Restoring best parameters. (Score: 0.546875)
RestoringEarlyStopping: Restoring optimizer.
{'model_training/val_metrics': <wandb.data_types.Table object at 0x7fafacdacdf0>, 'model_training/best_epoch': 3, 'model_training/best_val_accuracy': 0.546875, 'model_training/best_val_crossentropy': 1.5188500881195068}


get_predictions_labels:   0%|          | 0/1400000 [00:00<?, ?it/s]

Perf after training {'accuracy': 0.44685714285714284, 'crossentropy': tensor(1.7587), '_timestamp': 1652301505, '_runtime': 77}


get_predictions_labels:   0%|          | 0/12000 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/6000 [00:00<?, ?it/s]

Entropy:   0%|          | 0/6000 [00:00<?, ?it/s]

Acquiring global_index: label (score):
('id', 3157): 2 (0.6832)

Training set size 25:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.29248046875, 'crossentropy': 1.9117804765701294}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.509765625, 'crossentropy': 1.5092277526855469}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.508056640625, 'crossentropy': 1.5606906414031982}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.523193359375, 'crossentropy': 1.6402932405471802}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.5224609375, 'crossentropy': 1.6856915950775146}
RestoringEarlyStopping: 1 / 20
RestoringEarlyStopping: Restoring best parameters. (Score: 0.523193359375)
RestoringEarlyStopping: Restoring optimizer.
{'model_training/val_metrics': <wandb.data_types.Table object at 0x7fafac3ed130>, 'model_training/best_epoch': 3, 'model_training/best_val_accuracy': 0.523193359375, 'model_training/best_val_crossentropy': 1.6402932405471802}


get_predictions_labels:   0%|          | 0/1400000 [00:00<?, ?it/s]

Perf after training {'accuracy': 0.43584285714285714, 'crossentropy': tensor(1.7989), '_timestamp': 1652301518, '_runtime': 90}


get_predictions_labels:   0%|          | 0/12000 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/6000 [00:00<?, ?it/s]

Entropy:   0%|          | 0/6000 [00:00<?, ?it/s]

Acquiring global_index: label (score):
('id', 2274): 0 (0.673)

Training set size 26:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.32666015625, 'crossentropy': 1.940726399421692}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.485107421875, 'crossentropy': 1.5525437593460083}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.5234375, 'crossentropy': 1.5748449563980103}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.527587890625, 'crossentropy': 1.662620186805725}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.51123046875, 'crossentropy': 1.7853187322616577}
RestoringEarlyStopping: 1 / 20
RestoringEarlyStopping: Restoring best parameters. (Score: 0.527587890625)
RestoringEarlyStopping: Restoring optimizer.
{'model_training/val_metrics': <wandb.data_types.Table object at 0x7fafac2b9a90>, 'model_training/best_epoch': 3, 'model_training/best_val_accuracy': 0.527587890625, 'model_training/best_val_crossentropy': 1.662620186805725}


get_predictions_labels:   0%|          | 0/1400000 [00:00<?, ?it/s]

Perf after training {'accuracy': 0.4385, 'crossentropy': tensor(1.8417), '_timestamp': 1652301531, '_runtime': 103}


get_predictions_labels:   0%|          | 0/12000 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/6000 [00:00<?, ?it/s]

Entropy:   0%|          | 0/6000 [00:00<?, ?it/s]

Acquiring global_index: label (score):
('id', 2717): 9 (0.639)

Training set size 27:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.281005859375, 'crossentropy': 2.101107120513916}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.432373046875, 'crossentropy': 1.6511849164962769}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.503173828125, 'crossentropy': 1.5441358089447021}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.48291015625, 'crossentropy': 1.8059364557266235}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.512939453125, 'crossentropy': 1.8056018352508545}
RestoringEarlyStopping: Restoring best parameters. (Score: 0.512939453125)
RestoringEarlyStopping: Restoring optimizer.
{'model_training/val_metrics': <wandb.data_types.Table object at 0x7fafac1dc220>, 'model_training/best_epoch': 4, 'model_training/best_val_accuracy': 0.512939453125, 'model_training/best_val_crossentropy': 1.8056018352508545}


get_predictions_labels:   0%|          | 0/1400000 [00:00<?, ?it/s]

Perf after training {'accuracy': 0.4313142857142857, 'crossentropy': tensor(1.9129), '_timestamp': 1652301545, '_runtime': 117}


get_predictions_labels:   0%|          | 0/12000 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/6000 [00:00<?, ?it/s]

Entropy:   0%|          | 0/6000 [00:00<?, ?it/s]

Acquiring global_index: label (score):
('id', 3384): 0 (0.6824)

Training set size 28:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.28125, 'crossentropy': 1.957114577293396}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.495361328125, 'crossentropy': 1.5366103649139404}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.523681640625, 'crossentropy': 1.5298080444335938}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.548583984375, 'crossentropy': 1.5535179376602173}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.520263671875, 'crossentropy': 1.7138733863830566}
RestoringEarlyStopping: 1 / 20
RestoringEarlyStopping: Restoring best parameters. (Score: 0.548583984375)
RestoringEarlyStopping: Restoring optimizer.
{'model_training/val_metrics': <wandb.data_types.Table object at 0x7fafac18a5e0>, 'model_training/best_epoch': 3, 'model_training/best_val_accuracy': 0.548583984375, 'model_training/best_val_crossentropy': 1.5535179376602173}


get_predictions_labels:   0%|          | 0/1400000 [00:00<?, ?it/s]

Perf after training {'accuracy': 0.4445857142857143, 'crossentropy': tensor(1.8181), '_timestamp': 1652301558, '_runtime': 130}


get_predictions_labels:   0%|          | 0/12000 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/6000 [00:00<?, ?it/s]

Entropy:   0%|          | 0/6000 [00:00<?, ?it/s]

Acquiring global_index: label (score):
('id', 3295): 3 (0.6924)

Training set size 29:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.216552734375, 'crossentropy': 2.1377549171447754}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.4775390625, 'crossentropy': 1.5503466129302979}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.52587890625, 'crossentropy': 1.4880073070526123}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.5244140625, 'crossentropy': 1.5435246229171753}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.544189453125, 'crossentropy': 1.6195586919784546}
RestoringEarlyStopping: Restoring best parameters. (Score: 0.544189453125)
RestoringEarlyStopping: Restoring optimizer.
{'model_training/val_metrics': <wandb.data_types.Table object at 0x7fafac0ae6d0>, 'model_training/best_epoch': 4, 'model_training/best_val_accuracy': 0.544189453125, 'model_training/best_val_crossentropy': 1.6195586919784546}


get_predictions_labels:   0%|          | 0/1400000 [00:00<?, ?it/s]

Perf after training {'accuracy': 0.4448142857142857, 'crossentropy': tensor(1.7994), '_timestamp': 1652301572, '_runtime': 144}


get_predictions_labels:   0%|          | 0/12000 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/6000 [00:00<?, ?it/s]

Entropy:   0%|          | 0/6000 [00:00<?, ?it/s]

Acquiring global_index: label (score):
('id', 4887): 7 (0.6727)

Training set size 30:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.285400390625, 'crossentropy': 2.0492501258850098}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.422607421875, 'crossentropy': 1.6460908651351929}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.493408203125, 'crossentropy': 1.5688776969909668}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.502197265625, 'crossentropy': 1.605987787246704}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.509033203125, 'crossentropy': 1.734849214553833}
RestoringEarlyStopping: Restoring best parameters. (Score: 0.509033203125)
RestoringEarlyStopping: Restoring optimizer.
{'model_training/val_metrics': <wandb.data_types.Table object at 0x7fafacdb10a0>, 'model_training/best_epoch': 4, 'model_training/best_val_accuracy': 0.509033203125, 'model_training/best_val_crossentropy': 1.734849214553833}


get_predictions_labels:   0%|          | 0/1400000 [00:00<?, ?it/s]

Perf after training {'accuracy': 0.44098571428571426, 'crossentropy': tensor(1.9284), '_timestamp': 1652301586, '_runtime': 158}
Done.



VBox(children=(Label(value='0.106 MB of 0.106 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁▄▆▇█▆▆▅██▇
crossentropy,▃█▄▅▁▂▃▅▂▂▅
model_training/best_epoch,▁▅▅█▅▅▅█▅██
model_training/best_val_accuracy,▅▄█▄█▄▄▂█▇▁
model_training/best_val_crossentropy,▄▇▁▆▁▄▅█▂▃▆

0,1
accuracy,0.44099
crossentropy,1.92839
model_training/best_epoch,4.0
model_training/best_val_accuracy,0.50903
model_training/best_val_crossentropy,1.73485


{'dataset_info': {'training': "'DirtyMNIST (Train, seed=0, 6000 samples)'",
  'test': "'DirtyMNIST (Test)'"},
 'initial_training_set_indices': [3827,
  1923,
  2244,
  5491,
  824,
  1023,
  5995,
  1774,
  5205,
  534,
  3689,
  3845,
  2708,
  1830,
  5809,
  3871,
  3888,
  838,
  633,
  3633],
 'evaluation_set_indices': [],
 'seed': 3026687564475523902,
 'active_learning_steps': [{'training': {'epochs': [{'accuracy': 0.37890625,
      'crossentropy': 1.7864514589309692},
     {'accuracy': 0.505859375, 'crossentropy': 1.4874533414840698},
     {'accuracy': 0.531494140625, 'crossentropy': 1.6502480506896973},
     {'accuracy': 0.510498046875, 'crossentropy': 1.7343323230743408},
     {'accuracy': 0.5224609375, 'crossentropy': 1.8292348384857178}],
    'best_epoch': 2},
   'evaluation_metrics': {'accuracy': 0.40954285714285715,
    'crossentropy': tensor(1.8577),
    '_timestamp': 1652301450,
    '_runtime': 22},
   'acquisition': {'indices': [('id', 5327)],
    'labels': tensor([7], 