# Example Experiment
> Experiment using Repeated MNIST and BatchBALD vs BALD vs random sampling

This notebook ties everything together and runs an AL loop.

In [8]:
import blackhc.project.script
from tqdm.auto import tqdm

In [9]:
import torch
import math
from torch import nn as nn
from torch.nn import functional as F

from batchbald_redux import active_learning, batchbald, consistent_mc_dropout, joint_entropy, repeated_mnist

Let's define our Bayesian CNN model that we will use to train MNIST.

In [10]:
class BayesianCNN(consistent_mc_dropout.BayesianModule):
    def __init__(self, num_classes=10):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv1_drop = consistent_mc_dropout.ConsistentMCDropout2d()
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.conv2_drop = consistent_mc_dropout.ConsistentMCDropout2d()
        self.fc1 = nn.Linear(1024, 128)
        self.fc1_drop = consistent_mc_dropout.ConsistentMCDropout()
        self.fc2 = nn.Linear(128, num_classes)

    def mc_forward_impl(self, input: torch.Tensor):
        input = F.relu(F.max_pool2d(self.conv1_drop(self.conv1(input)), 2))
        input = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(input)), 2))
        input = input.view(-1, 1024)
        input = F.relu(self.fc1_drop(self.fc1(input)))
        input = self.fc2(input)
        input = F.log_softmax(input, dim=1)

        return input

Grab our dataset, we'll use Repeated-MNIST.

In [11]:
train_dataset, test_dataset = repeated_mnist.create_repeated_MNIST_dataset(num_repetitions=1, add_noise=False)

In [12]:
num_initial_samples = 20
num_classes = 10

initial_samples = active_learning.get_balanced_sample_indices(
    repeated_mnist.get_targets(train_dataset),
    num_classes=num_classes,
    n_per_digit=2)

In [None]:
# experiment
max_training_samples = 150
acquisition_batch_size = 5
num_inference_samples = 100
num_test_inference_samples = 5
num_samples = 100000

test_batch_size = 512
batch_size = 64
scoring_batch_size = 128
epoch_samples = 4096 * 6

use_cuda = torch.cuda.is_available()

print(f"use_cuda: {use_cuda}")

device = "cuda" if use_cuda else "cpu"

kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=test_batch_size,
                                          shuffle=False,
                                          **kwargs)

active_learning_data = active_learning.ActiveLearningData(train_dataset)

# Split off the initial samples first.
active_learning_data.acquire(initial_samples)

train_loader = torch.utils.data.DataLoader(
    active_learning_data.training_dataset,
    sampler=active_learning.RandomFixedLengthSampler(
        active_learning_data.training_dataset, epoch_samples),
    batch_size=batch_size,
    **kwargs,
)

pool_loader = torch.utils.data.DataLoader(active_learning_data.pool_dataset,
                                          batch_size=scoring_batch_size,
                                          shuffle=False,
                                          **kwargs)

# Run experiment
test_accs = []
test_loss = []
added_indices = []

pbar = tqdm(initial=len(active_learning_data.training_dataset),
            total=max_training_samples,
            desc="Training Set Size")

while True:
    model = BayesianCNN(num_classes).to(device=device)
    optimizer = torch.optim.Adam(model.parameters())

    model.train()

    # Train
    for data, target in tqdm(train_loader, desc="Training", leave=False):
        data = data.to(device=device)
        target = target.to(device=device)

        optimizer.zero_grad()

        prediction = model(data, 1).squeeze(1)
        loss = F.nll_loss(prediction, target)

        loss.backward()
        optimizer.step()

    # Test
    loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in tqdm(test_loader, desc="Testing", leave=False):
            data = data.to(device=device)
            target = target.to(device=device)

            prediction = torch.logsumexp(
                model(data, num_test_inference_samples),
                dim=1) - math.log(num_test_inference_samples)
            loss += F.nll_loss(prediction, target, reduction="sum")

            prediction = prediction.max(1)[1]
            correct += prediction.eq(target.view_as(prediction)).sum().item()

    loss /= len(test_loader.dataset)
    test_loss.append(loss)

    percentage_correct = 100.0 * correct / len(test_loader.dataset)
    test_accs.append(percentage_correct)

    print("Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)".format(
        loss, correct, len(test_loader.dataset), percentage_correct))

    if len(active_learning_data.training_dataset) >= max_training_samples:
        break

    # Acquire pool predictions
    N = len(active_learning_data.pool_dataset)
    logits_N_K_C = torch.empty((N, num_inference_samples, num_classes),
                               dtype=torch.double,
                               pin_memory=use_cuda)

    with torch.no_grad():
        model.eval()

        for i, (data, _) in enumerate(
                tqdm(pool_loader,
                     desc="Evaluating Acquisition Set",
                     leave=False)):
            data = data.to(device=device)

            lower = i * pool_loader.batch_size
            upper = min(lower + pool_loader.batch_size, N)
            logits_N_K_C[lower:upper].copy_(model(
                data, num_inference_samples).double(),
                                            non_blocking=True)

    with torch.no_grad():
        candidate_batch = batchbald.get_batchbald_batch(logits_N_K_C,
                                                        acquisition_batch_size,
                                                        num_samples,
                                                        dtype=torch.double,
                                                        device=device)

    targets = repeated_mnist.get_targets(active_learning_data.pool_dataset)
    dataset_indices = active_learning_data.get_dataset_indices(
        candidate_batch.indices)

    print("Dataset indices: ", dataset_indices)
    print("Scores: ", candidate_batch.scores)
    print("Labels: ", targets[candidate_batch.indices])

    active_learning_data.acquire(candidate_batch.indices)
    added_indices.append(dataset_indices)
    pbar.update(len(dataset_indices))

use_cuda: True


HBox(children=(IntProgress(value=0, description='Training Set Size', max=150, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Training', max=384, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Testing', max=20, style=ProgressStyle(description_width='init…

Test set: Average loss: 1.8308, Accuracy: 6338/10000 (63.38%)


HBox(children=(IntProgress(value=0, description='Evaluating Acquisition Set', max=157, style=ProgressStyle(des…

HBox(children=(IntProgress(value=0, description='Conditional Entropy', max=19980, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='BatchBALD', max=5, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19980, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19980, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19980, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19980, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19980, style=ProgressSt…

Dataset indices:  [ 8289  3582 53863 25823  8257]
Scores:  [1.3557736757796552, 2.519825665002863, 3.4085050062563003, 3.9695525038521966, 4.2919329737575325]
Labels:  tensor([0, 2, 3, 0, 2])


HBox(children=(IntProgress(value=0, description='Training', max=384, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Testing', max=20, style=ProgressStyle(description_width='init…

Test set: Average loss: 1.4175, Accuracy: 6985/10000 (69.85%)


HBox(children=(IntProgress(value=0, description='Evaluating Acquisition Set', max=157, style=ProgressStyle(des…

HBox(children=(IntProgress(value=0, description='Conditional Entropy', max=19975, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='BatchBALD', max=5, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19975, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19975, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19975, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19975, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19975, style=ProgressSt…

Dataset indices:  [52012 41383 13682 42198  6185]
Scores:  [1.3061743472090992, 2.3644483581086218, 3.221506615442796, 3.8194846029077567, 4.183660860241003]
Labels:  tensor([8, 0, 8, 4, 3])


HBox(children=(IntProgress(value=0, description='Training', max=384, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Testing', max=20, style=ProgressStyle(description_width='init…

Test set: Average loss: 1.2743, Accuracy: 7269/10000 (72.69%)


HBox(children=(IntProgress(value=0, description='Evaluating Acquisition Set', max=157, style=ProgressStyle(des…

HBox(children=(IntProgress(value=0, description='Conditional Entropy', max=19970, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='BatchBALD', max=5, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19970, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19970, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19970, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19970, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19970, style=ProgressSt…

Dataset indices:  [11657 37137 14866 28222 34614]
Scores:  [1.2659351842478004, 2.3131665302748754, 3.202945514749244, 3.8116075718264586, 4.159970164562334]
Labels:  tensor([0, 5, 7, 6, 2])


HBox(children=(IntProgress(value=0, description='Training', max=384, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Testing', max=20, style=ProgressStyle(description_width='init…

Test set: Average loss: 1.1519, Accuracy: 7435/10000 (74.35%)


HBox(children=(IntProgress(value=0, description='Evaluating Acquisition Set', max=156, style=ProgressStyle(des…

HBox(children=(IntProgress(value=0, description='Conditional Entropy', max=19965, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='BatchBALD', max=5, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19965, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19965, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19965, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19965, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19965, style=ProgressSt…

Dataset indices:  [39411 13642 19396  8488 16077]
Scores:  [1.3163655603620135, 2.405610005940731, 3.1817393501560343, 3.758214079939448, 4.122080265207265]
Labels:  tensor([2, 5, 5, 6, 6])


HBox(children=(IntProgress(value=0, description='Training', max=384, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Testing', max=20, style=ProgressStyle(description_width='init…

Test set: Average loss: 0.9635, Accuracy: 7809/10000 (78.09%)


HBox(children=(IntProgress(value=0, description='Evaluating Acquisition Set', max=156, style=ProgressStyle(des…

HBox(children=(IntProgress(value=0, description='Conditional Entropy', max=19960, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='BatchBALD', max=5, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19960, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19960, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19960, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19960, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19960, style=ProgressSt…

Dataset indices:  [40057  4606 55743 26444 37870]
Scores:  [1.209489071500093, 2.283754189343815, 3.102376195203646, 3.6698826219743843, 4.038958571081122]
Labels:  tensor([5, 9, 3, 1, 8])


HBox(children=(IntProgress(value=0, description='Training', max=384, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Testing', max=20, style=ProgressStyle(description_width='init…

Test set: Average loss: 0.7743, Accuracy: 8067/10000 (80.67%)


HBox(children=(IntProgress(value=0, description='Evaluating Acquisition Set', max=156, style=ProgressStyle(des…

HBox(children=(IntProgress(value=0, description='Conditional Entropy', max=19955, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='BatchBALD', max=5, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19955, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19955, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19955, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19955, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19955, style=ProgressSt…

Dataset indices:  [ 2748 25910 24223 32954 20110]
Scores:  [1.123391331671529, 2.140858614880564, 3.012231409824741, 3.66999233531654, 4.095480093557085]
Labels:  tensor([2, 1, 8, 5, 4])


HBox(children=(IntProgress(value=0, description='Training', max=384, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Testing', max=20, style=ProgressStyle(description_width='init…

Test set: Average loss: 0.6680, Accuracy: 8267/10000 (82.67%)


HBox(children=(IntProgress(value=0, description='Evaluating Acquisition Set', max=156, style=ProgressStyle(des…

HBox(children=(IntProgress(value=0, description='Conditional Entropy', max=19950, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='BatchBALD', max=5, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19950, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19950, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19950, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19950, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19950, style=ProgressSt…

Dataset indices:  [56615 37249 50461   384 32509]
Scores:  [1.2427624719236614, 2.2830006810904986, 3.1605394537163587, 3.7497607635339576, 4.128859502902091]
Labels:  tensor([3, 5, 7, 7, 8])


HBox(children=(IntProgress(value=0, description='Training', max=384, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Testing', max=20, style=ProgressStyle(description_width='init…

Test set: Average loss: 0.6514, Accuracy: 8397/10000 (83.97%)


HBox(children=(IntProgress(value=0, description='Evaluating Acquisition Set', max=156, style=ProgressStyle(des…

HBox(children=(IntProgress(value=0, description='Conditional Entropy', max=19945, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='BatchBALD', max=5, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19945, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19945, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19945, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19945, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19945, style=ProgressSt…

Dataset indices:  [47695 30925 50010  3916 45114]
Scores:  [1.190816694089635, 2.2611525375451773, 3.0947280938353208, 3.7115958775092626, 4.084116729465448]
Labels:  tensor([4, 2, 5, 7, 7])


HBox(children=(IntProgress(value=0, description='Training', max=384, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Testing', max=20, style=ProgressStyle(description_width='init…

Test set: Average loss: 0.5701, Accuracy: 8386/10000 (83.86%)


HBox(children=(IntProgress(value=0, description='Evaluating Acquisition Set', max=156, style=ProgressStyle(des…

HBox(children=(IntProgress(value=0, description='Conditional Entropy', max=19940, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='BatchBALD', max=5, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19940, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19940, style=ProgressSt…

HBox(children=(IntProgress(value=0, description='ExactJointEntropy.compute_batch', max=19940, style=ProgressSt…