# Example Experiment
> Experiment using Repeated MNIST and BatchBALD vs BALD vs random sampling

This notebook ties everything together and runs an AL loop.

In [None]:
import blackhc.project.script
from tqdm.auto import tqdm

In [None]:
import torch
from torch import nn as nn
from torch.nn import functional as F

from batchbald_redux import active_learning, batchbald, consistent_mc_dropout, joint_entropy, repeated_mnist

Let's define our Bayesian CNN model that we will use to train MNIST.

In [None]:
class BayesianCNN(consistent_mc_dropout.BayesianModule):
    def __init__(self, num_classes=10):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv1_drop = consistent_mc_dropout.ConsistentMCDropout2d()
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.conv2_drop = consistent_mc_dropout.ConsistentMCDropout2d()
        self.fc1 = nn.Linear(1024, 128)
        self.fc1_drop = consistent_mc_dropout.ConsistentMCDropout()
        self.fc2 = nn.Linear(128, num_classes)

    def mc_forward_impl(self, input: torch.Tensor):
        input = F.relu(F.max_pool2d(self.conv1_drop(self.conv1(input)), 2))
        input = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(input)), 2))
        input = input.view(-1, 1024)
        input = F.relu(self.fc1_drop(self.fc1(input)))
        input = self.fc2(input)
        input = F.log_softmax(input, dim=1)

        return input

Grab our dataset, we'll use Repeated-MNIST.

In [None]:
train_dataset, test_dataset = repeated_mnist.create_repeated_MNIST_dataset(num_repetitions=1, add_noise=False)

In [None]:
num_initial_samples = 20
num_classes = 10

active_learning_data = active_learning.ActiveLearningData(train_dataset)

initial_samples = active_learning.get_balanced_sample_indices(
    repeated_mnist.get_targets(train_dataset),
    num_classes=num_classes,
    n_per_digit=2)

# Split off the initial samples first.
active_learning_data.acquire(initial_samples)

In [None]:
test_batch_size = 512
batch_size = 64
scoring_batch_size = 512
epoch_samples = 8192*4

use_cuda = False

kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=test_batch_size,
                                          shuffle=False,
                                          **kwargs)

train_loader = torch.utils.data.DataLoader(
    active_learning_data.training_dataset,
    sampler=active_learning.RandomFixedLengthSampler(active_learning_data.training_dataset,
                                     epoch_samples),
    batch_size=batch_size,
    **kwargs,
)

pool_loader = torch.utils.data.DataLoader(
    active_learning_data.pool_dataset,
    batch_size=scoring_batch_size,
    shuffle=False,
    **kwargs)

In [None]:
# experiment
max_training_samples = 150
acquisition_batch_size = 10
num_inference_samples = 20

test_accs = []
test_loss = []
added_indices = []

while True:
    print(f"Num acquired samples: {len(active_learning_data.training_dataset)}")

    model = BayesianCNN(num_classes)
    optimizer = torch.optim.Adam(model.parameters())

    model.train()

    # Train
    for data, target in tqdm(train_loader, desc="Training"):
        #data = data.cuda()
        #target = target.cuda()

        optimizer.zero_grad()

        prediction = model(data, 1).squeeze(1)
        loss = F.nll_loss(prediction, target)

        loss.backward()
        optimizer.step()

    # Test
    loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in tqdm(test_loader, desc="Testing"):
            #data = data.cuda()
            #target = target.cuda()

            prediction = model(data, 1).squeeze(1)
            loss += F.nll_loss(prediction, target, reduction="sum")

            prediction = prediction.max(1)[1]
            correct += prediction.eq(target.view_as(prediction)).sum().item()

    loss /= len(test_loader.dataset)
    test_loss.append(loss)

    percentage_correct = 100.0 * correct / len(test_loader.dataset)
    test_accs.append(percentage_correct)

    print("Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)".format(
        loss, correct, len(test_loader.dataset), percentage_correct))

    if len(active_learning_data.training_dataset) >= max_training_samples:
        break

    # Acquire pool predictions
    N = len(active_learning_data.pool_dataset)
    probs_N_K_C = torch.empty((N, num_inference_samples, num_classes))

    with torch.no_grad():
        model.eval()

        for i, (data, _) in enumerate(
                tqdm(available_loader, desc="Evaluating Acquisition Set")):
            lower = i * available_loader.batch_size
            upper = min(lower + available_loader.batch_size, N)
            probs_N_K_C[lower:upper] = model(
                data, num_inference_samples).double().exp_()

    candidate_batch = batchbald.get_batchbald_batch(probs_N_K_C,
                                                    acquisition_batch_size,
                                                    dtype=torch.double)

    targets = repeated_mnist.get_targets(
        active_learning_data.pool_dataset)
    dataset_indices = active_learning_data.get_dataset_indices(
        candidate_batch.indices)

    print("Dataset indices: ", dataset_indices)
    print("Scores: ", candidate_batch.scores)
    print("Labels: ", targets[candidate_batch.indices])

    active_learning_data.acquire(candidate_batch.indices)
    added_indices.append(dataset_indices)

Num acquired samples: 20


HBox(children=(FloatProgress(value=0.0, description='Training', max=512.0, style=ProgressStyle(description_wid…




KeyboardInterrupt: 

In [None]:
# experiment