In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from batchbald_redux import repeated_mnist, active_learning, batchbald
from main.models import ConvNet
from main.training_models import test_performance
from main.utils import save_experiment, load_experiment
from laplace.curvature import AsdlGGN, AsdlGGN
from main.laplace_batch import get_laplace_batch
from dataclasses import dataclass
from main.active_learning import run_active_learning

%reload_ext autoreload
%autoreload 2

In [None]:
# loading data
train_dataset, test_dataset = repeated_mnist.create_MNIST_dataset()

In [None]:
# set configurations
@dataclass
class ActiveLearningConfig:
    subset_of_weights: str = 'last_layer'
    hessian_structure: str = 'kron'
    backend: str = 'AsdlGGN'
    temperature: float = 1.0
    max_training_samples: int = 100
    acquisition_batch_size: int = 5
    al_method: str = 'entropy'
    test_batch_size: int = 512
    num_classes: int = 10
    num_initial_samples: int = 40
    training_iterations: int = 4096 
    scoring_batch_size: int = 64
    train_batch_size: int = 64
    extract_pool: int = 59000

experiment_id = 'active_learning'

config = ActiveLearningConfig()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

save_results = False

## Runs Active Learning with settings

In [None]:
results = run_active_learning(
    train_dataset=train_dataset,
    test_dataset=test_dataset, 
    model_constructor=ConvNet, 
    config=config, 
    device=device
    )

In [None]:
sns.set_palette(sns.color_palette("Spectral"))

In [None]:
sns.lineplot(x=np.arange(start=config.num_initial_samples, stop=config.max_training_samples + 1, step=config.acquisition_batch_size),
            y=results['test_accs'])
plt.title('Active Learning Performance')
plt.xlabel('Number of training samples')
plt.ylabel('Test Accuracy')

In [None]:
labels = torch.stack(results['added_labels'])

# give counts for each class
counts = torch.zeros(config.num_classes)
for i in range(config.num_classes):
    counts[i] = (labels == i).sum()

sns.barplot(x=np.arange(config.num_classes), y=counts)
plt.title('Class distribution of added samples')

In [None]:
if save_results:
    save_experiment(experiment_id, results)