### Simple MLP with JAX (using pytree-functionality for parameters)

(see JAX docs, e.g. Tutorial: JAX 101)

In [None]:
import numpy as np

import jax
import jax.numpy as jnp
from jax import jit
from jax import random

In [None]:
PATH = 'data'
BATCH_SIZE = 64
MNIST_IMG_SIZE = (28, 28, 1)

#### import MNIST Dataset with Pytorch...

... and set up Dataloaders

In [None]:
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader


def custom_transform(x):
    """ gets PIL Image and returns flattened and normalized ndarray """
        
    return np.ravel(np.array(x, dtype=np.float32))/255.0


def custom_collate_fn(batch):
    """ gets list of tuples and returns seperated images and labels as ndarrays """
    
    transposed_data = list(zip(*batch))

    labels = np.array(transposed_data[1])
    imgs = np.stack(transposed_data[0])

    return  imgs, labels


train_dataset = MNIST(root=PATH, train=True, download=True, transform=custom_transform)
test_dataset = MNIST(root=PATH, train=False, download=True , transform=custom_transform)

train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, collate_fn=custom_collate_fn, drop_last=True)
test_loader = DataLoader(test_dataset, BATCH_SIZE, shuffle=False, collate_fn=custom_collate_fn, drop_last=True)


#### check data

In [None]:
import matplotlib.pyplot as plt

# img = np.expand_dims(np.reshape(next(iter(train_loader))[0][0], (28,28)), axis=2) 
img = np.reshape(next(iter(train_loader))[0][0], MNIST_IMG_SIZE)

plt.imshow(img)

#### initialize parameters of MLP 

In [None]:
def init_mlp_params(layer_widths, key):
  params = []
  for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
    key, subkey = random.split(key)
    params.append(
        dict(weights=random.normal(subkey, shape=(n_in, n_out)) * np.sqrt(2/n_in),
             biases=np.ones(shape=(n_out,))
            )
    )
  return params

# random key for parameter-initialization
key = random.PRNGKey(42)

params = init_mlp_params([28*28, 28*28, 512,  10], key)

##### use JAX pytree-functionality to check parameters

In [None]:
jax.tree_map(lambda x: x.shape, params)

#### define forward path, loss including forward path & parameter update function

In [None]:
def forward(params, x):
  *hidden, last = params
  for layer in hidden:
    x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
  return x @ last['weights'] + last['biases']

batch_forward = jax.vmap(forward, in_axes=[None, 0]) 

def loss_fn(params, x, y):
  return jnp.mean((batch_forward(params, x) - y) ** 2)

def predict(params, x): 
  y_hat = batch_forward(params, x)
  return_hat = np.argmax(batch_forward(params, x))
  return np.argmax(batch_forward(params, x))

LEARNING_RATE = 0.0001

@jit
def update(params, x, y):

  grads = jax.grad(loss_fn)(params, x, y)
  # Note that `grads` is a pytree with the same structure as `params`.
  # `jax.grad` is one of the many JAX functions that has
  # built-in support for pytrees.

  # This is handy, because we can apply the SGD update using tree utils:
  return jax.tree_multimap(lambda p, g: p - LEARNING_RATE * g, params, grads)

In [None]:
def one_hot(x, k=10, dtype=jnp.float32):
  """ create one-hot encodings of size k of (j)np.array x """
  return jnp.array(x[:, None] == jnp.arange(k), dtype)

#### train

In [None]:
NUM_EPOCHS = 5

for epoch in range(NUM_EPOCHS): 
    
  for xs, ys in train_loader:
    ys_onehot = one_hot(ys)
    params = update(params, xs, ys_onehot)

#### test

In [None]:

f_imgs, labels = next(iter(train_loader))

img_flat = f_imgs[0]
label = labels[0]

print(f'predicted: {predict(params, np.expand_dims(img_flat, axis=0))}, label: {label}')

img = np.reshape(img_flat, MNIST_IMG_SIZE)
plt.imshow(img)
