In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from typing import Callable

import jax.numpy as jnp
import numpy as np
import torchvision
from jax import random
from sklearn.model_selection import ParameterGrid
from tqdm import tqdm

from bayesian_active_learning.acquisition_functions import (
    get_acquisition_function,
)
from bayesian_active_learning.experiment import experiment_run
from bayesian_active_learning.utils import one_hot

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme()

# Summary

- The following notebook attempts to reproduce figure 1 in "Deep Bayesian Active Learning with Image Data".
- We start with an initial training set of 100 points, and successively select 100 new points from the unlabelled pool set to be added to the training set.  


- The notebook produces a figure in `figures/01_reproducing_plot`.
- All three methods perform similarly in this regime - This might be expected given that we are using such large acquisition sizes - further experiments explore this idea (and provide justification for BatchBALD)

# 1. Experimental setup

In [None]:
def full_experiment(acquisition_function_name: str, seed: int) -> jnp.ndarray:
    # get acquisition function
    acquisition_function = get_acquisition_function(acquisition_function_name)

    # grab all datasets + preprocess

    full_train_dataset = torchvision.datasets.MNIST(
        "../datasets", train=True, download=True
    )
    full_test_dataset = torchvision.datasets.MNIST(
        "../datasets", train=False, download=True
    )

    num_classes = len(full_train_dataset.classes)

    all_train_X = np.array(full_train_dataset.data) / 255.0
    all_train_y = one_hot(np.array(full_train_dataset.targets), k=num_classes)

    all_test_X = np.array(full_test_dataset.data) / 255.0
    all_test_y = one_hot(np.array(full_test_dataset.targets), k=num_classes)

    # shuffle the datasets according to the seed and split into initial train, test and pool sets
    num_initial_train_points = 100
    num_validation_points = 100

    rng = np.random.default_rng(seed)
    shuffle_idx = rng.permutation(len(all_train_X))

    all_train_X, all_train_y = all_train_X[shuffle_idx], all_train_y[shuffle_idx]

    initial_train_X, val_X, initial_pool_X = np.split(
        all_train_X,
        [num_initial_train_points, num_initial_train_points + num_validation_points],
    )
    initial_train_y, val_y, initial_pool_y = np.split(
        all_train_y,
        [num_initial_train_points, num_initial_train_points + num_validation_points],
    )

    # run the active learning procedure with the acquisition function and datasets

    key = random.PRNGKey(seed)

    test_accuracy_history = experiment_run(
        train_set=(initial_train_X, initial_train_y),
        val_set=(val_X, val_y),
        pool_set=(initial_pool_X, initial_pool_y),
        test_set=(all_test_X, all_test_y),
        weight_decay=1e-2,
        acquisition_fn=acquisition_function,
        num_predictive_samples=100,
        num_acquired_points_per_iteration=100,
        num_iterations=5,
        key=key,
    )

    return test_accuracy_history

# Setup parameter grid

In [None]:
acquisition_functions = ["BALD", "Random", "Max Entropy"]

seeds = np.arange(4)

param_grid = {"acquisition_function_name": acquisition_functions, "seed": seeds}
arg_list = list(ParameterGrid(param_grid))
print(arg_list)

In [None]:
results = []

for args in tqdm(arg_list):
    results.append(full_experiment(**args))

# Plotting

In [None]:
grouped_results = {af: [] for af in acquisition_functions}

for result, arg in zip(results, arg_list):
    grouped_results[arg["acquisition_function_name"]].append(result)

grouped_results = {af: jnp.stack(r) for af, r in grouped_results.items()}

In [None]:
for label, result in grouped_results.items():
    median = 100 * np.median(result, axis=0)
    lower_quartile = 100 * np.percentile(result, 25, axis=0)
    upper_quartile = 100 * np.percentile(result, 75, axis=0)

    plt.plot(100 * np.arange(len(median)), median, label=label)
    plt.fill_between(
        100 * np.arange(len(median)), lower_quartile, upper_quartile, alpha=0.25
    )

plt.xlabel("Number of acquired points")
plt.ylabel("Test accuracy")
plt.ylim(75, 100)
plt.legend()
plt.savefig("../figures/01_reproducing_plot/plot.png", bbox_inches="tight")
plt.show()