This Notebook Contains a Hello World example using JAX on CUDA backend.

In [1]:
import jax

devices = jax.devices()
print(devices)
runtime_device = devices[0]

[CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3)]


In [2]:
devices

[CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3)]

If Multiple CudaDevices are shown here, the installation was successfull. (On my workstation that has 4 GPU's that is.)

### Jax Numpy

Jax main features are exposed in a numpy like API, accessible via Jax Numpy.

In [3]:
import jax.numpy as jnp
import FunctionCollection as fnc
from jax import random
rng = random.PRNGKey(0)
JaxArray = random.normal(rng,(10000,10000))

We can verify that the JaxArray is indeed a JaxArray.

In [4]:
type(JaxArray)

jaxlib.xla_extension.ArrayImpl

We can load the Array onto the GPU, then function calls will be excuted as Kernels.

In [5]:
JaxArrayCPU = JaxArray.copy().to_device(jax.devices('cpu')[0])
JaxArrayGPU = JaxArray.copy().to_device(runtime_device)

In the following we will demonstrate a runtime example running the product $x*x^T$ on cpu and gpu, respectively.
The corresponding function is defined in the FunctionCollection (fnc) as runtime_example.
Before timing, we run the function once, in order to get the compilation done.

In [6]:
import time
fnc.runtime_example(JaxArrayCPU)
fnc.runtime_example(JaxArrayGPU)
#begin Timing on the CPU
start = time.time()
result_cpu = fnc.runtime_example(JaxArrayCPU)
result_cpu.block_until_ready()
end = time.time()
print("Cpu Time: ", end-start)

start = time.time()
result_gpu = fnc.runtime_example(JaxArrayGPU)
result_gpu.block_until_ready()
end = time.time()
print("Gpu Time: ", end-start)

Cpu Time:  1.5748882293701172
Gpu Time:  0.02760004997253418


The Runtime acceleration suggests an running on the GPU.

### MNIST
Now that we have seen that Jax is indeed running, lets do the ML Hello World of MNIST Classification!
Note that JAX is a AutoDiff Library, NOT a NN library; hence it does not come with a lot of QOL features.

<a href="#neuralnetwork"> Neural Network </a> <br>
<a href="#dataloader"> Data Loading </a> <br>
<a href="#training"> Training Loop </a>



#### Writing the Neural Network
<a id="neuralnetwork"> </a>
The most important difference between implementation in Torch, TensorFlow and Flux is that JAX does not come with the Tensor approach, where an entire batch is making up a Tensor of size (batch, datadims) (or the other way around). Instead we have to define our network on singular datapoints and then vectorize ("batching it up") later.
Further notice that the parameters and the model architecture are decoupled, making inspecting the params as well as replacing them straight forward.

In [7]:
import jax.nn as jnn
@jax.jit
def denseLayer(params, image):
    W,b = params
    return jnp.dot(W, image) + b

def denseLayerConstructor(key, indims, outdims, weight=fnc.glorot_normal, bias=fnc.glorot_uniform):
    W_key, b_key = random.split(key)
    W = weight(W_key, (outdims, indims))
    b = bias(b_key, (outdims,))
    return (W,b)

@jax.jit
def model(params, image):
    activations = image
    for param in params[:-1]:
        activations = denseLayer(param, activations)
        activations = jnn.relu(activations)
    activations = denseLayer(params[-1], activations)
    return jnn.softmax(activations)

def modelConstructor(key, layerdims):
    return [denseLayerConstructor(key, layerdims[i], layerdims[i+1]) for i in range(len(layerdims)-1)]

In [8]:
modelParameters = modelConstructor(rng, (28*28,128,10));

To check that it indeed works, lets generate a random Image in vector form.

In [9]:
img = jax.random.uniform(rng, 28*28, minval=0, maxval=1)
model(modelParameters, img)

Array([0.16604736, 0.2734799 , 0.09123447, 0.08354109, 0.12440692,
       0.06872484, 0.0261414 , 0.1224124 , 0.02137834, 0.02263325],      dtype=float32)

In Order to apply this function to a batch of Data, we can use JAX's vmap. 
vmap uses the keywoard in_axes to specify which inputs we want to batch over;
The signature of the function remains the same, such that you can imagine
```
predictor(params, batch) = [model(params, image) for image in batch]
```
Since we do not want to iterate over params, but instead feed them all, we put in_axes=(None,...).
The second argument specifies the axis we want to iterate the batch over. Since our batches will be designed as (batchsize, imagesize), we want to iterate over the 0th axis.


In [10]:
predictor = jax.vmap(model, in_axes=(None,0))

We can see that this now Indeed works on a batch of Images.

In [11]:
batched_images = jax.random.uniform(rng, (10,28*28), minval=0, maxval=1)
predictions = predictor(modelParameters, batched_images)
predictions.shape

(10, 10)

##### Writing the Loss Function
We will implement the Categorial Cross Entropy 
    $$ L = -\sum_x p(x)log(q(x)) $$
    

In [12]:
def CCEloss(params, batch):
    inputs, targets = batch
    predictions = predictor(params, inputs)
    return -jnp.mean(jnp.sum(targets * jnp.log10(predictions), axis=1))

## Dataloading <a id="dataloader"></a>

In [13]:
from datasets import load_dataset, load_from_disk;

  from .autonotebook import tqdm as notebook_tqdm


In [14]:
import pathlib
if not (pathlib.Path(pathlib.Path.cwd() / "MNIST.hf")).exists():
    MNIST = load_dataset("ylecun/mnist")
    MNIST.save_to_disk("MNIST.hf")
else:
    MNIST = load_from_disk("MNIST.hf")

Now we will have to implement our own custom Dataloader. Luckily this is just a generator, with the added capabilities of scrambling our data. We will do that on the CPU and afterwards shift over to the GPU.

In [15]:
training_data = MNIST["train"].with_format("jax")
test_data = MNIST["test"].with_format("jax")
training_inputs = training_data["image"].astype(jnp.float32) / 255
training_labels = jax.nn.one_hot(training_data["label"], 10)
test_inputs = test_data["image"].astype(jnp.float32) / 255  
test_labels = jax.nn.one_hot(test_data["label"], 10)

In this simple Example, we will use a Dense NN on the MNIST data, hence we need to flatten the input first.

In [16]:
training_inputs = training_inputs.reshape(training_inputs.shape[0], -1)
test_inputs = test_inputs.reshape(test_inputs.shape[0], -1)

Now that we have the data, lets write a dataloader in order to feed batches into our Network.

In [17]:
class dataloader(object):
    def __init__(self, inputs, labels, batch_size):
        self.inputs = inputs
        self.labels = labels
        self.batch_size = batch_size
        self.num_batches = inputs.shape[0] // batch_size
        
    def __iter__(self):
        for i in range(self.num_batches):
            start = i * self.batch_size
            end = (i + 1) * self.batch_size
            yield (self.inputs[start:end], self.labels[start:end])

    def __len__(self):
        return self.num_batches
    
    def shuffle(self):
        perm = jax.random.permutation(rng, self.inputs.shape[0])
        self.inputs = self.inputs[perm]
        self.labels = self.labels[perm]

This very basic implementation of an Dataloader is a iterable, capable of shuffling data. Here, the Data must lie within the RAM, however this is a more or less artificial constraint, which we will break as soon as we have larger datasets.

## Training Loop<a id=training> </a>

Equipped with a dataloader and a neural network we can now start to think about training the model.
For this we need to be able to backpropagate through our neural network and to collect the gradients.
This is the job of JAX, and the whole reason we use this software: It can automatically step through a function call tree and (effciently) compute gradients with respect to the input parameters. This is nice, because it means we dont have to do that per hand. In theory, (afaik) manually derived derivatives are more efficient, however the cost in additional developement time is almost never worth it.

In [18]:
grad = jax.grad(CCEloss)(modelParameters, (training_inputs[:10], training_labels[:10]))

We can use this gradient in the update function, which will be accelerated by jax compilation aswell.

In [19]:
@jax.jit
def update(params, batch, lr=0.001):
    grads = jax.grad(CCEloss)(params, batch)
    return [(w - lr *dw, b -lr*db) for (w,b), (dw,db) in zip(params, grads)]

Then the training loop is straightforward an iteration throug our epochs and batches, applying the update function on each batch.

In [20]:
#SETTINGS FOR THE TRAINING
lr = 0.001
NUMEPOCHS = 100
BATCHSIZE = 60

train_loader = dataloader(training_inputs, training_labels, 60)

In [21]:
modelParameters = [(w.to_device(runtime_device), b.to_device(runtime_device)) for (w,b) in modelParameters]

In [22]:
for epoch in range(1):
    train_loader.shuffle()
    for batch in train_loader:
        input, labels = batch
        input  = input.to_device(runtime_device)
        labels = labels.to_device(runtime_device)
        modelParameters = update(modelParameters, (input, labels), lr=lr)

We propably also want to measure how good our model is performing. For this, let us define the accuracy function:

In [23]:
@jax.jit
def accuracy(parameters, batch):
    inputs, targets = batch
    predictions = predictor(parameters, inputs)
    return jnp.mean(jnp.argmax(predictions, axis=1) == jnp.argmax(targets, axis=1))

Adding this to the training Loop above, keeping track of the accuracy, and actually evaluating the Loop for mroe than 1 EPOCH leads to:

In [24]:
keys = jax.random.choice(rng, len(test_inputs)-1, (1000,), replace=False)
train_acc_inputs = training_inputs[keys].to_device(runtime_device)
train_acc_labels = training_labels[keys].to_device(runtime_device)
test_labels = test_labels.to_device(runtime_device)
test_inputs = test_inputs.to_device(runtime_device)


In [25]:


train_accs = []
test_accs  = []
for epoch in range(NUMEPOCHS):
    train_loader.shuffle()
    for batch in train_loader:
        input, labels = batch
        input  = input.to_device(runtime_device)
        labels = labels.to_device(runtime_device)
        modelParameters = update(modelParameters, (input, labels), lr=lr)

    train_acc = accuracy(modelParameters, (training_inputs.to_device(runtime_device), training_labels.to_device(runtime_device)))
    train_accs.append(train_acc)
    test_acc = accuracy(modelParameters, (test_inputs, test_labels))
    test_accs.append(test_acc)



In [26]:
import matplotlib.pyplot as plt
import seaborn as sns
import os
import time

In [27]:
plt.rcParams.update({
    'font.size': 12,
    'axes.labelsize': 14,
    'axes.titlesize': 16,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 12,
    'figure.titlesize': 16,
    'text.usetex': True,
    'font.family': 'serif',
    'font.serif': ['Computer Modern Roman']
})

plt.figure(dpi=300, figsize=(16,9))
plt.plot(test_accs, label="Test Accuracy")
plt.plot(train_accs, label="Train Accuracy")
xlabel = plt.xlabel("Epoch")
ylabel = plt.ylabel("Accuracy")
plt.legend()
os.makedirs("plots", exist_ok=True)
plt.savefig("plots/accuracy.png")
plt.close()
plot_displayer = fnc.PlotDisplay()

In [30]:
plot_displayer.update() #This forces the Renderer to reload the plot
plot_displayer.show("plots/accuracy.png", width=800)

Great, we see that we successfully implemented a MNIST classifier, just like a few billion people before us! But fear not, this is still an achievement, since it is the "Hello World!" of Machine Learning, and hence that we successfully installed all required packages, and understood the very basics of our language (in this case JAX).