In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import seaborn as sns

import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from batchbald_redux import repeated_mnist, active_learning, batchbald
from main.models import ConvNet
from main.training_models import test_performance
from main.utils import save_experiment, load_experiment, log_experiment, generate_experiment_id
from laplace.curvature import AsdlGGN, AsdlGGN
from main.laplace_batch import get_laplace_batch
from dataclasses import dataclass
from main.active_learning import run_active_learning

sns.set_palette(sns.color_palette("Spectral"))

%reload_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


## Settings for Active Learning

In [3]:
# set configurations
@dataclass
class ActiveLearningConfig:
    subset_of_weights: str = 'last_layer'
    hessian_structure: str = 'kron'
    backend: str = 'AsdlGGN'
    temperature: float = 1
    max_training_samples: int = 100
    acquisition_batch_size: int = 5
    al_method: str = 'entropy'
    test_batch_size: int = 512
    num_classes: int = 10
    num_initial_samples: int = 40
    training_iterations: int = 4096 * 6
    scoring_batch_size: int = 64
    train_batch_size: int = 64
    extract_pool: int = 55000  # number of samples to extract from the dataset (bit of a hack)

experiment_name = 'lowtemperature_'  # provide descriptive name for the experiment
experiment_name += generate_experiment_id()

config = ActiveLearningConfig()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
kwargs = {"num_workers": 1, "pin_memory": True}

save_results = False

Using device: cpu


## Load and Prepare Data

In [4]:
# loading data
train_dataset, test_dataset = repeated_mnist.create_MNIST_dataset()

# get indices of initial samples
initial_samples = active_learning.get_balanced_sample_indices(
    repeated_mnist.get_targets(train_dataset), num_classes=config.num_classes, n_per_digit=config.num_initial_samples / config.num_classes
)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=config.test_batch_size, shuffle=False, **kwargs)

active_learning_data = active_learning.ActiveLearningData(train_dataset)

# Split off the initial samples first.
active_learning_data.acquire(initial_samples)

# THIS REMOVES MOST OF THE POOL DATA. REMOVE THIS LINE TO USE THE FULL POOL
active_learning_data.extract_dataset_from_pool(config.extract_pool)

train_loader = torch.utils.data.DataLoader(
    active_learning_data.training_dataset,
    sampler=active_learning.RandomFixedLengthSampler(active_learning_data.training_dataset, config.training_iterations),
    batch_size=config.train_batch_size,
    **kwargs,
)

pool_loader = torch.utils.data.DataLoader(
    active_learning_data.pool_dataset, batch_size=config.scoring_batch_size, shuffle=False, **kwargs
)



## Runs Active Learning with settings

In [4]:
results = run_active_learning(
    train_loader=train_loader,
    test_loader=test_loader, 
    pool_loader=pool_loader,
    active_learning_data=active_learning_data,
    model_constructor=ConvNet, 
    config=config, 
    device=device
    )

Training Set Size:  40%|████      | 40/100 [00:00<?, ?it/s]2024-07-28 11:14:36,897 - INFO - Training set size: 40, Test set accuracy: 76.53, Test set loss: -0.0117
Computing entropies: 78it [00:04, 17.98it/s]
Training Set Size:  45%|████▌     | 45/100 [00:35<06:29,  7.08s/it]

Dataset indices:  [26852  8713 13063  1804 56874]
Scores:  [1.9689174890518188, 1.9535801410675049, 1.9389352798461914, 1.9305953979492188, 1.9218602180480957]
Labels:  tensor([8, 5, 9, 8, 8])


2024-07-28 11:15:15,546 - INFO - Training set size: 45, Test set accuracy: 77.00, Test set loss: -0.0225
Computing entropies: 78it [00:04, 15.73it/s]
Training Set Size:  50%|█████     | 50/100 [01:15<06:22,  7.64s/it]

Dataset indices:  [48680 12076 12359 32012  8786]
Scores:  [1.9359958171844482, 1.6774601936340332, 1.6574766635894775, 1.6573704481124878, 1.6566517353057861]
Labels:  tensor([5, 9, 6, 8, 3])


2024-07-28 11:15:53,676 - INFO - Training set size: 50, Test set accuracy: 79.31, Test set loss: -0.0166
Computing entropies: 78it [00:04, 16.19it/s]
Training Set Size:  55%|█████▌    | 55/100 [01:52<05:37,  7.50s/it]

Dataset indices:  [49302  3104 30159 26981 32524]
Scores:  [1.8859899044036865, 1.8834879398345947, 1.8813992738723755, 1.8368420600891113, 1.8345773220062256]
Labels:  tensor([8, 9, 3, 3, 4])


2024-07-28 11:16:29,306 - INFO - Training set size: 55, Test set accuracy: 82.56, Test set loss: -0.0244
Computing entropies: 78it [00:04, 18.16it/s]
Training Set Size:  60%|██████    | 60/100 [02:26<04:50,  7.27s/it]

Dataset indices:  [44250 51283 37341 49012 30485]
Scores:  [1.6477454900741577, 1.6442286968231201, 1.6419825553894043, 1.6105098724365234, 1.59601628780365]
Labels:  tensor([6, 5, 5, 7, 2])


2024-07-28 11:17:04,536 - INFO - Training set size: 60, Test set accuracy: 79.98, Test set loss: -0.0243
Computing entropies: 78it [00:04, 17.67it/s]
Training Set Size:  65%|██████▌   | 65/100 [03:02<04:12,  7.22s/it]

Dataset indices:  [53696 48010 13081 25429 22147]
Scores:  [2.0079548358917236, 1.9639463424682617, 1.931840181350708, 1.9001073837280273, 1.8774399757385254]
Labels:  tensor([5, 7, 0, 8, 5])


2024-07-28 11:17:40,184 - INFO - Training set size: 65, Test set accuracy: 82.69, Test set loss: -0.0303
Computing entropies: 78it [00:04, 17.47it/s]
Training Set Size:  70%|███████   | 70/100 [03:38<03:37,  7.24s/it]

Dataset indices:  [24633 16797 41611 37005 31962]
Scores:  [1.965768575668335, 1.9328879117965698, 1.92270028591156, 1.9129570722579956, 1.9026795625686646]
Labels:  tensor([2, 8, 8, 8, 3])


2024-07-28 11:18:16,124 - INFO - Training set size: 70, Test set accuracy: 79.37, Test set loss: -0.0290
Computing entropies: 78it [00:04, 17.22it/s]
Training Set Size:  75%|███████▌  | 75/100 [04:14<02:59,  7.17s/it]

Dataset indices:  [25576 23578 57597 14295 11391]
Scores:  [1.9196363687515259, 1.907682180404663, 1.9068281650543213, 1.8995343446731567, 1.875787615776062]
Labels:  tensor([0, 2, 2, 2, 0])


2024-07-28 11:18:51,583 - INFO - Training set size: 75, Test set accuracy: 82.79, Test set loss: -0.0242
Computing entropies: 77it [00:04, 17.59it/s]
Training Set Size:  80%|████████  | 80/100 [04:49<02:23,  7.16s/it]

Dataset indices:  [49009  9924 12812 49002 27822]
Scores:  [1.8265278339385986, 1.8036136627197266, 1.7363224029541016, 1.7327243089675903, 1.7306407690048218]
Labels:  tensor([2, 8, 3, 1, 0])


2024-07-28 11:19:27,706 - INFO - Training set size: 80, Test set accuracy: 81.00, Test set loss: -0.0323
Computing entropies: 77it [00:05, 15.24it/s]
Training Set Size:  85%|████████▌ | 85/100 [05:26<01:48,  7.22s/it]

Dataset indices:  [ 1088 55244 55064 19934  5381]
Scores:  [1.6665570735931396, 1.6560964584350586, 1.6349788904190063, 1.6254900693893433, 1.607519268989563]
Labels:  tensor([7, 7, 9, 9, 4])


2024-07-28 11:20:05,104 - INFO - Training set size: 85, Test set accuracy: 82.86, Test set loss: -0.0233
Computing entropies: 77it [00:04, 16.26it/s]
Training Set Size:  90%|█████████ | 90/100 [06:04<01:13,  7.35s/it]

Dataset indices:  [39834 31339 33369 30139 33581]
Scores:  [1.717010736465454, 1.7047741413116455, 1.696811556816101, 1.6967148780822754, 1.6739047765731812]
Labels:  tensor([9, 6, 1, 6, 9])


2024-07-28 11:20:42,566 - INFO - Training set size: 90, Test set accuracy: 83.68, Test set loss: -0.0266
Computing entropies: 77it [00:04, 17.16it/s]
Training Set Size:  95%|█████████▌| 95/100 [06:40<00:36,  7.30s/it]

Dataset indices:  [13997 14355 36760 11536 41283]
Scores:  [1.7375097274780273, 1.7242283821105957, 1.7170239686965942, 1.690377116203308, 1.6881840229034424]
Labels:  tensor([9, 2, 7, 9, 3])


2024-07-28 11:21:18,216 - INFO - Training set size: 95, Test set accuracy: 83.04, Test set loss: -0.0249
Computing entropies: 77it [00:04, 16.78it/s]
Training Set Size: 100%|██████████| 100/100 [07:17<00:00,  7.30s/it]

Dataset indices:  [38631  9860 24587  1724 14124]
Scores:  [1.6712132692337036, 1.629883885383606, 1.543965458869934, 1.4847278594970703, 1.471351981163025]
Labels:  tensor([4, 6, 8, 2, 2])


2024-07-28 11:21:56,252 - INFO - Training set size: 100, Test set accuracy: 79.66, Test set loss: -0.0177
Training Set Size: 100%|██████████| 100/100 [07:34<00:00,  7.57s/it]


## Analyse Results

In [5]:
sns.lineplot(x=np.arange(start=config.num_initial_samples, stop=config.max_training_samples + 1, step=config.acquisition_batch_size),
            y=results['test_accs'])
plt.title('Active Learning Performance')
plt.xlabel('Number of training samples')
plt.ylabel('Test Accuracy')

NameError: name 'results' is not defined

## Save Experiment Results

In [7]:
if True:
    save_experiment(config, results, experiment_name)

Experiment saved in experiments\temperature_20240728-111421
