Upon now we have Used JAX directly. Using JAX for ML seems like a straightforward idea, just as we implemented a MLP model for MNIST in the HelloWorld_Cuda notebook. Ofcourse other people (with a lot more ressources than we have) did this aswell. One such stack ontop of JAX is Google's FLAX, which we will introduce in this notebook.

In [1]:
from flax import nnx
import jax
import jax.numpy as jnp
DEFAULTGPU = jax.devices()[0]

FLAX comes with different APIs. For some reason or another, the current recommendation is to use 'nnx'.
Note that in nnx models are stateful! This is a big difference compared to pure the pure JAX model we implemented, or Julia's LUX library (where the developement went the other way, ironically dropping the 'F' from FLUX').

'nnx' comes with a plethora of pre defined layers. As an example, consider again the MNIST problem, where we defined a MLP using a "Dense Layer" and a decoupled parameter Constructor (again, see the Hello World Notebook):

```python
def denseLayer(params, image):
    W,b = params
    return jnp.dot(W, image) + b

def denseLayerConstructor(key, indims, outdims, weight=fnc.glorot_normal, bias=fnc.glorot_uniform):
    W_key, b_key = random.split(key)
    W = weight(W_key, (outdims, indims))
    b = bias(b_key, (outdims,))
    return (W,b)
```

Remember, in 'nnx' Layers shall be statefull, that is for the Dense layer instead of using the combination 'function + parameters' we will combine them into a single object. Hence the Dense Layer would look something like:

```python
class NNXDenseLayer(nnx.Module):
    def __init__(self, key, indims, outdims, weight=fnc.glorot_normal, bias=fnc.glorot_uniform):
        W_key, B_key = random.split(key)
        self.w = nnx.Param(weight(W_key, (outdims, indims)))
        self.b = nnx.Param(bias(B_key, (outdims, )))
        self.din, self.dout = indims, outdims

    def __call__(self, x):
        return x @ self.w + self.b

```


Let us now reimplement our MLP model, using nnx's build-in dense layer:
(HA! we cant, beause nnx's initializers do not work with bias terms. We could use our own initializer, that does successfully dispatch on the bias, or we use the dafault zero initializer. Maybe someday they will fix it. (Or you can do so yourself.))

In [2]:
from flax.nnx import initializers  as nnx_init

class MLP(nnx.Module):
    def __init__(self, rng):
        self.lin1 = nnx.Linear(28*28, 128, rngs= rng, kernel_init = nnx_init.glorot_normal())
        self.lin2 = nnx.Linear(128, 10, rngs = rng, kernel_init=nnx_init.glorot_normal())

    @nnx.jit
    def __call__(self, x):
        x = nnx.relu(self.lin1(x))
        x = nnx.softmax(self.lin2(x))
        return x
    

In [3]:
model = MLP(nnx.Rngs(jax.random.PRNGKey(0)))


We now have to put in on the GPU:

In [4]:
state = nnx.state(model)
state = jax.device_put(state, DEFAULTGPU)
nnx.update(model, state)

# Data Loading

We will reuse the same Dataloading shenannigans as in the Hello World notebook.

In [5]:
from datasets import load_dataset, load_from_disk;
import pathlib
if not (pathlib.Path(pathlib.Path.cwd() / "MNIST.hf")).exists():
    MNIST = load_dataset("ylecun/mnist")
    MNIST.save_to_disk("MNIST.hf")
else:
    MNIST = load_from_disk("MNIST.hf")

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
training_data = MNIST["train"].with_format("jax")
test_data = MNIST["test"].with_format("jax")
training_inputs = training_data["image"].astype(jnp.float32) / 255
training_labels = jax.nn.one_hot(training_data["label"], 10)
test_inputs = test_data["image"].astype(jnp.float32) / 255  
test_labels = jax.nn.one_hot(test_data["label"], 10)

training_inputs = jnp.reshape(training_inputs, (60000, 28*28))
test_inputs = jnp.reshape(test_inputs, (10000, 28*28))

In [7]:
class dataloader(object):
    def __init__(self, inputs, labels, batch_size):
        self.inputs = inputs
        self.labels = labels
        self.batch_size = batch_size
        self.num_batches = inputs.shape[0] // batch_size
        
    def __iter__(self):
        for i in range(self.num_batches):
            start = i * self.batch_size
            end = (i + 1) * self.batch_size
            yield (self.inputs[start:end], self.labels[start:end])

    def __len__(self):
        return self.num_batches
    
    def shuffle(self, rng_key):
        perm = jax.random.permutation(rng_key, self.inputs.shape[0])
        self.inputs = self.inputs[perm]
        self.labels = self.labels[perm]

In [8]:
training_dataloader = dataloader(training_inputs, training_labels, 60)

Note that unlike in JAX we do not have to explicitly interate through our batch, i.e. call vmap. nnx does that for us.

In [9]:
for batch in training_dataloader:
    batch_in, batch_label = batch
    print(f"batch_in.shape {batch_in.shape}, type {type(batch_in)}")
    model(batch_in)
    break

batch_in.shape (60, 784), type <class 'jaxlib.xla_extension.ArrayImpl'>


# The Loss Function

The loss function takes the model and compares it to the correct labels. Note that we have a one-hot encoding for the label classes. We again want to implement the Cross Entropy Loss $$ L = -\sum_x p(x)log(q(x)) ,$$ however <s>nnx</s> optax already contains it.
Optax is Google's optimization package.

In [10]:
import optax

In [11]:
@nnx.jit
def lossfn(model, inputs, labels):
    prediction = model(inputs)
    return optax.softmax_cross_entropy(logits=prediction, labels=labels).mean(), prediction

# The Optimizer
In the HelloWorld notebook we implemented our own version of gradient descent and a custom `update` function. This is ofcourse not needed, when we use optax.

In [12]:
optimizer = nnx.Optimizer(model, optax.sgd(learning_rate=0.001))

# Training Loop
Let us begin with the training loop. For this iterate through batches shuffle them and update.

In [13]:
#note how we dont @partial(jit, ....)
@nnx.jit
def __train_step(model, batch, optimizer):
    batch_in, batch_label = batch
    #seems inefficient to make a function definition here, but its compiled away.
    gradfn = nnx.value_and_grad(lossfn, has_aux=True) #has_aux means that we return a tuple (x, aux) where x is differentiated.
    (loss,logits), grads = gradfn(model, batch_in, batch_label)
    #the model is semi-implicit...
    optimizer.update(grads)


nnx comes with their own Metric function. However, to keep the notebooks somewhat comparable, we use our old accuracy function.

In [14]:
@nnx.jit
def accuracy(model, inputs, labels):
    predictions = model(inputs)
    return (predictions.argmax(axis=1) == labels.argmax(axis=1)).mean()

Now let us define the training loop:

In [15]:
keys = jax.random.choice(jax.random.PRNGKey(0), len(test_inputs)-1, (1000,), replace=False)
train_acc_inputs = training_inputs[keys].to_device(DEFAULTGPU)
train_acc_labels = training_labels[keys].to_device(DEFAULTGPU)
test_labels = test_labels.to_device(DEFAULTGPU)
test_inputs = test_inputs.to_device(DEFAULTGPU)


In [None]:
NUMEPOCHS = 1000


train_accs = []
test_accs  = []

model = MLP(nnx.Rngs(jax.random.PRNGKey(4)))
state = nnx.state(model)
state = jax.device_put(state, DEFAULTGPU)
nnx.update(model, state)    

training_dataloader = dataloader(training_inputs.to_device(DEFAULTGPU), training_labels.to_device(DEFAULTGPU), 128)
optimizer = nnx.Optimizer(model, optax.sgd(learning_rate=0.001))
key = jax.random.PRNGKey(1234)
for epoch in range(NUMEPOCHS):
    key, subkey = jax.random.split(key)
    training_dataloader.shuffle(subkey)
    for batch in training_dataloader:
        input, labels = batch
        #input  = input.to_device(DEFAULTGPU)
        #labels = labels.to_device(DEFAULTGPU)
        __train_step(model, (input, labels), optimizer)

    train_acc = accuracy(model, train_acc_inputs, train_acc_labels)
    train_accs.append(train_acc)
    test_acc = accuracy(model, test_inputs, test_labels)
    test_accs.append(test_acc)


In [16]:
import matplotlib.pyplot as plt
import os
import FunctionCollection as fnc
plt.rcParams.update({
    'font.size': 32,
    'axes.labelsize': 18,
    'axes.titlesize': 16,
    'xtick.labelsize': 18,
    'ytick.labelsize': 18,
    'legend.fontsize': 18,
    'figure.titlesize': 16
})

plt.figure(dpi=300, figsize=(16,9))
plt.plot(test_accs, label="Test Accuracy")
plt.plot(train_accs, label="Train Accuracy")
xlabel = plt.xlabel("Epoch")
ylabel = plt.ylabel("Accuracy")
plt.legend()
os.makedirs("plots", exist_ok=True)
plt.savefig("plots/accuracy.png")
plt.close()


plot_displayer = fnc.PlotDisplay()

In [17]:

plot_displayer.update() #This forces the Renderer to reload the plot
plot_displayer.show("plots/accuracy.png", width=800)