# Jax - Understanding the Library and Implementing a MLP

In [None]:
import time
import numpy as np

import matplotlib.pyplot as plt

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax.scipy.special import logsumexp

import tensorflow as tf
import tensorflow_datasets as tfds

### Importing the Dataset

Let's use the MNIST dataset. For that, we will use the tensorflow built-in functions to download and process it.the

In [None]:
def one_hot(x, k, dtype=jnp.float32):
    """
    Create a one-hot encoding of x of size k.
    
    x: array
        The array to be one hot encoded
    k: interger
        The number of classes
    dtype: jnp.dtype, optional(default=float32)
        The dtype to be used on the encoding
    
    """
    return jnp.array(x[:, None] == jnp.arange(k), dtype)

In [None]:
data_dir = '/tmp/tfds'

mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)

train_data, test_data = mnist_data['train'], mnist_data['test']

num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

# Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)

# Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)

In [None]:
# Sets the random key for Jax, this is different from the numpy way
random_state = 42
key = random.PRNGKey(random_state)

Ok, but why this random state handling is different from numpy? Well, the numpy method defines a global random state that is shared between all of the random functions. This is not a problem when we are dealing with sequential execution, however, it becomes a problem when we parallelize our functions.

See, the random state will make sure that the value is the same given an execution of that function. See for example the following code snippet:

In [None]:
np.random.seed(0)

def bar(): return np.random.uniform()
def baz(): return np.random.uniform()

def foo(): return bar() + 2 * baz()

print(foo())

In [None]:
np.random.seed(0)

def bar(): return np.random.uniform()
def baz(): return np.random.uniform()

def foo(): return 2 * baz() + bar()

print(foo())

If we change the order of the execution, the result is different. If we paralelize the calculation of these functions, we cannot guarantee that any of them will return first, so our code will be stochastic. To deal with that, every function in Jax must receive a key and that key must be unique, in a way that it does not matter in which way our calculation happens, the result will always be the same.

The way of making this work is to never use tha same key twice. But how the hell are we going to create a different key? Well, Jax has a built-in function for that, called split.

This function will receive a key and will generate two new keys that can be used on the following functions. This split is deterministic, so we do not have to worry about it changing our results.

In [None]:
key

In [None]:
# Here we split our original key into three subkeys
random.split(key, num=3)

## MLP Implementation

Now, let's implement a Multilayer Perceptron (MLP) using Jax to better grab a grasp of how we can use it for Machine Learning.

### Parameters initialization

Let's define some functions to initialize our neural network parameters with random normal values.

In [None]:
def random_layer_params(m, n, key, scale=1e-2):
    """
    This function returns two matrices, a W matrix with shape (n, m) and a b matrix with shape (n,)
    
    m: integer
        The first dimension of the W matrix
    n: integer
        The second dimension of the b matrix
    key: PRNGKey
        A Jax PRNGKey
    scale: float, optional(default=1e-2)
        The scale of the random numbers on the matrices
    """
    # Split our key into two new keys, one for each matrix
    w_key, b_key = random.split(key, num=2)
    return scale * random.normal(w_key, (m,n)), scale * random.normal(b_key, (n,))

In [None]:
def init_network_params(layers_sizes, key):
    """
    Given a list of weights for a neural network, initializes the weights of the network
    
    layers_sizes: list of integers
        The number of neurons on each layer of the network
    key: PRNGKey
        A Jax PRNGKey
    """
    # Generate one subkey for layer in the network
    keys = random.split(key, len(layers_sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(layers_sizes[:-1], layers_sizes[1:], keys)]

### Creating a prediction function

Now we will create a prediction function for our MLP. For the activation function we are going to use the ReLU.

In [None]:
def relu(x):
    return jnp.maximum(0, x)

In [None]:
def predict(params, x):
    """
    Function to generate a prediction given weights and the activation
    
    params: list of matrices
        The weights for every layer of the network, including the bias
    x: matrix
        The activation, or the features, to be predicted
    """
    activations = x
    
    for w, b in params[:-1]:
        output = jnp.dot(w.T, activations) + b
        activations = relu(output)
        
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w.T, activations) + final_b
    
    return logits - logsumexp(logits)

Notice that we created a method that will output only the result for one image. As one can expect, if you pass a batch of, let's say, 30 images to this function, then it would crash because we would have a shape problem on our matrices operations. The beauty of Jax is it ability to auto batch methods for us, so we don't need to worry about handling the batches sizes.

In [None]:
batched_predict = vmap(predict, in_axes=(None, 0))

This batched_predict method now will be able to deal with a batch of images for us. To do so, we must pass the in_axes parameters which will tell us how the parameters for our function must be mapped.

Notice we have two parameters: weights and x, the weights of our neural network and the image on which they should be applied. The None for the first parameter tells vmap that we should not map this parameter anywhere, i.e, it is not batcheable. The 0 tells us that the x (or the images) should be mapped into the axis 0, i.e, the rows.

### Utility and loss

In [None]:
def accuracy(params, images, targets):
    """
    Calculates the accuracy of the neural network on a set of images

    params: list of matrices
        The weights for every layer of the network, including the bias
    images: list of matrices
        The images to be used on the calculation
    targets: list of labels
        The true labels for each of the targets

    """
    target_class = jnp.argmax(targets, axis=1)
    
    # Predicts the probabilities for each class and get the maximum
    predicted_class = jnp.argmax(batched_predict(params, images), axis=1)

    return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
    preds = batched_predict(params, images)
    return -jnp.mean(preds * targets)

def update(params, x, y):
    grads = grad(loss)(params, x, y)
    return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

In [None]:
# Our MLP will have three hiddens layers, each one with 784, 512 and 512 neurons.
layer_sizes = [784, 512, 512, 10]

# Training parameters
step_size = 0.01
num_epochs = 10
batch_size = 128

# Number of labels
n_targets = 10

# Initializing the network parameters with random values
params = init_network_params(layer_sizes, random.PRNGKey(0))

In [None]:
def get_train_batches(batch_size):
    """
    This function loads the MNIST and returns a batch of images given the batch size
    
    batch_size: integer
        The batch size, i.e, the number of images to be retrieved at each step
    
    """
    ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
    ds = ds.batch(batch_size).prefetch(1)
    return tfds.as_numpy(ds)

In [None]:
non_jit_time = []

for epoch in range(num_epochs):
    start_time = time.time()
    for x, y in get_train_batches(batch_size):
        x = jnp.reshape(x, (len(x), num_pixels))
        y = one_hot(y, num_labels)
        params = update(params, x, y)
    epoch_time = time.time() - start_time
    non_jit_time.append(epoch_time)

    train_acc = accuracy(params, train_images, train_labels)
    test_acc = accuracy(params, test_images, test_labels)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}\n".format(test_acc))

In [None]:
jit_update = jit(update)

jit_time = []

for epoch in range(num_epochs):
    start_time = time.time()
    for x, y in get_train_batches(batch_size):
        x = jnp.reshape(x, (len(x), num_pixels))
        y = one_hot(y, num_labels)
        params = jit_update(params, x, y)
    epoch_time = time.time() - start_time
    jit_time.append(epoch_time)

    train_acc = accuracy(params, train_images, train_labels)
    test_acc = accuracy(params, test_images, test_labels)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}\n".format(test_acc))

In [None]:
plt.figure(figsize=(11,7))
plt.plot(non_jit_time)
plt.plot(jit_time)
plt.xlabel('Epochs')
plt.ylabel('Elapsed Time')
plt.title('Time per epoch (s)')
plt.legend(['Without Jit on Update', 'With Jit on Update'])

print('Non Jit Average epoch time (s): ', np.mean(non_jit_time))
print('Jit Average epoch time (s): ', np.mean(jit_time))