In [1]:
# MNIST_BNN_SVGD_JAX.ipynb

# Import necessary libraries
import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

from datetime import date

rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
from warnings import filterwarnings

from flax import linen as nn
import jax.numpy as jnp

filterwarnings("ignore")

import sys, os
sys.path.insert(0, os.path.abspath(".."))
from modules.evaluation_functions.bnn_functions_MNIST import *

In [2]:
# Load and preprocess MNIST dataset
mnist = fetch_openml('mnist_784', version=1)
X = mnist.data.astype(jnp.float32) / 255.0  #/255 to normalize the data
y = mnist.target.astype(jnp.int32)

# Split dataset into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [3]:
X_train = np.array(X_train)  # Ensure X_train is a NumPy array
y_train = np.array(y_train)  # Ensure y_train is a NumPy array
X_test = np.array(X_test)    # Ensure X_test is a NumPy array
y_test = np.array(y_test)    # Ensure y_test is a NumPy array   

In [4]:
hidden_layer_width = 100
n_hidden_layers = 2

class NN(nn.Module):
    n_hidden_layers: int
    layer_width: int

    @nn.compact
    def __call__(self, x):
        for i in range(self.n_hidden_layers):
            x = nn.Dense(features=self.layer_width)(x)
            x = nn.tanh(x)
        return nn.Dense(features=1)(x)

bnn = NN(n_hidden_layers, hidden_layer_width)

In [5]:
rng_key, eval_key = jax.random.split(rng_key)

def fit_and_eval_single_mlp(key, X_train, y_train, X_test):
    return fit_and_eval(
        key, bnn, logdensity_fn_of_bnn, X_train, y_train, X_test, grid =None,  num_steps=40,batch_size_particles = 20, batch_size_data = 32, num_particles=100
    )

Ys_pred_train, Ys_pred_test, ppc_grid_single, _, _ = fit_and_eval_single_mlp(eval_key, X_train, y_train, X_test)

100%|██████████| 40/40 [01:42<00:00,  2.56s/it]


In [6]:
print(f"Train accuracy = {100 * jnp.mean(Ys_pred_train == y_train):.2f}%")

Train accuracy = 12.02%


In [7]:
print(f"Test accuracy = {100 * jnp.mean(Ys_pred_test == y_test):.2f}%")

Test accuracy = 11.96%
