In [1]:
import torch
import torch.nn as nn
import torch.functional as F
import math
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from batchbald_redux import repeated_mnist, active_learning, batchbald
from main.models import BayesianConvNet
from main.utils import save_experiment, load_experiment


%reload_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# loading data
train_dataset, test_dataset = repeated_mnist.create_MNIST_dataset()

# number of initial samples 
num_initial_samples = 40
num_classes = 10

# get indices of initial samples
initial_samples = active_learning.get_balanced_sample_indices(
    repeated_mnist.get_targets(train_dataset), num_classes=num_classes, n_per_digit=num_initial_samples / num_classes
)

# Experiment parameters
max_training_samples = 100  # Maximum number of samples to acquire from the pool dataset 
acquisition_batch_size = 5  # Number of samples to acquire in each acquisition step
num_inference_samples = 50  # Number of samples to use for inference in MC-Dropout
num_test_inference_samples = 5  
num_samples = 100000 # Number of samples to use for estimation in batchbald

test_batch_size = 512  # Batch size for testing
batch_size = 64  # Batch size for training
scoring_batch_size = 128  # Batch size for scoring 
training_iterations = 4096 * 6 # Number of training iterations (batches) to run

kwargs = {"num_workers": 1, "pin_memory": True}

use_cuda = torch.cuda.is_available()
device = "cuda" if use_cuda else "cpu"

print(f"use_cuda: {use_cuda}")

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, **kwargs)

active_learning_data = active_learning.ActiveLearningData(train_dataset)

# Split off the initial samples first.
active_learning_data.acquire(initial_samples)

# THIS REMOVES MOST OF THE POOL DATA. UNCOMMENT THIS TO TAKE ALL UNLABELLED DATA INTO ACCOUNT!
active_learning_data.extract_dataset_from_pool(55000)

train_loader = torch.utils.data.DataLoader(
    active_learning_data.training_dataset,
    sampler=active_learning.RandomFixedLengthSampler(active_learning_data.training_dataset, training_iterations),
    batch_size=batch_size,
    **kwargs,
)

pool_loader = torch.utils.data.DataLoader(
    active_learning_data.pool_dataset, batch_size=scoring_batch_size, shuffle=False, **kwargs
)

use_cuda: False




In [3]:
# Run experiment
test_accs = []
test_loss = []
added_indices = []

pbar = tqdm(initial=len(active_learning_data.training_dataset), total=max_training_samples, desc="Training Set Size")
loss_fn = nn.NLLLoss()

while True:
    model = BayesianConvNet(num_classes).to(device=device)
    optimizer = torch.optim.Adam(model.parameters())

    model.train()

    # Train
    for data, target in tqdm(train_loader, desc="Training", leave=False):
        data = data.to(device=device)
        target = target.to(device=device)

        optimizer.zero_grad()

        prediction = model(data, 1).squeeze(1)
        loss = loss_fn(prediction, target)

        loss.backward()
        optimizer.step()

    # Test
    loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in tqdm(test_loader, desc="Testing", leave=False):
            data = data.to(device=device)
            target = target.to(device=device)

            prediction = torch.logsumexp(model(data, num_test_inference_samples), dim=1) - math.log(
                num_test_inference_samples
            )
            loss += loss_fn(prediction, target)

            prediction = prediction.max(1)[1]
            correct += prediction.eq(target.view_as(prediction)).sum().item()

    loss /= len(test_loader.dataset)
    test_loss.append(loss)

    percentage_correct = 100.0 * correct / len(test_loader.dataset)
    test_accs.append(percentage_correct)

    print(
        "Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)".format(
            loss, correct, len(test_loader.dataset), percentage_correct
        )
    )

    if len(active_learning_data.training_dataset) >= max_training_samples:
        break

    # Acquire pool predictions
    N = len(active_learning_data.pool_dataset)
    logits_N_K_C = torch.empty((N, num_inference_samples, num_classes), dtype=torch.double, pin_memory=use_cuda)

    with torch.no_grad():
        model.eval()

        for i, (data, _) in enumerate(tqdm(pool_loader, desc="Evaluating Acquisition Set", leave=False)):
            data = data.to(device=device)

            lower = i * pool_loader.batch_size
            upper = min(lower + pool_loader.batch_size, N)
            logits_N_K_C[lower:upper].copy_(model(data, num_inference_samples).double(), non_blocking=True)

    with torch.no_grad():
        candidate_batch = batchbald.get_bald_batch(
            logits_N_K_C, acquisition_batch_size, dtype=torch.double, device=device
        )

    targets = repeated_mnist.get_targets(active_learning_data.pool_dataset)
    dataset_indices = active_learning_data.get_dataset_indices(candidate_batch.indices)

    print("Dataset indices: ", dataset_indices)
    print("Scores: ", candidate_batch.scores)
    print("Labels: ", targets[candidate_batch.indices])

    active_learning_data.acquire(candidate_batch.indices)
    added_indices.append(dataset_indices)
    pbar.update(len(dataset_indices))

Training Set Size:  40%|████      | 40/100 [00:00<?, ?it/s]

Test set: Average loss: 0.0028, Accuracy: 7141/10000 (71.41%)


Conditional Entropy: 100%|██████████| 4960/4960 [00:00<00:00, 7657.08it/s]
Entropy: 100%|██████████| 4960/4960 [00:00<00:00, 9624.61it/s] 


torch.return_types.topk(
values=tensor([1.9551, 1.9063, 1.9008, 1.8993, 1.8902, 1.8838, 1.8633, 1.8504, 1.8457,
        1.8390], dtype=torch.float64),
indices=tensor([3019, 1456, 2373, 3520,  517, 2681,  286, 3108,  907,  430]))
torch.return_types.topk(
values=tensor([1.2654, 1.1127, 1.0929, 1.0828, 1.0283, 0.9894, 0.9873, 0.9672, 0.9671,
        0.9641], dtype=torch.float64),
indices=tensor([3019, 1696, 3520, 1086, 2976, 3206, 2681,  517, 2357,  430]))


Entropy: 100%|██████████| 4960/4960 [00:00<00:00, 9483.36it/s] 
Training Set Size:  45%|████▌     | 45/100 [01:04<11:47, 12.86s/it]

Dataset indices:  [23472 28944 42327 41217 46580]
Scores:  [1.25087769893981, 1.19062686371772, 1.1860350331984852, 1.1561641162276248, 1.1340433182090535]
Labels:  tensor([2, 2, 2, 2, 6])




Test set: Average loss: 0.0025, Accuracy: 7268/10000 (72.68%)


Conditional Entropy: 100%|██████████| 4955/4955 [00:00<00:00, 8842.26it/s] 
Entropy: 100%|██████████| 4955/4955 [00:00<00:00, 9157.48it/s] 


torch.return_types.topk(
values=tensor([1.9814, 1.9566, 1.9539, 1.8923, 1.8859, 1.8807, 1.8695, 1.8576, 1.8460,
        1.8431], dtype=torch.float64),
indices=tensor([1411, 4318,  598, 1696, 3503, 3930, 2974, 4785, 4638,  967]))
torch.return_types.topk(
values=tensor([1.3288, 1.2746, 1.2679, 1.1955, 1.1505, 1.1452, 1.1250, 1.1045, 1.0791,
        1.0724], dtype=torch.float64),
indices=tensor([1696, 2974, 4318, 1086, 3017,  289, 1311, 4508, 1790,  598]))


Entropy: 100%|██████████| 4955/4955 [00:00<00:00, 8931.24it/s] 
Training Set Size:  50%|█████     | 50/100 [02:05<10:25, 12.52s/it]

Dataset indices:  [50075 52669 54740 14104  3895]
Scores:  [1.2328700166982527, 1.169212225878496, 1.1631054219063395, 1.1572460527253698, 1.151517323101395]
Labels:  tensor([7, 7, 7, 7, 7])




Test set: Average loss: 0.0027, Accuracy: 7158/10000 (71.58%)


Conditional Entropy: 100%|██████████| 4950/4950 [00:00<00:00, 8557.80it/s] 
Entropy: 100%|██████████| 4950/4950 [00:00<00:00, 9089.77it/s] 


torch.return_types.topk(
values=tensor([1.9419, 1.9084, 1.9045, 1.8867, 1.8861, 1.8844, 1.8672, 1.8422, 1.8408,
        1.8295], dtype=torch.float64),
indices=tensor([3515,   75,  966, 1975, 1918, 4633, 1694, 1745, 2133, 3308]))
torch.return_types.topk(
values=tensor([1.3825, 1.2546, 1.1980, 1.1734, 1.1660, 1.1361, 1.1272, 1.1015, 1.0772,
        1.0626], dtype=torch.float64),
indices=tensor([1694, 4504,  289, 4633, 1309, 4315, 3015, 1085, 1788, 1006]))


Entropy: 100%|██████████| 4950/4950 [00:00<00:00, 9317.37it/s] 
Training Set Size:  55%|█████▌    | 55/100 [03:09<09:30, 12.67s/it]

Dataset indices:  [ 9994 16756 13096 53017 39835]
Scores:  [1.2516065547946824, 1.2395376705061483, 1.2328000787566271, 1.1818893445182925, 1.1347621975508049]
Labels:  tensor([0, 7, 9, 2, 7])




Test set: Average loss: 0.0023, Accuracy: 7352/10000 (73.52%)




KeyboardInterrupt: 

In [None]:
plt.plot(np.arange(start=num_initial_samples, stop=max_training_samples + acquisition_batch_size, step=acquisition_batch_size), test_accs)
plt.xlabel("Training Set Size")
plt.ylabel("Test Accuracy")
plt.hlines(90, num_initial_samples, max_training_samples, colors='r', linestyles='dashed')

plt.show()

NameError: name 'plt' is not defined

## Storing results

In [None]:
assert False

AssertionError: 

In [None]:
params_dict = {
    'num_initial_samples': num_initial_samples,
    'num_classes': num_classes,
    'max_training_samples': max_training_samples,
    'acquisition_batch_size': acquisition_batch_size,
    'num_inference_samples': num_inference_samples,
    'num_test_inference_samples': num_test_inference_samples,
    'num_samples': num_samples,
    'test_batch_size': test_batch_size,
    'batch_size': batch_size,
    'scoring_batch_size': scoring_batch_size,
    'training_iterations': training_iterations
}

save_experiment('Lenet5-simple_BALD', params_dict, {
    'test_accs': test_accs,
    'test_loss': test_loss,
    'added_indices': added_indices
})