In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import seaborn as sns

import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from batchbald_redux import repeated_mnist, active_learning, batchbald
from main.models import ConvNet
from main.training_models import test_performance
from main.utils import save_experiment, load_experiment, log_experiment, generate_experiment_id
from laplace.curvature import AsdlGGN, AsdlGGN
from main.laplace_batch import get_laplace_batch
from dataclasses import dataclass
from main.active_learning import run_active_learning

sns.set_palette(sns.color_palette("Spectral"))

%reload_ext autoreload
%autoreload 2

## Settings for Active Learning

In [None]:
# set configurations
@dataclass
class ActiveLearningConfig:
    subset_of_weights: str = 'last_layer'
    hessian_structure: str = 'kron'
    backend: str = 'AsdlGGN'
    temperature: float = 1.0
    max_training_samples: int = 100
    acquisition_batch_size: int = 5
    al_method: str = 'entropy'
    test_batch_size: int = 512
    num_classes: int = 10
    num_initial_samples: int = 40
    training_iterations: int = 4096 // 64
    scoring_batch_size: int = 64
    train_batch_size: int = 64
    extract_pool: int = 59900

experiment_name = 'active_learning'  # provide descriptive name for the experiment
experiment_name += generate_experiment_id()

config = ActiveLearningConfig()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

save_results = False

## Load and Prepare Data

In [None]:
# loading data
train_dataset, test_dataset = repeated_mnist.create_MNIST_dataset()

kwargs = {"num_workers": 1, "pin_memory": True}

# get indices of initial samples
initial_samples = active_learning.get_balanced_sample_indices(
    repeated_mnist.get_targets(train_dataset), num_classes=config.num_classes, n_per_digit=config.num_initial_samples / config.num_classes
)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=config.test_batch_size, shuffle=False, **kwargs)

active_learning_data = active_learning.ActiveLearningData(train_dataset)

# Split off the initial samples first.
active_learning_data.acquire(initial_samples)

# THIS REMOVES MOST OF THE POOL DATA. REMOVE THIS LINE TO USE THE FULL POOL
active_learning_data.extract_dataset_from_pool(config.extract_pool)

train_loader = torch.utils.data.DataLoader(
    active_learning_data.training_dataset,
    sampler=active_learning.RandomFixedLengthSampler(active_learning_data.training_dataset, config.training_iterations),
    batch_size=config.train_batch_size,
    **kwargs,
)

pool_loader = torch.utils.data.DataLoader(
    active_learning_data.pool_dataset, batch_size=config.scoring_batch_size, shuffle=False, **kwargs
)

## Runs Active Learning with settings

In [None]:
results = run_active_learning(
    train_loader=traind_loader,
    test_loader=test_loader, 
    active_learning_data=active_learning_data,
    model_constructor=ConvNet, 
    config=config, 
    device=device
    )

## Analyse Results

In [None]:
sns.lineplot(x=np.arange(start=config.num_initial_samples, stop=config.max_training_samples + 1, step=config.acquisition_batch_size),
            y=results['test_accs'])
plt.title('Active Learning Performance')
plt.xlabel('Number of training samples')
plt.ylabel('Test Accuracy')

In [None]:
labels = torch.stack(results['added_labels'])

# give counts for each class
counts = torch.zeros(config.num_classes)
for i in range(config.num_classes):
    counts[i] = (labels == i).sum()

sns.barplot(x=np.arange(config.num_classes), y=counts)
plt.title('Class distribution of added samples')

## Save Experiment Results

In [None]:
if True:
    save_experiment(config, results, experiment_name)