<a href="https://colab.research.google.com/github/PetchMa/deeplearning_fundamentals/blob/main/MLP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Building MLP
In this notebook we will attempt to build a simple neural network from scratch using JAX and I will attempt to walk through the step process in doing so. 

First we import a number of packages, we will be using TORCH MNIST dataset as a starting ground as JAX doesnt handle data loading and all that. 

In [4]:
import jax
import numpy as np
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from jax import jit, vmap, grad, pmap,value_and_grad

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader 

# Building Model

We start off by building the individual models by creating effectively empty weights. These weights live inside of effectively a python list. These weights are initialized using the function below. 

I like to see this as constructing the empty skeleton of the model. To make it "live" we need to add operations to the model to actually __do__ anything with the model which is the feed forward aspect. 

To initialize the model what we do is we generate some random key as a starting point. Then we iterate through the model dimensions. We go from ```[784, 512, 256, 10]``` layer weights. 

In [5]:
seed = 0

def init_MLP(layer_widths,parent_key, scale =0.01):
  params = []
  keys = jax.random.split(parent_key,num=len(layer_widths)-1)
  for in_width, out_width, key in zip(layer_widths[:-1], layer_widths[1:], keys):
    weight_key, bias_key = jax.random.split(key)

    params.append(
                   [scale*jax.random.normal(weight_key, shape=(out_width, in_width)),
                    scale*jax.random.normal(bias_key, shape=(out_width,))]
    )
  return params

rng = jax.random.PRNGKey(seed)

MLP_params = init_MLP([784, 512, 256, 10], rng)
#this checks the shape of the model
print(jax.tree_map(lambda x: x.shape, MLP_params))

[[(512, 784), (512,)], [(256, 512), (256,)], [(10, 256), (10,)]]


# Feed forward
Now we want to make the model "alive" by implementing the actual feedforward aspect. Recall that feed forward is a simple linear combo of the model weights with some bias and the applied with a nonlinear activation function and passed on. 

We do exactly that, we set the initial input as the resultant activation, then we update it in the for loop by taking the linear combo of weights plust the bias and then applying the relu activation function. 

However at the end, we apply softmax but not really softmax lmao because JAX doesn't have it at the moment so we apply log softmax which we need to remember to exponentiate. 

However NOTE: the good things about JAX is that ```VMAP``` function transforms the function from a pathetic function acting on a SINGLE entry of the data, we can vectorize the operation using VMAP which makes it an array based operation which is sweet.

In [6]:
def MLP_predict(params, x):
  hidden_layers = params[:-1]

  activation = x
  for w,b in hidden_layers:
    activation = jax.nn.relu(jnp.dot(w,activation)+b)

  w_last, b_last = params[-1]
  logits = jnp.dot(w_last,activation)+b_last
  return logits-logsumexp(logits) # basically does softmax lol but its log of softmax

mnist_img_size = 784

# LOOK AT THIS VMAP FUNCTION AND REMEMBER IT CLEARLY
batched_MLP_predict = vmap(MLP_predict, in_axes=(None, 0))
# small test
dummy_imgs_flat = np.random.randn(16, np.prod(mnist_img_size))
print(dummy_imgs_flat.shape)
predictions = batched_MLP_predict(MLP_params, dummy_imgs_flat)
print(predictions.shape)

(16, 784)
(16, 10)


# Data Transformation 

Okay this part is boring but we basically need to transform the images by squashing them down to vectors and then feeding it into the neural network. And so its not very interesting. Its just data augmentation stuff.

In [7]:
# data loading
def custom_transform(x):
  return np.ravel(np.array(x, dtype=np.float32))
train_dataset = MNIST(root='train_mnist', train=True, download=True, transform=custom_transform)
test_dataset = MNIST(root='train_mnist', train=False, download=True, transform=custom_transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to train_mnist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting train_mnist/MNIST/raw/train-images-idx3-ubyte.gz to train_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to train_mnist/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting train_mnist/MNIST/raw/train-labels-idx1-ubyte.gz to train_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to train_mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting train_mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to train_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to train_mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting train_mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to train_mnist/MNIST/raw



In [8]:
def custom_collate_fn(batch):
  transposed_data = list(zip(*batch))
  labels = np.array(transposed_data[1])
  imgs = np.stack(transposed_data[0])
  return imgs, labels


train_loader = DataLoader(train_dataset, batch_size=128, shuffle = True, collate_fn=custom_collate_fn)
batch_data = next(iter(train_loader))
imgs = batch_data[0]
labels = batch_data[1]
print(labels.shape)

(128,)


# Training
Now to make the neural network learn we need a metric and this is the error function which is literally just the mean of the crossentropy error, since we've already gotten the log of the softmax we can just multiply it with the labels. note the labels are one-hot encodings and so  we dont need to loop through each nonzero value as those would get multiplied by 0. We can just take the average of this. We want to push this value up to 1. We basically did cross entropy loss

We then want to update the neural network. Value and GRAD helps return both the actual function f(x) value and the grad gives the gradient of the function. 

Now this tree_multimap is a bit confusing... here is how it works basically the gradients returned are python treemaps which you can basically think of as like nested lists, but can take arbitrary data types and stuff. 

Look at the first section, we have ```lambda p,g:p-lr*g``` this is a simple for loop function that takes the parameter p, takes the corresponding gradient g and applies the stochastic gradient decent. with p = p-lr*g. Now to properly index the correct parameters and the correct gradients in respect to the parameters, we use the tree map. Since the parameters are stored in thes nested list structure, the gradients in respect to those layers are also stored in a similar fashion. Thus to "unravel" and apply back propagation we use the tree multi map


In [9]:
num_epochs = 100

def loss_fn(params, imgs, gt_lbls):
  predictions = batched_MLP_predict(params, imgs)
  return -jnp.mean(predictions * gt_lbls)

def update(params, imgs, gt_lbls, lr = 0.01):
  loss, grads = value_and_grad(loss_fn)(params,imgs,gt_lbls)
  return loss, jax.tree_multimap(lambda p,g:p-lr*g, params, grads)

for epochs in range(num_epochs):
  for count, (imgs, lbls) in enumerate(train_loader):
    gt_labels = jax.nn.one_hot(lbls, len(MNIST.classes))
    loss, MLP_params = update(MLP_params, imgs, gt_labels)
    if count %10==0:
      print(loss)
  break

0.24227884
0.20041227
0.16782306
0.13546939
0.11797513
0.09290477
0.08728522
0.07453238
0.06636019
0.072071545
0.060662486
0.050317425
0.06171403
0.04537266
0.059020735
0.04427728
0.038150944
0.036000844
0.045533556
0.044534184
0.04663719
0.05466785
0.03303972
0.031562883
0.045496844
0.04389519
0.038473535
0.030264331
0.04470369
0.046309512
0.039135084
0.027539542
0.03618552
0.043680187
0.023523778
0.027202338
0.03461301
0.043001037
0.024820391
0.026264424
0.040121146
0.020373339
0.027725805
0.035998005
0.037409402
0.035020955
0.024504013
