### Simple MLP MNIST Classification with JAX

(see JAX docs, e.g. ADVANCED JAX TUTORIALS)

In [None]:
import numpy as np

import jax
import jax.numpy as jnp
from jax import jit
from jax import random

In [None]:
PATH = 'data'
BATCH_SIZE = 64
MNIST_IMG_SIZE = (28, 28, 1)
STEP_SIZE = 0.01

#### import MNIST Dataset with Pytorch...

... and set up Dataloaders

In [None]:
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader


def custom_transform(x):
    """ gets PIL Image and returns flattened and normalized ndarray """
        
    return np.ravel(np.array(x, dtype=np.float32))/255.0


def custom_collate_fn(batch):
    """ gets list of tuples and returns seperated images and labels as ndarrays """
    
    transposed_data = list(zip(*batch))

    labels = np.array(transposed_data[1])
    imgs = np.stack(transposed_data[0])

    return  imgs, labels


train_dataset = MNIST(root=PATH, train=True, download=True, transform=custom_transform)
test_dataset = MNIST(root=PATH, train=False, download=True , transform=custom_transform)

train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, collate_fn=custom_collate_fn, drop_last=True)
test_loader = DataLoader(test_dataset, BATCH_SIZE, shuffle=False, collate_fn=custom_collate_fn, drop_last=True)


#### check data

In [None]:
import matplotlib.pyplot as plt

# img = np.expand_dims(np.reshape(next(iter(train_loader))[0][0], (28,28)), axis=2) 
img = np.reshape(next(iter(train_loader))[0][0], MNIST_IMG_SIZE)

plt.imshow(img)

#### initialize parameters of MLP 

In [None]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# # better initialization: (better predictions)
# def random_layer_params(m, n, key):
#  return np.sqrt(2/m) * random.normal(key, (n, m)), jnp.ones(shape=(n,))


# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

LAYER_SIZES = [784, 512, 512, 10]

params = init_network_params(LAYER_SIZES, random.PRNGKey(0))

##### use JAX pytree-functionality to check parameters

In [None]:
jax.tree_map(lambda x: x.shape, params)

#### define predict/forward function

In [None]:
from jax.scipy.special import logsumexp

def relu(x):
  return jnp.maximum(0, x)

def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)
  
  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits

# vmap predict to handle batches
batch_predict = jax.vmap(predict, in_axes=(None, 0))

#### define loss and update function

The loss function here is I think a version of the negative log-likelihood function, a.k.a. Cross Entropy. If one defines the Likelihood in the following way, with $ \sigma $ being the Softmax, 

$$ L(\theta) = p_{(Y|X)}(y|x) = \prod_{i=1}^{n} \prod_{k=1}^{l} \sigma(Net(x_i))_k^{y_{i_k}} $$

Applying the logarithm and multiplying by minus one we get the Negative Log Likelihood: 

$$  - log(L(\theta)) = - \sum_{i=1}^{n} \sum_{k=1}^{l} y_{i_k} (Net(x_i)_k - log( \sum_{j=1}^l e^{Net(x_i)_j} )) $$

<font size="2">(And deviding by the number of samples gives the Cross Entropy).</font>



In [None]:
N_TARGETS = 10

def one_hot(x, k=N_TARGETS, dtype=jnp.float32):
  """ create one-hot encodings of size k of (j)np.array x """
  return jnp.array(x[:, None] == jnp.arange(k), dtype)
  
def accuracy(params, images, targets):
  predicted_class = jnp.argmax(batch_predict(params, images), axis=1)
  return jnp.mean(predicted_class == targets)

def loss(params, images, targets):
  logits = batch_predict(params, images)
  return -jnp.mean((logits- logsumexp(logits)) * targets)

@jit
def update(params, x, y):
  grads = jax.grad(loss)(params, x, y)
  return [(w - STEP_SIZE * dw, b - STEP_SIZE * db)
          for (w, b), (dw, db) in zip(params, grads)]

#### train

In [None]:
NUM_EPOCHS = 3

for epoch in range(NUM_EPOCHS): 
    
  for xs, ys in train_loader:
    ys_onehot = one_hot(ys)
    params = update(params, xs, ys_onehot)

#### test

In [None]:

f_imgs, labels = next(iter(test_loader))

r_int = np.random.randint(0, BATCH_SIZE-1)
img_flat = f_imgs[r_int]
label = labels[r_int]

out = batch_predict(params, np.expand_dims(img_flat, axis=0))
prediction = jnp.argmax(out)

print('accuracy on sample batch: ', accuracy(params, f_imgs, labels))
print(f'example \n predicted: {prediction}, label: {label}')

img = np.reshape(img_flat, MNIST_IMG_SIZE)
plt.imshow(img)
