<a href="https://colab.research.google.com/github/PetchMa/deeplearning_fundamentals/blob/main/BNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Building Bayes Neural Network
In this notebook we will attempt to build a simple bayesian neural network from scratch using JAX and I will attempt to walk through the step process in doing so. 

First we import a number of packages, we will be using TORCH MNIST dataset as a starting ground as JAX doesnt handle data loading and all that. 

In [1]:
import jax
import numpy as np
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from jax import jit, vmap, grad, pmap,value_and_grad
from time import time
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader 

# Building Model

We start off by building the individual models by creating effectively empty "weights". Recall what a Bayesian neural network __IS__. A Bayesian neural net seeks to learn no weight parameters but distributions, and thus for each parameter we have a mean AND a log normal standard deviation. Thus we implement the following: first a sampling function from a guassian. Then we need to wrap that function within a tree map as shown below:

Notice now we have 4 parameters which we have to update, this is okay we just need to apply the correct feedforward implementation so our model will be split into two one section containing the mean ```mu``` and the other containing the standard deviation ```sigma``` which is stored together as a single dictionary!

ALSO NOTE, we needed to use a variance of approx ~ 0.001 at the start or else this screws up the rest of  ~0.001 variance around default means that we need a LOG VARIANCE of -7 approx. 

In [2]:
seed = 0

@jax.jit
def sample_gaussian(mu, logvar, rng):
  """Sample from a Gaussian distribution.

  NOTE: It uses reparameterization trick.
  """
  eps = jax.random.normal(rng, shape=mu.shape)
  return eps * jnp.exp(logvar / 2) + mu

@jax.jit
def sample_params(mu, sigma, rng):
  sample = jax.tree_multimap(sample_gaussian, mu, sigma, rng)
  return sample


def init_MLP(layer_widths,parent_key, scale =0.01):
  params = {}
  mu = []
  sigma =[]
  keys = jax.random.split(parent_key,num=len(layer_widths)-1)
  for in_width, out_width, key in zip(layer_widths[:-1], layer_widths[1:], keys):
    weight_key, bias_key = jax.random.split(key)
    mu.append([scale*jax.random.normal(weight_key, shape=(out_width, in_width)),
               scale*jax.random.normal(bias_key, shape=(out_width,))])
    
  params['mu'] = mu
  #see we placed a log variance of -7 everywhere because ln(0.001)~ -7 
  params['sigma'] = jax.tree_map(lambda x: -7 * jnp.ones_like(x), mu)


  return params

rng = jax.random.PRNGKey(seed)

MLP_params = init_MLP([784, 512, 256, 10], rng)
#this checks the shape of the model
print(jax.tree_map(lambda x: x.shape, MLP_params))
print(type(MLP_params))

{'mu': [[(512, 784), (512,)], [(256, 512), (256,)], [(10, 256), (10,)]], 'sigma': [[(512, 784), (512,)], [(256, 512), (256,)], [(10, 256), (10,)]]}
<class 'dict'>


# Feed forward
Now we want to make the model "alive" by implementing the actual feedforward aspect. Recall that feed forward is a simple linear combo of the model weights with some bias and the applied with a nonlinear activation function and passed on. 

However our neural network "parameters" describe an entire distribution of weights to choose from and thus we need to select those weights. 

The trick to remember is we can store the specific stuff within a dictionary within the layers. The issue with previous ML Libs is that these model structures are often abstracted away from us and makes it hard to manipulate the individual architecture. Now everything is up to you to implement correctly! 



In [3]:
from functools import partial 
@jax.jit
def MLP_predict(paramas_key, x ):
  params, rng = paramas_key
  hidden_layers_mu = params['mu'][:-1]
  hidden_layers_sigma = params['sigma'][:-1]
  # rng=0
  params_rng, rng = jax.random.split(rng)

  activation = x
  for mu, sigma in zip(hidden_layers_mu,hidden_layers_sigma) :
    weight_mu = mu[0]
    weight_sigma =  sigma[0]
    bias_mu =  mu[1]
    bias_sigma =  sigma[1]

    w= sample_params(weight_mu, weight_sigma, params_rng)
    b = sample_params(bias_mu, bias_sigma, params_rng)
    activation = jax.nn.relu(jnp.dot(w,activation)+b)

  hidden_layers_mu_last = params['mu'][-1]
  hidden_layers_sigma_last = params['sigma'][-1]

  weight_mu_last = hidden_layers_sigma_last[0]
  weight_sigma_last =  hidden_layers_sigma_last[0]
  bias_mu_last =  hidden_layers_mu_last[1]
  bias_sigma_last =  hidden_layers_sigma_last[1]

  w_last= sample_params(weight_mu_last, weight_sigma_last, params_rng)
  b_last = sample_params(bias_mu_last, bias_sigma_last, params_rng)
  logits = jnp.dot(w_last,activation)+b_last
  return logits-logsumexp(logits) # basically does softmax lol but its log of softmax

mnist_img_size = 784

# LOOK AT THIS VMAP FUNCTION AND REMEMBER IT CLEARLY
batched_MLP_predict = vmap(MLP_predict, in_axes=(None,0))
# small test
dummy_imgs_flat = np.random.randn(16, np.prod(mnist_img_size))

predictions = batched_MLP_predict((MLP_params,rng), dummy_imgs_flat)
print(predictions.shape)
print(type(MLP_params))

(16, 10)
<class 'dict'>


# Data Transformation 

Okay this part is boring but we basically need to transform the images by squashing them down to vectors and then feeding it into the neural network. And so its not very interesting. Its just data augmentation stuff.

In [4]:
# data loading
def custom_transform(x):
  return np.ravel(np.array(x, dtype=np.float32))
train_dataset = MNIST(root='train_mnist', train=True, download=True, transform=custom_transform)
test_dataset = MNIST(root='train_mnist', train=False, download=True, transform=custom_transform)

In [5]:
def custom_collate_fn(batch):
  transposed_data = list(zip(*batch))
  labels = np.array(transposed_data[1])
  imgs = np.stack(transposed_data[0])
  return imgs, labels


train_loader = DataLoader(train_dataset, batch_size=128, shuffle = True, collate_fn=custom_collate_fn)
batch_data = next(iter(train_loader))
imgs = batch_data[0]
labels = batch_data[1]
print(labels.shape)

(128,)


# Loss Function

Okay so another thing that is different about Bayes is that we need to compute what is called the evidence base lower bound or the ELBO function. Basically we just need to make sure the statistical difference between the layers needs to be great and the predictions need to be accurate. Thus we have to tack on this extra loss function. 

We compute the guassian kl divergence with the following function:

In [6]:
@jax.jit
def gaussian_kl(mu, logvar):
    """Computes mean KL between parameterized Gaussian and Normal distributions.

    Gaussian parameterized by mu and logvar. Mean over the batch.

    NOTE: See Appendix B from VAE paper (Kingma 2014):
          https://arxiv.org/abs/1312.6114
    """
    kl_divergence = jnp.sum(jnp.exp(logvar) + mu**2 - 1 - logvar) / 2
    kl_divergence /= mu.shape[0]

    return kl_divergence
@jax.jit
def elbo(params, imgs, gt_lbls,rng, beta=1):
    predictions = batched_MLP_predict((params,rng), imgs)
    # Compute log likelihood of batch.
    log_likelihood = jnp.mean(predictions * gt_lbls)
    # Compute the kl penalty on the approximate posterior.
    kl_divergence = jax.tree_util.tree_reduce(
        lambda a, b: a + b,
        jax.tree_multimap(gaussian_kl,
                          params['mu'],
                          params['sigma']),
    )
    elbo_ = log_likelihood - beta * kl_divergence
    return elbo_, log_likelihood, kl_divergence


# Training
Now to make the neural network learn we need a metric and this is the error function which is literally just the mean of the crossentropy error, since we've already gotten the log of the softmax we can just multiply it with the labels. note the labels are one-hot encodings and so  we dont need to loop through each nonzero value as those would get multiplied by 0. We can just take the average of this. We want to push this value up to 1. We basically did cross entropy loss

We then want to update the neural network. Value and GRAD helps return both the actual function f(x) value and the grad gives the gradient of the function. 

Now this tree_multimap is a bit confusing... here is how it works basically the gradients returned are python treemaps which you can basically think of as like nested lists, but can take arbitrary data types and stuff. 

Look at the first section, we have ```lambda p,g:p-lr*g``` this is a simple for loop function that takes the parameter p, takes the corresponding gradient g and applies the stochastic gradient decent. with p = p-lr*g. Now to properly index the correct parameters and the correct gradients in respect to the parameters, we use the tree map. Since the parameters are stored in thes nested list structure, the gradients in respect to those layers are also stored in a similar fashion. Thus to "unravel" and apply back propagation we use the tree multi map

However to test the actual accuracy we want to loop through multiple samplings of the data to get a reading on the final distribution.



In [None]:
num_epochs = 100


@jax.jit
def accuracy(logits, targets):
    """Returns classification accuracy."""
    # Return accuracy = how many predictions match the ground truth
    return jnp.mean(jnp.argmax(logits, axis=-1) == targets)

# @jax.jit
def predict(params, batch_image, rng):
    probs = []
    num_samples = batch_image.shape[0]  
    for i in range(num_samples):
        params_rng, rng = jax.random.split(rng)
        logits = batched_MLP_predict((params,params_rng), batch_image)
        probs.append(jax.nn.softmax(logits))
    stack_probs = jnp.stack(probs)
    return jnp.mean(stack_probs, axis=0), jnp.std(stack_probs, axis=0)

@jax.jit
def loss_fn(params, imgs, gt_lbls,rng):
  return -elbo(params, imgs, gt_lbls, rng)[1]

# @jax.jit
# def loss_fn(params, imgs, gt_lbls,rng):
#   return -elbo(params, imgs, gt_lbls, rng)[0]

@jax.jit
def update(params, imgs, gt_lbls, rng, lr = 0.01):
  loss, grads = value_and_grad(loss_fn)(params, imgs, gt_lbls, rng)
  return loss, jax.tree_multimap(lambda p, g: p - lr * g, params, grads)



MLP_params = init_MLP([784, 512, 256, 10], rng,scale=0.01)

for epochs in range(num_epochs):
  start = time()
  for count, (imgs, lbls) in enumerate(train_loader):
    gt_labels = jax.nn.one_hot(lbls, len(MNIST.classes))
    loss, MLP_params = update(MLP_params, imgs, gt_labels,rng)
    if epochs %10==0 and count==0:
      mean, var = predict(MLP_params, imgs, rng)
      classes = jnp.argmax(gt_labels, axis=-1)
      acc = accuracy(mean, classes)
      print(epochs, 'LOSS: ',loss," ACCURACY", acc, " time taken: ", round((time()-start)/60,5))


0 LOSS:  2412058.8  ACCURACY 0.1328125  time taken:  0.08032
10 LOSS:  nan  ACCURACY 0.046875  time taken:  0.00676
20 LOSS:  nan  ACCURACY 0.0625  time taken:  0.00691
30 LOSS:  nan  ACCURACY 0.078125  time taken:  0.00637
40 LOSS:  nan  ACCURACY 0.1015625  time taken:  0.0063
50 LOSS:  nan  ACCURACY 0.125  time taken:  0.00628
