In [8]:
"""
Links: 

Jax:
https://github.com/google/jax/tree/main/jax/example_libraries
https://teddykoker.com/2022/04/learning-to-learn-jax/
https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html
https://jax.readthedocs.io/en/latest/notebooks/convolutions.html
https://coderzcolumn.com/tutorials/artificial-intelligence/jax-guide-to-create-convolutional-neural-networks

Optax:
https://github.com/deepmind/optax
https://optax.readthedocs.io/en/latest/optax-101.html

Flax:
https://github.com/google/flax
https://flax.readthedocs.io/en/latest/getting_started.html
https://coderzcolumn.com/tutorials/artificial-intelligence/flax-cnn

"""

'\nLinks: \n\nJax:\nhttps://github.com/google/jax/tree/main/jax/example_libraries\nhttps://teddykoker.com/2022/04/learning-to-learn-jax/\nhttps://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html\nhttps://jax.readthedocs.io/en/latest/notebooks/convolutions.html\nhttps://coderzcolumn.com/tutorials/artificial-intelligence/jax-guide-to-create-convolutional-neural-networks\n\nOptax:\nhttps://github.com/deepmind/optax\nhttps://optax.readthedocs.io/en/latest/optax-101.html\n\nFlax:\nhttps://github.com/google/flax\nhttps://flax.readthedocs.io/en/latest/getting_started.html\nhttps://coderzcolumn.com/tutorials/artificial-intelligence/flax-cnn\n\n'

In [45]:
"""
Jax CNN Example using MNIST
"""

from jax.example_libraries import stax, optimizers
from jax import numpy as jnp
from tensorflow import keras
from sklearn.model_selection import train_test_split


(X_train, Y_train), (X_test, Y_test) = keras.datasets.fashion_mnist.load_data()

X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
                                   jnp.array(X_test, dtype=jnp.float32),\
                                   jnp.array(Y_train, dtype=jnp.float32),\
                                   jnp.array(Y_test, dtype=jnp.float32)

X_train, X_test = X_train.reshape(-1,28,28,1), X_test.reshape(-1,28,28,1)

X_train, X_test = X_train/255.0, X_test/255.0

classes =  jnp.unique(Y_train)

X_train.shape, X_test.shape, Y_train.shape, Y_test.shape

((60000, 28, 28, 1), (10000, 28, 28, 1), (60000,), (10000,))

In [41]:
conv_init, conv_apply = stax.serial(
    stax.Conv(32,(3,3), padding="SAME"),
    stax.Relu,
    stax.Conv(16, (3,3), padding="SAME"),
    stax.Relu,

    stax.Flatten,
    stax.Dense(len(classes)),
    stax.Softmax
)

seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e4)
epochs = 10
batch_size=128

weights = conv_init(rng, (batch_size,28,28,1))
params = weights[1]

opt_init, opt_update, opt_get_weights = optimizers.sgd(learning_rate)
opt_state = opt_init(params)

In [42]:
# make prediction
preds = conv_apply(params, X_train[:5])
print(preds)

[[0.1010583  0.08450612 0.08876862 0.10357413 0.09424469 0.06837884
  0.13585266 0.10559722 0.10213768 0.11588172]
 [0.10652138 0.08236858 0.11018915 0.11165013 0.08258919 0.0772452
  0.14238864 0.09115906 0.0884075  0.10748117]
 [0.09566505 0.09545238 0.10094573 0.10249694 0.09882495 0.08883354
  0.11344208 0.09690958 0.10275006 0.10467971]
 [0.10200436 0.08598677 0.10633809 0.10407417 0.09481844 0.0829622
  0.12135128 0.093064   0.09924015 0.11016052]
 [0.09390712 0.08359886 0.10012674 0.11463808 0.09753537 0.07206484
  0.12989205 0.08960547 0.10824842 0.11038305]]


In [43]:
from jax import value_and_grad

def CrossEntropyLoss(weights, input_data, actual):
    preds = conv_apply(weights, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    log_preds = jnp.log(preds)
    return - jnp.sum(one_hot_actual * log_preds)

def TrainModelInBatches(X, Y, epochs, opt_state, batch_size=32):
    for i in range(1, epochs+1):
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices

        losses = [] ## Record loss of each batch
        for batch in batches:
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(CrossEntropyLoss)(opt_get_weights(opt_state), X_batch,Y_batch)

            ## Update Weights
            opt_state = opt_update(i, gradients, opt_state)

            losses.append(loss) ## Record Loss

        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))

    return opt_state

In [44]:
final_opt_state = TrainModelInBatches(X_train, Y_train, epochs, opt_state, batch_size=batch_size)

CrossEntropyLoss : 105.415
CrossEntropyLoss : 70.646
CrossEntropyLoss : 63.197
CrossEntropyLoss : 58.752
CrossEntropyLoss : 55.373
CrossEntropyLoss : 52.601
CrossEntropyLoss : 50.305
CrossEntropyLoss : 48.366
CrossEntropyLoss : 46.709
CrossEntropyLoss : 45.297
