<h1>ALBench project: BatchBALD</h1>

<p>This Jupyter lab demonstrates use of the al_bench Active Learning Benchmark Tool with a Bayesian Neural Network and BatchBALD</p>

<h2>Overview</h2>

<p>The tool takes an input dataset, Bayesian machine learning model, and active learning strategy and outputs information to be used in evaluating how well the strategy does with that model and dataset. By running the tool multiple times with different inputs, the tool allows comparisons across different active learning strategies and also allows comparisons across different models and across different datasets. Researchers can use the tool to test proposed active learning strategies in the context of a specific model and dataset; or multiple models and datasets can be used to get a broader picture of each strategy's effectiveness in multiple contexts. As an alternative use case, multiple runs of the tool with different models and datasets can be compared, evaluating these models and datasets for their compatibility with a given active learning strategy.</p>

<p>In the present example, we will show the use of the BatchBALD active learning strategie using an MNIST dataset and Bayesian model.  To do this we will fetch the dataset and provide it to a dataset handler, and we will build a model and provide it to a model handler.  These are then used with the active learning strategy handler.</p>

<h2>Install needed Python packages</h2>

<p>If you haven't yet installed these packages, remove the "<code>#</code>" characters and run this code block.</p>

In [1]:
#!pip install -e ../../ALBench  # Installs al_bench and dependencies
#!pip install ipywidgets
#!jupyter labextension install @jupyter-widgets/jupyterlab-manager

<h2>Import Python packages</h2>

We will use the fully qualified names in this notebook so that the provenance of the functions is obvious.

In [2]:
import al_bench as alb
import batchbald_redux as bbald
import batchbald_redux.active_learning
import batchbald_redux.consistent_mc_dropout
import batchbald_redux.repeated_mnist
import numpy as np
import os
import random
import shutil
import torch
from datetime import datetime

2023-02-13 14:48:42.737252: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-13 14:48:42.861354: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-02-13 14:48:42.896851: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-02-13 14:48:43.490316: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; 

<h2>Find a dataset and create a Dataset Handler</h2>

<p>We fetch the labeled MNIST dataset of images of the digits 0 to 9.  The benchmarking tool requires that all examples be labeled, although the labels are not used initially.  The label for a sample is revealed to the machine learning training only when the active learning strategy indicates that the clinician has been asked to label that sample.</p>

In [3]:
mnist = bbald.repeated_mnist.create_repeated_MNIST_dataset(
    num_repetitions=1, add_noise=False
)
train_dataset, test_dataset = mnist

print("Dataset is read")

Dataset is read


<h2>Create a DatasetHandler</h2>

<p>Generally, a dataset may come in many possible formats.  This dataset is a tensorflow.Dataset, but we convert it to a set of Python numpy arrays because that is what the al_bench.DatasetHandler expects.</p>

In [4]:
# al_bench datasets are supplied as numpy arrays.
# Build the numpy arrays from a subset of the data
train_dataset_list = random.sample([d for d in train_dataset], 500)
test_dataset_list = random.sample([d for d in test_dataset], 50)
num_training_indices = len(train_dataset_list)
num_validation_indices = len(test_dataset_list)
dataset_list = train_dataset_list + test_dataset_list
# Unzip the data set into separate (unlabeled) input data and their labels.
my_feature_vectors = np.concatenate([d[0].numpy() for d in dataset_list])  # data only
my_labels = np.array([[d[1]] for d in dataset_list])  # each is list of one label only
# This dataset is the digits "0" through "9" which we will enumerate with the
# values 0 through 9.
num_classes = 10
# We have one label per feature_vector so we need a list of one dictionary.
my_label_definitions = [{i: {"description": repr(i)} for i in range(num_classes)}]
# We will indicate the validation examples by their indices.
validation_indices = np.array(
    range(num_training_indices, num_training_indices + num_validation_indices)
)
print(f"feature_shape = {my_feature_vectors.shape[1:]}")
print("Dataset is ready as numpy")

# Tell al_bench about the dataset
my_dataset_handler = alb.dataset.GenericDatasetHandler()
my_dataset_handler.set_all_feature_vectors(my_feature_vectors)
my_dataset_handler.set_all_label_definitions(my_label_definitions)
my_dataset_handler.set_all_labels(my_labels)
my_dataset_handler.set_validation_indices(validation_indices)
print("DatasetHandler is initialized")

# We'll start the first pass of active learning with some randomly chosen samples
# from the training data set.
num_initial_training = 20
currently_labeled_examples = np.array(
    random.sample(range(num_training_indices), num_initial_training)
)

Dataset is converted to numpy
Datahandler is initialized


<h2>Create a model and a Model Handler</h2>

<p>Here we construct a model that we will train with active learning.  This is nearly verbatim from <a href="https://blackhc.github.io/batchbald_redux/example_experiment.html">https://blackhc.github.io/batchbald_redux/example_experiment.html</a>.</p>

<h3>Build a TensorFlow model and its Model Handler</h3>

In [12]:
print("batchbald_redux does not support tensorflow at this time")

batchbald_redux does not support tensorflow at this time


<h3>Build a Torch model and its Model Handler</h3>

In [7]:
class BayesianCNN(bbald.consistent_mc_dropout.BayesianModule):
    def __init__(self, num_classes=10):
        super().__init__()

        self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=5)
        self.conv1_drop = bbald.consistent_mc_dropout.ConsistentMCDropout2d()
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=5)
        self.conv2_drop = bbald.consistent_mc_dropout.ConsistentMCDropout2d()
        self.fc1 = torch.nn.Linear(1024, 128)
        self.fc1_drop = bbald.consistent_mc_dropout.ConsistentMCDropout()
        self.fc2 = torch.nn.Linear(128, num_classes)

    def mc_forward_impl(self, input: torch.Tensor):
        input = torch.nn.functional.relu(
            torch.nn.functional.max_pool2d(self.conv1_drop(self.conv1(input)), 2)
        )
        input = torch.nn.functional.relu(
            torch.nn.functional.max_pool2d(self.conv2_drop(self.conv2(input)), 2)
        )
        input = input.view(-1, 1024)
        input = torch.nn.functional.relu(self.fc1_drop(self.fc1(input)))
        input = self.fc2(input)
        input = torch.nn.functional.log_softmax(input, dim=1)

        return input


my_pytorch_model = BayesianCNN(num_classes)
print("Created torch model")

# Tell al_bench about the model
my_pytorch_model_handler = alb.model.SamplingBayesianPyTorchModelHandler()
my_pytorch_model_handler.set_model(my_pytorch_model)
print("PyTorch model handler built")

Created torch model


<h3>Choose one of the models to proceed with</h3>

<p>If we had a tensorflow model at this point we could choose to proceed with it or the torch model.</p>

In [9]:
# my_model_handler = my_tensorflow_model_handler
my_model_handler = my_pytorch_model_handler

<h2>Make use of Strategy Handlers for active learning</h2>

<p>At this point we have one active learning strategy available for a Bayesian machine learning model, called BatchBALD.  The strategy looks at the unlabeled samples, evaluates them, and the selects the set of samples that evaluates to be both uncertain predictions individually and distinctive within the group.</p>

<p>See Kirsch A, van Amersfoort J, Gal Y.  BatchBALD: Efficient and Diverse Batch Acquisition for Deep Bayesian Active Learning.  2019 Jun 19.  <a href="https://arxiv.org/abs/1906.08158">arXiv:1906.08158</a> for more information about BatchBALD.</p>

In [10]:
all_logs_dir = "runs-SamplingBayesian"
try:
    shutil.rmtree(all_logs_dir)  # DELETE OLD LOG FILES
except:
    pass

for name, my_strategy_handler in (
    #   ("BALD", alb.strategy.BaldStrategyHandler()),
    ("BatchBALD", alb.strategy.BatchBaldStrategyHandler()),
):
    print(f"=== Begin Strategy {repr(name)} at {datetime.now()} ===")
    my_strategy_handler.set_dataset_handler(my_dataset_handler)
    my_strategy_handler.set_model_handler(my_model_handler)
    # We've supplied only one label per feature vector, so choose it with
    # label_of_interest=0
    my_strategy_handler.set_learning_parameters(
        label_of_interest=0, maximum_queries=8, number_to_select_per_query=10
    )

    # ################################################################
    # Simulate the strategy.
    my_strategy_handler.run(currently_labeled_examples)
    # ################################################################

    # We will write out collected information to disk.  First say where:
    log_dir = os.path.join(all_logs_dir, name)
    # Write accuracy and loss information during training
    my_strategy_handler.write_train_log_for_tensorboard(log_dir=log_dir)
    # Write confidence statistics during active learning
    my_strategy_handler.write_confidence_log_for_tensorboard(log_dir=log_dir)
print(f"=== Done at {datetime.now()} ===")

=== Begin Strategy 'BatchBALD' at 2023-02-13 14:48:56.298307 ===
Training with 20 examples
Predicting for 550 examples


Conditional Entropy:   0%|          | 0/480 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/10 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/480 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/480 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/480 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/480 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/480 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/480 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/480 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/480 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/480 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/480 [00:00<?, ?it/s]

Training with 30 examples
Predicting for 550 examples


Conditional Entropy:   0%|          | 0/470 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/10 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/470 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/470 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/470 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/470 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/470 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/470 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/470 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/470 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/470 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/470 [00:00<?, ?it/s]

Training with 40 examples
Predicting for 550 examples


Conditional Entropy:   0%|          | 0/460 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/10 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/460 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/460 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/460 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/460 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/460 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/460 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/460 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/460 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/460 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/460 [00:00<?, ?it/s]

Training with 50 examples
Predicting for 550 examples


Conditional Entropy:   0%|          | 0/450 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/10 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/450 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/450 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/450 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/450 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/450 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/450 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/450 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/450 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/450 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/450 [00:00<?, ?it/s]

Training with 60 examples
Predicting for 550 examples


Conditional Entropy:   0%|          | 0/440 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/10 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/440 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/440 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/440 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/440 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/440 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/440 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/440 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/440 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/440 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/440 [00:00<?, ?it/s]

Training with 70 examples
Predicting for 550 examples


Conditional Entropy:   0%|          | 0/430 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/10 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/430 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/430 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/430 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/430 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/430 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/430 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/430 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/430 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/430 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/430 [00:00<?, ?it/s]

Training with 80 examples
Predicting for 550 examples


Conditional Entropy:   0%|          | 0/420 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/10 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/420 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/420 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/420 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/420 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/420 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/420 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/420 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/420 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/420 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/420 [00:00<?, ?it/s]

Training with 90 examples
Predicting for 550 examples


Conditional Entropy:   0%|          | 0/410 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/10 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/410 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/410 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/410 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/410 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/410 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/410 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/410 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/410 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/410 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/410 [00:00<?, ?it/s]

Training with 100 examples
Predicting for 550 examples
=== Done at 2023-02-13 14:53:21.120731 ===


<h2>Use with TensorBoard</h2>

<p>TensorBoard provides a way to graph the information from the log files we have written.  If it is not blocked by a firewall, the TensorBoard graphics will appear in this Jupyter lab.  Otherwise, the TensorBoard output can be made to appear in any web browser by launching "<code>tensorboard --logdir runs</code>" from a command prompt and then asking the web browser to load "<code>http://localhost:6006/</code>".</p>

<p>Because these are randomized simulations you will not see the same output each time you run them.  Clicking on the "Scalars" tab allows one to change the smoothing of the displayed graphics, e.g., to 0.</p>

<p>The Confidence graphs show how the certainty, among samples that are (simulated as) not yet labeled, changes during the active learning process; specifically, as a function of the number of samples that have been labeled so far.  For example, Confidence/margin/10% measures certainty for a sample's prediction as the difference between the two highest-scoring lablels, and the 10% indicates that this is the 10 percentile among all unlabeled samples -- which is among the worst performing of these samples.  Confidence/entropy/50% shows the median value among unlabeled samples of the negative entropy.  Confidence/maximum/5% shows the 5 percentile value -- among the very worst -- for a sample's maximum label score, which is the score for its predicted label.</p>

In [11]:
%load_ext tensorboard
%tensorboard --logdir {all_logs_dir}