In [8]:
"""
Links: 

Jax:
https://github.com/google/jax/tree/main/jax/example_libraries
https://teddykoker.com/2022/04/learning-to-learn-jax/
https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html
https://jax.readthedocs.io/en/latest/notebooks/convolutions.html
https://coderzcolumn.com/tutorials/artificial-intelligence/jax-guide-to-create-convolutional-neural-networks

Optax:
https://github.com/deepmind/optax
https://optax.readthedocs.io/en/latest/optax-101.html

Flax:
https://github.com/google/flax
https://flax.readthedocs.io/en/latest/getting_started.html
https://coderzcolumn.com/tutorials/artificial-intelligence/flax-cnn

"""

'\nLinks: \n\nJax:\nhttps://github.com/google/jax/tree/main/jax/example_libraries\nhttps://teddykoker.com/2022/04/learning-to-learn-jax/\nhttps://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html\nhttps://jax.readthedocs.io/en/latest/notebooks/convolutions.html\nhttps://coderzcolumn.com/tutorials/artificial-intelligence/jax-guide-to-create-convolutional-neural-networks\n\nOptax:\nhttps://github.com/deepmind/optax\nhttps://optax.readthedocs.io/en/latest/optax-101.html\n\nFlax:\nhttps://github.com/google/flax\nhttps://flax.readthedocs.io/en/latest/getting_started.html\nhttps://coderzcolumn.com/tutorials/artificial-intelligence/flax-cnn\n\n'

In [1]:
"""
Flax CNN Example using MNIST
"""

import jax
import jax.numpy as jnp                # JAX NumPy

from flax import linen as nn           # The Linen API
from flax.training import train_state  # Useful dataclass to keep train state

import numpy as np                     # Ordinary NumPy
import optax                           # Optimizers
import tensorflow_datasets as tfds     # TFDS for MNIST

# Suppress warning and info messages from jax
import os  
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

2022-11-14 15:56:25.044840: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-11-14 15:56:25.080007: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-11-14 15:56:25.875154: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-11-14 15:56:25.875220: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
  from .autonotebook import tqdm as notebook_tqdm


In [11]:
class CNN(nn.Module):
    """A simple CNN model."""

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        
        return x

In [13]:
def cross_entropy_loss(*, logits, labels):
    labels_onehot = jax.nn.one_hot(labels, num_classes=10)
    return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

def compute_metrics(*, logits, labels):
    loss = cross_entropy_loss(logits=logits, labels=labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    metrics = {
      'loss': loss,
      'accuracy': accuracy,
    }
    return metrics

def get_datasets():
    """Load MNIST train and test datasets into memory."""
    ds_builder = tfds.builder('mnist')
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    train_ds['image'] = jnp.float32(train_ds['image']) / 255.
    test_ds['image'] = jnp.float32(test_ds['image']) / 255.
    
    return train_ds, test_ds

def create_train_state(rng, learning_rate, momentum):
    """Creates initial `TrainState`."""
    cnn = CNN()
    params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
    tx = optax.sgd(learning_rate, momentum)
    
    return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)

In [32]:
@jax.jit
def train_step(state, batch):
    """Train for a single step."""
    
    def loss_fn(params):
        logits = CNN().apply({'params': params}, batch['image'])
        loss = cross_entropy_loss(logits=logits, labels=batch['label'])
        return loss, logits
    
    grad_fn = jax.grad(loss_fn, has_aux=True)
    grads, logits = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits=logits, labels=batch['label'])
    
    return state, metrics

@jax.jit
def eval_step(params, batch):
    logits = CNN().apply({'params': params}, batch['image'])
    return compute_metrics(logits=logits, labels=batch['label'])

In [27]:
def train_epoch(state, train_ds, batch_size, epoch, rng):
    """Train for a single epoch."""
    train_ds_size = len(train_ds['image'])
    steps_per_epoch = train_ds_size // batch_size

    perms = jax.random.permutation(rng, train_ds_size)
    perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))
    batch_metrics = []
    
    for perm in perms:
        batch = {k: v[perm, ...] for k, v in train_ds.items()}
        state, metrics = train_step(state, batch)
        batch_metrics.append(metrics)

    # compute mean of metrics across each batch in epoch.
    batch_metrics_np = jax.device_get(batch_metrics)
    epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]
    }

    print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100))

    return state

In [28]:
def eval_model(params, test_ds):
    metrics = eval_step(params, test_ds)
    metrics = jax.device_get(metrics)
    summary = jax.tree_util.tree_map(lambda x: x.item(), metrics)
    return summary['loss'], summary['accuracy']

In [30]:
# Suppress warning and info messages from jax
import os  
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

train_ds, test_ds = get_datasets()

rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

learning_rate = 0.1
momentum = 0.9

state = create_train_state(init_rng, learning_rate, momentum)
del init_rng  # Must not be used anymore.

num_epochs = 10
batch_size = 32

In [33]:
for epoch in range(1, num_epochs + 1):
    # Use a separate PRNG key to permute image data during shuffling
    rng, input_rng = jax.random.split(rng)
    # Run an optimization step over a training batch
    state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
    # Evaluate on the test set after each training epoch
    test_loss, test_accuracy = eval_model(state.params, test_ds)
    print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (epoch, test_loss, test_accuracy * 100))

train epoch: 1, loss: 0.1355, accuracy: 95.83
 test epoch: 1, loss: 0.07, accuracy: 97.86
train epoch: 2, loss: 0.0506, accuracy: 98.50
 test epoch: 2, loss: 0.05, accuracy: 98.51
train epoch: 3, loss: 0.0339, accuracy: 98.98
 test epoch: 3, loss: 0.04, accuracy: 98.61
train epoch: 4, loss: 0.0253, accuracy: 99.23
 test epoch: 4, loss: 0.03, accuracy: 99.18
train epoch: 5, loss: 0.0208, accuracy: 99.40
 test epoch: 5, loss: 0.04, accuracy: 98.82
train epoch: 6, loss: 0.0170, accuracy: 99.48
 test epoch: 6, loss: 0.04, accuracy: 98.97
train epoch: 7, loss: 0.0160, accuracy: 99.51
 test epoch: 7, loss: 0.03, accuracy: 99.01
train epoch: 8, loss: 0.0146, accuracy: 99.54
 test epoch: 8, loss: 0.04, accuracy: 98.93
train epoch: 9, loss: 0.0154, accuracy: 99.56
 test epoch: 9, loss: 0.04, accuracy: 99.00
train epoch: 10, loss: 0.0112, accuracy: 99.66
 test epoch: 10, loss: 0.04, accuracy: 99.01
