In [3]:
"""
Links: 

Jax:
https://github.com/google/jax/tree/main/jax/example_libraries
https://teddykoker.com/2022/04/learning-to-learn-jax/
https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html
https://jax.readthedocs.io/en/latest/notebooks/convolutions.html
https://coderzcolumn.com/tutorials/artificial-intelligence/jax-guide-to-create-convolutional-neural-networks

Optax:
https://github.com/deepmind/optax
https://optax.readthedocs.io/en/latest/optax-101.html

Flax:
https://github.com/google/flax
https://flax.readthedocs.io/en/latest/getting_started.html
https://coderzcolumn.com/tutorials/artificial-intelligence/flax-cnn

"""

'\nLinks: \n\nJax:\nhttps://github.com/google/jax/tree/main/jax/example_libraries\nhttps://teddykoker.com/2022/04/learning-to-learn-jax/\nhttps://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html\nhttps://jax.readthedocs.io/en/latest/notebooks/convolutions.html\nhttps://coderzcolumn.com/tutorials/artificial-intelligence/jax-guide-to-create-convolutional-neural-networks\n\nOptax:\nhttps://github.com/deepmind/optax\nhttps://optax.readthedocs.io/en/latest/optax-101.html\n\nFlax:\nhttps://github.com/google/flax\nhttps://flax.readthedocs.io/en/latest/getting_started.html\nhttps://coderzcolumn.com/tutorials/artificial-intelligence/flax-cnn\n\n'

In [4]:
"""
Flax CNN Example using MNIST
"""

import jax
import jax.numpy as jnp                # JAX NumPy

from flax import linen as nn           # The Linen API
from flax.training import train_state  # Useful dataclass to keep train state

import numpy as np                     # Ordinary NumPy
import optax                           # Optimizers
import tensorflow_datasets as tfds     # TFDS for MNIST

# Suppress warning and info messages from jax
import os  
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

2023-01-23 18:39:33.515226: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-01-23 18:39:33.540994: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-01-23 18:39:34.009815: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-01-23 18:39:34.009869: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
  from .autonotebook import tqdm as notebook_tqdm


In [5]:
class CNN(nn.Module):
    """A simple CNN model."""

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        
        return x

In [6]:
def cross_entropy_loss(*, logits, labels):
    labels_onehot = jax.nn.one_hot(labels, num_classes=10)
    return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

def compute_metrics(*, logits, labels):
    loss = cross_entropy_loss(logits=logits, labels=labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    metrics = {
      'loss': loss,
      'accuracy': accuracy,
    }
    return metrics

def get_datasets():
    """Load MNIST train and test datasets into memory."""
    ds_builder = tfds.builder('mnist')
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    train_ds['image'] = jnp.float32(train_ds['image']) / 255.
    test_ds['image'] = jnp.float32(test_ds['image']) / 255.
    
    return train_ds, test_ds

def create_train_state(rng, learning_rate, momentum):
    """Creates initial `TrainState`."""
    cnn = CNN()
    params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
    tx = optax.sgd(learning_rate, momentum)
    
    return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)

In [7]:
@jax.jit
def train_step(state, batch):
    """Train for a single step."""
    
    def loss_fn(params):
        logits = CNN().apply({'params': params}, batch['image'])
        loss = cross_entropy_loss(logits=logits, labels=batch['label'])
        return loss, logits
    
    grad_fn = jax.grad(loss_fn, has_aux=True)
    grads, logits = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits=logits, labels=batch['label'])
    
    return state, metrics

@jax.jit
def eval_step(params, batch):
    logits = CNN().apply({'params': params}, batch['image'])
    return compute_metrics(logits=logits, labels=batch['label'])

In [8]:
def train_epoch(state, train_ds, batch_size, epoch, rng):
    """Train for a single epoch."""
    train_ds_size = len(train_ds['image'])
    steps_per_epoch = train_ds_size // batch_size

    perms = jax.random.permutation(rng, train_ds_size)
    perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))
    batch_metrics = []
    
    for perm in perms:
        batch = {k: v[perm, ...] for k, v in train_ds.items()}
        state, metrics = train_step(state, batch)
        batch_metrics.append(metrics)

    # compute mean of metrics across each batch in epoch.
    batch_metrics_np = jax.device_get(batch_metrics)
    epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]
    }

    print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100))

    return state

In [9]:
def eval_model(params, test_ds):
    metrics = eval_step(params, test_ds)
    metrics = jax.device_get(metrics)
    summary = jax.tree_util.tree_map(lambda x: x.item(), metrics)
    return summary['loss'], summary['accuracy']

In [10]:
# Suppress warning and info messages from jax
import os  
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

train_ds, test_ds = get_datasets()
print(train_ds['image'].shape)

rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

learning_rate = 0.1
momentum = 0.9

state = create_train_state(init_rng, learning_rate, momentum)
del init_rng  # Must not be used anymore.

num_epochs = 10
batch_size = 32

2023-01-23 18:39:35.277165: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-01-23 18:39:35.277241: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublas.so.11'; dlerror: libcublas.so.11: cannot open shared object file: No such file or directory
2023-01-23 18:39:35.277288: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublasLt.so.11'; dlerror: libcublasLt.so.11: cannot open shared object file: No such file or directory
2023-01-23 18:39:35.277324: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcufft.so.10'; dlerror: libcufft.so.10: cannot open shared object file: No such file or directory
2023-01-23 18:39:35.277360: W tensorflow/stream_executor/platform/default/dso_loader.cc:64

(60000, 28, 28, 1)


In [11]:
for epoch in range(1, num_epochs + 1):
    # Use a separate PRNG key to permute image data during shuffling
    rng, input_rng = jax.random.split(rng)
    # Run an optimization step over a training batch
    state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
    # Evaluate on the test set after each training epoch
    test_loss, test_accuracy = eval_model(state.params, test_ds)
    print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (epoch, test_loss, test_accuracy * 100))

train epoch: 1, loss: 0.1424, accuracy: 95.62
 test epoch: 1, loss: 0.06, accuracy: 98.24
train epoch: 2, loss: 0.0485, accuracy: 98.51
 test epoch: 2, loss: 0.05, accuracy: 98.53
train epoch: 3, loss: 0.0362, accuracy: 98.88
 test epoch: 3, loss: 0.03, accuracy: 99.10
train epoch: 4, loss: 0.0253, accuracy: 99.21
 test epoch: 4, loss: 0.04, accuracy: 98.99
train epoch: 5, loss: 0.0230, accuracy: 99.31
 test epoch: 5, loss: 0.03, accuracy: 99.15
train epoch: 6, loss: 0.0169, accuracy: 99.50
 test epoch: 6, loss: 0.04, accuracy: 98.81
train epoch: 7, loss: 0.0149, accuracy: 99.53
 test epoch: 7, loss: 0.03, accuracy: 99.09
train epoch: 8, loss: 0.0130, accuracy: 99.59
 test epoch: 8, loss: 0.05, accuracy: 98.63
train epoch: 9, loss: 0.0118, accuracy: 99.63
 test epoch: 9, loss: 0.05, accuracy: 98.93
train epoch: 10, loss: 0.0101, accuracy: 99.69
 test epoch: 10, loss: 0.04, accuracy: 99.14


In [24]:
def get_datasets():
    """Load CIFAR10 train and test datasets into memory."""
    
    ds_builder = tfds.builder('cifar10')
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    train_ds['image'] = jnp.float32(train_ds['image']) / 255.
    test_ds['image'] = jnp.float32(test_ds['image']) / 255.
    
    return train_ds, test_ds

In [17]:
class ResNetV2(nn.Module):
    @nn.compact
    def __call__(self, x):
        
        # Input layer
        x = nn.Conv(features=16, kernel_size=(3, 3), padding=1)(x)  # (4, 32, 32, 3) -> (4, 32, 32, 16)
        x = nn.relu(x)
        
        # Block 1
        x = nn.Conv(features=16, kernel_size=(3, 3), padding=1)(x)  # (4, 32, 32, 16) -> (4, 32, 32, 16)
        x = nn.relu(x)
        x = nn.Conv(features=16, kernel_size=(3, 3), padding=1)(x)  # (4, 32, 32, 16) -> (4, 32, 32, 16)
        x = nn.relu(x)
        
        # Block 2
        x = nn.Conv(features=16, kernel_size=(3, 3), padding=1)(x)  # (4, 32, 32, 16) -> (4, 32, 32, 16)
        x = nn.relu(x)
        x = nn.Conv(features=16, kernel_size=(3, 3), padding=1)(x)  # (4, 32, 32, 16) -> (4, 32, 32, 16)
        x = nn.relu(x)
        
        # Block 3
        x = nn.Conv(features=16, kernel_size=(3, 3), padding=1)(x)  # (4, 32, 32, 16) -> (4, 32, 32, 16)
        x = nn.relu(x)
        x = nn.Conv(features=16, kernel_size=(3, 3), padding=1)(x)  # (4, 32, 32, 16) -> (4, 32, 32, 16)
        x = nn.relu(x)
        
        # Block 4
        x = nn.Conv(features=32, kernel_size=(3, 3), padding=1, strides=2)(x)  # (4, 32, 32, 16) -> (4, 16, 16, 32)
        x = nn.relu(x)
        x = nn.Conv(features=32, kernel_size=(3, 3), padding=1)(x)  # (4, 16, 16, 32) -> (4, 16, 16, 32)
        x = nn.relu(x)
        
        # Block 5
        x = nn.Conv(features=32, kernel_size=(3, 3), padding=1)(x)  # (4, 16, 16, 32) -> (4, 16, 16, 32)
        x = nn.relu(x)
        x = nn.Conv(features=32, kernel_size=(3, 3), padding=1)(x)  # (4, 16, 16, 32) -> (4, 16, 16, 32)
        x = nn.relu(x)
        
        # Block 6
        x = nn.Conv(features=32, kernel_size=(3, 3), padding=1)(x)  # (4, 16, 16, 32) -> (4, 16, 16, 32)
        x = nn.relu(x)
        x = nn.Conv(features=32, kernel_size=(3, 3), padding=1)(x)  # (4, 16, 16, 32) -> (4, 16, 16, 32)
        x = nn.relu(x)
        
        # Block 7
        x = nn.Conv(features=64, kernel_size=(3, 3), padding=1, strides=2)(x)  # (4, 16, 16, 32) -> (4, 8, 8, 64)
        x = nn.relu(x)
        x = nn.Conv(features=64, kernel_size=(3, 3), padding=1)(x)  # (4, 8, 8, 64) -> (4, 8, 8, 64)
        x = nn.relu(x)
        
        # Block 8
        x = nn.Conv(features=64, kernel_size=(3, 3), padding=1)(x)  # (4, 8, 8, 64) -> (4, 8, 8, 64)
        x = nn.relu(x)
        x = nn.Conv(features=64, kernel_size=(3, 3), padding=1)(x)  # (4, 8, 8, 64) -> (4, 8, 8, 64)
        x = nn.relu(x)
        
        # Block 9
        x = nn.Conv(features=64, kernel_size=(3, 3), padding=1)(x)  # (4, 8, 8, 64) -> (4, 8, 8, 64)
        x = nn.relu(x)
        x = nn.Conv(features=64, kernel_size=(3, 3), padding=1)(x)  # (4, 8, 8, 64) -> (4, 8, 8, 64)
        x = nn.relu(x)
        
        # Pooling 
        x = nn.avg_pool(x, window_shape=(8, 8)) # (4, 8, 8, 64) -> (4, 1, 1, 64)
        x = x.flatten()  # flatten (4, 1, 1, 64) -> (4, 64)
        
        # Output 
        x = nn.Dense(features=10)(x)
        x = nn.log_softmax(x)
        return x

In [31]:
def create_train_state(rng, learning_rate, momentum):
    """Creates initial `TrainState`."""
    cnn = ResNetV2()
    params = cnn.init(rng, jnp.ones([1, 32, 32, 3]))['params']
    tx = optax.sgd(learning_rate, momentum)
    
    return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)

@jax.jit
def train_step(state, batch):
    """Train for a single step."""
    
    def loss_fn(params):
        logits = ResNetV2().apply({'params': params}, batch['image'])
        loss = cross_entropy_loss(logits=logits, labels=batch['label'])
        return loss, logits
    
    grad_fn = jax.grad(loss_fn, has_aux=True)
    grads, logits = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits=logits, labels=batch['label'])
    
    return state, metrics

@jax.jit
def eval_step(params, batch):
    logits = ResNetV2().apply({'params': params}, batch['image'])
    return compute_metrics(logits=logits, labels=batch['label'])

In [32]:
train_ds, test_ds = get_datasets()
print(train_ds['image'].shape)

rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

learning_rate = 0.1
momentum = 0.9

state = create_train_state(init_rng, learning_rate, momentum)
del init_rng  # Must not be used anymore.

num_epochs = 3
batch_size = 256

(50000, 32, 32, 3)


In [33]:
for epoch in range(1, num_epochs + 1):
    
    # Use a separate PRNG key to permute image data during shuffling
    rng, input_rng = jax.random.split(rng)
    
    # Run an optimization step over a training batch
    state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
    
    # Evaluate on the test set after each training epoch
    test_loss, test_accuracy = eval_model(state.params, test_ds)
    
    print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (epoch, test_loss, test_accuracy * 100))

AssertionError: [b'train_00026' b'train_08926' b'train_36578' b'train_37193'
 b'train_25123' b'train_11103' b'train_00248' b'train_41776'
 b'train_05751' b'train_18224' b'train_20734' b'train_36163'
 b'train_37135' b'train_12608' b'train_49609' b'train_33290'
 b'train_27556' b'train_45969' b'train_14729' b'train_43136'
 b'train_27632' b'train_32697' b'train_37493' b'train_28656'
 b'train_23750' b'train_11580' b'train_04447' b'train_20043'
 b'train_13623' b'train_49115' b'train_49848' b'train_29417'
 b'train_21753' b'train_49643' b'train_39746' b'train_25872'
 b'train_16616' b'train_10649' b'train_24252' b'train_15527'
 b'train_30433' b'train_39122' b'train_12877' b'train_06903'
 b'train_36079' b'train_35962' b'train_05118' b'train_07392'
 b'train_16140' b'train_26435' b'train_47985' b'train_24918'
 b'train_36567' b'train_20746' b'train_46881' b'train_06261'
 b'train_11011' b'train_34093' b'train_23841' b'train_20323'
 b'train_45329' b'train_02794' b'train_46546' b'train_39896'
 b'train_33435' b'train_19078' b'train_27684' b'train_28542'
 b'train_31430' b'train_43474' b'train_20629' b'train_32453'
 b'train_20287' b'train_28801' b'train_16736' b'train_38079'
 b'train_34681' b'train_39609' b'train_29730' b'train_15414'
 b'train_39033' b'train_08409' b'train_37153' b'train_33968'
 b'train_49605' b'train_01034' b'train_18833' b'train_19936'
 b'train_46795' b'train_44068' b'train_22096' b'train_12910'
 b'train_23917' b'train_20540' b'train_26634' b'train_46994'
 b'train_24772' b'train_11476' b'train_01836' b'train_03829'
 b'train_08626' b'train_20234' b'train_12339' b'train_39595'
 b'train_01960' b'train_19396' b'train_20925' b'train_01635'
 b'train_10292' b'train_07263' b'train_41143' b'train_41793'
 b'train_05250' b'train_41491' b'train_34292' b'train_38887'
 b'train_11270' b'train_16006' b'train_39211' b'train_02668'
 b'train_43279' b'train_47732' b'train_46061' b'train_45053'
 b'train_15867' b'train_29419' b'train_01208' b'train_06841'
 b'train_07431' b'train_46674' b'train_17468' b'train_44065'
 b'train_44806' b'train_41527' b'train_04714' b'train_25564'
 b'train_35175' b'train_36564' b'train_44179' b'train_19423'
 b'train_24387' b'train_34771' b'train_48542' b'train_36400'
 b'train_08192' b'train_48607' b'train_44442' b'train_06062'
 b'train_14172' b'train_33054' b'train_05017' b'train_48154'
 b'train_04992' b'train_49040' b'train_44615' b'train_32510'
 b'train_46171' b'train_47596' b'train_42941' b'train_09260'
 b'train_17986' b'train_27938' b'train_26685' b'train_00584'
 b'train_07900' b'train_06967' b'train_13778' b'train_46311'
 b'train_08924' b'train_28844' b'train_39441' b'train_47543'
 b'train_47273' b'train_42377' b'train_08435' b'train_40420'
 b'train_41561' b'train_10408' b'train_11191' b'train_35551'
 b'train_45648' b'train_33154' b'train_02883' b'train_25050'
 b'train_00918' b'train_01412' b'train_48663' b'train_05266'
 b'train_46576' b'train_37752' b'train_08625' b'train_05078'
 b'train_00763' b'train_14982' b'train_15461' b'train_12765'
 b'train_07287' b'train_42926' b'train_22380' b'train_32854'
 b'train_38790' b'train_47926' b'train_33659' b'train_31426'
 b'train_09579' b'train_34882' b'train_09008' b'train_38034'
 b'train_14178' b'train_35073' b'train_00150' b'train_04191'
 b'train_49381' b'train_17084' b'train_13838' b'train_14745'
 b'train_05635' b'train_11864' b'train_47984' b'train_11367'
 b'train_11311' b'train_24270' b'train_22774' b'train_39602'
 b'train_09811' b'train_30465' b'train_20927' b'train_45931'
 b'train_13341' b'train_49500' b'train_22588' b'train_08604'
 b'train_03564' b'train_05421' b'train_36463' b'train_32404'
 b'train_11283' b'train_18021' b'train_45348' b'train_30212'
 b'train_44157' b'train_26257' b'train_17518' b'train_39870'
 b'train_23327' b'train_25393' b'train_39192' b'train_04489'
 b'train_36268' b'train_01280' b'train_14141' b'train_39401'
 b'train_44999' b'train_23205' b'train_11207' b'train_18613']