# **Neural Network (Multilayer perceptron) using plain JAX**



In [27]:
import numpy as np
import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from jax import jit, vmap, pmap, grad, value_and_grad
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

In [28]:
seed = 2137
key = jax.random.PRNGKey(seed)

mnist_img_size = (28, 28)

batch_size = 128
num_epochs = 5
lr = 0.01

In [29]:
def custom_transform(x):
    return np.ravel(np.array(x, dtype=np.float32))

def custom_collate_fn(batch):
    transposed_data = list(zip(*batch))

    labels = np.array(transposed_data[1])
    imgs = np.array(transposed_data[0])

    return imgs, labels

In [30]:
train_dataset = MNIST(root='train_mnist', train=True, download=True, transform=custom_transform)
test_dataset = MNIST(root='test_mnist', train=False, download=True, transform=custom_transform)

train_loader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=custom_collate_fn, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size, shuffle=False, collate_fn=custom_collate_fn, drop_last=True)

In [31]:
# loading the whole dataset into memory
train_images = jnp.array(train_dataset.data).reshape(len(train_dataset), -1)
train_lbls = jnp.array(train_dataset.targets)

test_images = jnp.array(test_dataset.data).reshape(len(test_dataset), -1)
test_lbls = jnp.array(test_dataset.targets)

In [32]:
def init_NN(layer_widths, parent_key, scale=0.01):

    params = []
    keys = jax.random.split(parent_key, num=len(layer_widths)-1)

    for in_width, out_width, key in zip(layer_widths[:-1], layer_widths[1:], keys):
        weight_key, bias_key = jax.random.split(key)
        params.append([scale*jax.random.normal(weight_key, shape=(out_width, in_width)),
                       scale*jax.random.normal(bias_key, shape=(out_width,))])

    return params

In [33]:
def NN_predict(params, x):
    hidden_layers = params[:-1]

    activation = x
    for w, b in hidden_layers:
        activation = jax.nn.relu(jnp.dot(w, activation) + b)

    w_last, b_last = params[-1]
    logits = jnp.dot(w_last, activation) + b_last

    return logits - logsumexp(logits)

In [34]:
# test batched function
batched_NN_predict = vmap(NN_predict, in_axes=(None, 0))

In [35]:
def loss_fn(params, imgs, gt_lbls):
    predictions = batched_NN_predict(params, imgs)

    return -jnp.mean(predictions * gt_lbls)

def accuracy(params, dataset_imgs, dataset_lbls):
    pred_classes = jnp.argmax(batched_NN_predict(params, dataset_imgs), axis=1)
    return jnp.mean(dataset_lbls == pred_classes)

@jit
def update(params, imgs, gt_lbls, lr=lr):
    loss, grads = value_and_grad(loss_fn)(params, imgs, gt_lbls)

    return loss, jax.tree_multimap(lambda p, gradient: p - lr*gradient, params, grads)

# Create a MLP
NN_params = init_NN([np.prod(mnist_img_size), 512, 256, len(MNIST.classes)], key)

for epoch in range(num_epochs):

    for cnt, (imgs, lbls) in enumerate(train_loader):

        gt_labels = jax.nn.one_hot(lbls, len(MNIST.classes))
        
        loss, NN_params = update(NN_params, imgs, gt_labels)
        

    print(f'Epoch {epoch}, train acc = {accuracy(NN_params, train_images, train_lbls)} test acc = {accuracy(NN_params, test_images, test_lbls)}')



Epoch 0, train acc = 0.9097833633422852 test acc = 0.9146999716758728
Epoch 1, train acc = 0.9337999820709229 test acc = 0.9323999881744385
Epoch 2, train acc = 0.9447667002677917 test acc = 0.9429999589920044
Epoch 3, train acc = 0.951200008392334 test acc = 0.9490999579429626
Epoch 4, train acc = 0.959766685962677 test acc = 0.9559999704360962
