In [None]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from jax import random

from bayesian_active_learning.data_utils import NumpyDataset, NumpyLoader
from bayesian_active_learning.utils import one_hot

# Set seed

In [None]:
key = random.PRNGKey(0)

# 1. Setup

# 1.1 Load + preprocess MNIST

In [None]:
full_train_dataset = torchvision.datasets.MNIST(
    "../datasets", train=True, download=True
)
full_test_dataset = torchvision.datasets.MNIST(
    "../datasets", train=False, download=True
)

num_classes = len(full_train_dataset.classes)
total_train_samples = len(full_train_dataset.data)
total_test_samples = len(full_train_dataset.data)

all_train_X = np.array(full_train_dataset.data) / 255.0
all_train_y = one_hot(np.array(full_train_dataset.targets), k=num_classes)

all_test_X = np.array(full_test_dataset.data) / 255.0
all_test_y = one_hot(np.array(full_test_dataset.targets), k=num_classes)

## 1.2 Split train set into initial train set, validation set and pool set

In [None]:
num_initial_train_points = 100
num_validation_points = 100

initial_train_X, val_X, inital_pool_X = np.split(
    all_train_X,
    [num_initial_train_points, num_initial_train_points + num_validation_points],
)
initial_train_y, val_y, inital_pool_y = np.split(
    all_train_y,
    [num_initial_train_points, num_initial_train_points + num_validation_points],
)

In [None]:
training_generator = NumpyLoader(
    dataset=NumpyDataset(initial_train_X, initial_train_y), batch_size=16, shuffle=True
)
validation_generator = NumpyLoader(
    dataset=NumpyDataset(val_X, val_y), batch_size=256, shuffle=True
)
test_generator = NumpyLoader(
    dataset=NumpyDataset(all_test_X, all_test_y), batch_size=256
)

# 2. Model setup

In [None]:
from functools import partial
from typing import Callable, Tuple

import haiku as hk
import jax.numpy as jnp
from jax import nn, random

from bayesian_active_learning.models import BayesianConvNet

key = random.PRNGKey(0)


def forward(
    num_classes: int,
    dropout_rates: Tuple[float, float],
    activation: Callable[[jnp.ndarray], jnp.ndarray],
    x: jnp.ndarray,
):
    net = BayesianConvNet(
        num_classes=num_classes, dropout_rates=dropout_rates, activation=nn.relu
    )

    return net(x)


def forward(
    num_classes: int,
    dropout_rates: Tuple[float, float],
    activation: Callable[[jnp.ndarray], jnp.ndarray],
    x: jnp.ndarray,
) -> Callable[[jnp.ndarray], jnp.ndarray]:
    net = BayesianConvNet(
        num_classes=num_classes,
        activation=activation,
    )
    return net(dropout_rates, x)


num_classes = 10
dropout_rates = (0.25, 0.5)
activation = nn.relu

model = partial(forward, num_classes, dropout_rates, activation)
model = hk.transform(model)

eval_model = partial(forward, num_classes, (0, 0), activation)
eval_model = hk.without_apply_rng(hk.transform(eval_model))

key, subkey = random.split(key)
params = model.init(key, jnp.zeros((1, 28, 28)))

# 2.1 Train model

In [None]:
import optax


def loss(
    params: optax.Params, xs: jnp.ndarray, labels: jnp.ndarray, key
) -> jnp.ndarray:
    y_hat = model.apply(params, x=xs, rng=key)

    # optax also provides a number of common loss functions.
    loss_value = optax.softmax_cross_entropy(y_hat, labels)

    return loss_value.mean()

In [None]:
dummy_x = jnp.ones((2, 28, 28))
dummy_y = 1 / 10 * jnp.ones((2, 10))
print(loss(params, dummy_x, dummy_y, key))

In [None]:
import jax
from tqdm.notebook import trange

from bayesian_active_learning.metrics import compute_model_accuracy

params = model.init(key, jnp.zeros((1, 28, 28)))
optimizer = optax.adamw(1e-3, weight_decay=1e-3)
opt_state = optimizer.init(params)


@jax.jit
def step(params, optimizer_state, xs, labels, key):
    grads = jax.grad(loss)(params, xs, labels, key)
    updates, opt_state = optimizer.update(grads, optimizer_state, params)
    return optax.apply_updates(params, updates), opt_state


validation_accuracy_history = []
train_accuracy_history = []

for epoch in trange(100):
    for xs, labels in training_generator:
        key, sub_key = random.split(key, 2)
        params, opt_state = step(params, opt_state, xs, labels, sub_key)

    # compute accuracy on validation and train set
    train_accuracy = compute_model_accuracy(
        partial(eval_model.apply, params), training_generator
    )
    validation_accuracy = compute_model_accuracy(
        partial(eval_model.apply, params), validation_generator
    )

    train_accuracy_history.append(train_accuracy)
    validation_accuracy_history.append(validation_accuracy)

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_accuracy_history)
plt.plot(validation_accuracy_history)

# 3. Acquisition functions 

In [None]:
from jax import vmap


def generate_logit_samples(
    model: Callable[[jnp.ndarray, random.PRNGKeyArray], jnp.ndarray],
    xs: jnp.ndarray,
    num_samples: int,
    key: random.PRNGKeyArray,
) -> jnp.ndarray:
    keys = random.split(key, num_samples)

    return vmap(model, in_axes=(0, None))(keys, xs).transpose((1, 0, 2))


def entropy(dist: jnp.ndarray) -> jnp.ndarray:
    # expect batch * num_classes
    return -jnp.sum(dist * jnp.log(dist), axis=-1)


@jax.jit
def BALD(logit_samples: jnp.ndarray) -> jnp.ndarray:
    # expect batch * num_samples * num_classes
    probs = nn.softmax(logit_samples, axis=-1)

    posterior_predictive = jnp.mean(probs, axis=1)

    return entropy(posterior_predictive) - jnp.mean(entropy(probs), axis=1)


def max_entropy(logit_samples: jnp.ndarray) -> jnp.ndarray:
    # expect batch * num_samples * num_classes
    probs = nn.softmax(logit_samples, axis=-1)

    posterior_predictive = jnp.mean(probs, axis=1)

    return entropy(posterior_predictive)


def random(logit_samples: jnp.ndarray) -> jnp.ndarray:
    return jnp.ones(logit_samples.shape[0])

In [None]:
predictive_model = partial(model.apply, params)

samples = generate_logit_samples(predictive_model, jnp.ones((2, 28, 28)), 2, key)

# 4. Putting it all together