In [None]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
import torchvision
from jax import random

from bayesian_active_learning.utils import one_hot

# Set seed

In [None]:
key = random.PRNGKey(0)

# 1. Setup

# 1.1 Load + preprocess MNIST

In [None]:
full_train_dataset = torchvision.datasets.MNIST(
    "../datasets", train=True, download=True
)
full_test_dataset = torchvision.datasets.MNIST(
    "../datasets", train=False, download=True
)

num_classes = len(full_train_dataset.classes)
total_train_samples = len(full_train_dataset.data)
total_test_samples = len(full_train_dataset.data)

all_train_X = jnp.array(full_train_dataset.data) / 255.0
all_train_y = one_hot(jnp.array(full_train_dataset.targets), k=num_classes)

all_test_X = jnp.array(full_test_dataset.data) / 255.0
all_test_y = one_hot(jnp.array(full_test_dataset.targets), k=num_classes)

## 1.2 Split train set into initial train set and pool set

In [None]:
num_initial_train_points = 100

key, subkey = random.split(key, 2)

# select training data
initial_training_indices = random.choice(
    subkey,
    jnp.arange(total_train_samples),
    shape=(num_initial_train_points,),
    replace=False,
)

initial_train_X = all_train_X[initial_training_indices]
initial_train_y = all_train_y[initial_training_indices]

# select pool data (all train points not in the initial training subselections)
mask = jnp.ones(total_train_samples, jnp.bool_)
mask = mask.at[initial_training_indices].set(False)
initial_pool_X = all_train_X[mask]
initial_pool_y = all_train_y[mask]

# 2. Model setup

In [None]:
from functools import partial
from typing import Callable, Tuple

import haiku as hk
import jax.numpy as jnp
from jax import nn, random

from bayesian_active_learning.models import BayesianConvNet

key = random.PRNGKey(0)


def forward(
    num_classes: int,
    dropout_rates: Tuple[float, float],
    activation: Callable[[jnp.ndarray], jnp.ndarray],
    x: jnp.ndarray,
):
    net = BayesianConvNet(
        num_classes=num_classes, dropout_rates=dropout_rates, activation=nn.relu
    )

    return net(x)


num_classes = 10
dropout_rates = (0.25, 0.5)
activation = nn.relu
model = partial(forward, num_classes, dropout_rates, activation)
model = hk.transform(model)

key, subkey = random.split(key)
params = model.init(key, jnp.zeros((1, 28, 28)))

In [None]:
import haiku as hk
import jax
import jax.numpy as jnp


def forward(x):
    mlp = hk.nets.MLP([300, 100, 10])
    return mlp(x)


forward = hk.transform(forward)

rng = hk.PRNGSequence(jax.random.PRNGKey(42))
x = jnp.ones([8, 28 * 28])
params = forward.init(next(rng), x)
logits = forward.apply(params, next(rng), x)