In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import OrderedDict

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from jax import random

from bayesian_active_learning.acquisition_functions import (
    BALD,
    max_entropy,
    uniform,
)
from bayesian_active_learning.data_utils import NumpyDataset, NumpyLoader
from bayesian_active_learning.experiment import experiment_run
from bayesian_active_learning.utils import one_hot

# Summary

The following notebook does a scan over the weight decay hyperparameter, to determine the appropriate value to maximise the validation accuracy. N.B We assume that the weight decay is tuned for the validation set for the first batch of training data only, and not on every subsequent receipt of data from the pool set.

# Set seed

In [None]:
key = random.PRNGKey(0)

# 1. Setup

## 1.1 Load + preprocess MNIST

In [None]:
full_train_dataset = torchvision.datasets.MNIST(
    "../datasets", train=True, download=True
)

num_classes = len(full_train_dataset.classes)
total_train_samples = len(full_train_dataset.data)
total_test_samples = len(full_train_dataset.data)

all_train_X = np.array(full_train_dataset.data) / 255.0
all_train_y = one_hot(np.array(full_train_dataset.targets), k=num_classes)

## 1.2 Split train set into initial train set, validation set and pool set

In [None]:
num_initial_train_points = 100
num_validation_points = 100

initial_train_X, val_X, initial_pool_X = np.split(
    all_train_X,
    [num_initial_train_points, num_initial_train_points + num_validation_points],
)
initial_train_y, val_y, initial_pool_y = np.split(
    all_train_y,
    [num_initial_train_points, num_initial_train_points + num_validation_points],
)

In [None]:
training_generator = NumpyLoader(
    dataset=NumpyDataset(initial_train_X, initial_train_y), batch_size=16, shuffle=True
)
validation_generator = NumpyLoader(dataset=NumpyDataset(val_X, val_y), batch_size=256)

# 2. tuning the weight decay

In [None]:
from functools import partial

import haiku as hk
import optax

from bayesian_active_learning.losses import classification_loss
from bayesian_active_learning.models import model
from bayesian_active_learning.training import fit

# create, transform and intialise model (and evaluation model)
num_classes = 10
dropout_rates = (0.25, 0.5)

base_training_model = partial(model, num_classes, dropout_rates)
stochastic_model = hk.transform(base_training_model)

base_eval_model = partial(model, num_classes, (0, 0))
eval_model = hk.without_apply_rng(hk.transform(base_eval_model))

key, subkey = random.split(key)
params = stochastic_model.init(subkey, jnp.zeros((1, 28, 28)))

loss = partial(classification_loss, stochastic_model)

# train the model using the initial training data
for weight_decay in jnp.logspace(-3, -1.5, 5):
    optimiser = optax.adamw(1e-3, weight_decay=weight_decay)

    params, metrics = fit(
        loss=loss,
        params=params,
        eval_model=eval_model,
        optimiser=optimiser,
        num_epochs=100,
        train_generator=training_generator,
        validation_generator=validation_generator,
        key=key,
    )

    print(
        f"weight decay; {weight_decay}, validation accuracy: {jnp.mean(metrics.validation_accuracy_history[:-20])}"
    )