In [1]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import numpy as np
import optax
import tensorflow_datasets as tfds

In [2]:
mnist_builder = tfds.builder("mnist")
mnist_info = mnist_builder.info

In [3]:
mnist_info


tfds.core.DatasetInfo(
    name='mnist',
    full_name='mnist/3.0.1',
    description="""
    The MNIST database of handwritten digits.
    """,
    homepage='http://yann.lecun.com/exdb/mnist/',
    data_path='/Users/rajathbharadwaj/tensorflow_datasets/mnist/3.0.1',
    file_format=tfrecord,
    download_size=11.06 MiB,
    dataset_size=21.00 MiB,
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=10000, num_shards=1>,
        'train': <SplitInfo num_examples=60000, num_shards=1>,
    },
    citation="""@article{lecun2010mnist,
      title={MNIST handwritten digit database},
      author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
      journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},
      volume={2},
      year={2010}
  

In [4]:
mnist_builder.download_and_prepare()
datasets = mnist_builder.as_dataset()

In [5]:
datasets

{'test': <PrefetchDataset element_spec={'image': TensorSpec(shape=(28, 28, 1), dtype=tf.uint8, name=None), 'label': TensorSpec(shape=(), dtype=tf.int64, name=None)}>,
 'train': <PrefetchDataset element_spec={'image': TensorSpec(shape=(28, 28, 1), dtype=tf.uint8, name=None), 'label': TensorSpec(shape=(), dtype=tf.int64, name=None)}>}

In [6]:
train, test = datasets['train'], datasets['test']

In [7]:
train_ds = train.repeat().shuffle(1024).batch(128)

In [8]:
train_ds.prefetch(2)

<PrefetchDataset element_spec={'image': TensorSpec(shape=(None, 28, 28, 1), dtype=tf.uint8, name=None), 'label': TensorSpec(shape=(None,), dtype=tf.int64, name=None)}>

In [9]:
import tensorflow as tf

In [10]:
features = tf.compat.v1.data.make_one_shot_iterator(train_ds).get_next()

2022-10-14 14:17:34.526338: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


In [11]:
class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(32, kernel_size=(3, 3), )(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(128, kernel_size=(3, 3), )(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(512)(x)
        x = nn.relu(x)
        x = nn.Dense(10)(x)
        return x        

In [12]:
def cross_entropy_loss(*, logits, labels):
    labels_onehot = jax.nn.one_hot(labels, num_classes=10)
    return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

In [13]:
import wandb

In [14]:
wandb.init('mnist_with_flax&jax', config={})

[34m[1mwandb[0m: Currently logged in as: [33mrajathdb[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [43]:
def compute_metrics(logits, lables):
    loss = cross_entropy_loss(logits=logits, labels=lables)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == lables)
    return loss, accuracy

In [16]:
def datasets(name):
    ds_builder = tfds.builder(name)
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    train_ds['image'] = jnp.float32(train_ds['image'])/ 255.
    test_ds['image'] = jnp.float32(test_ds['image']) / 255.
    return train_ds, test_ds

In [17]:

def create_train_state(random_number_gen, lr, m):
    cnn = CNN()
    weight_params = cnn.init(random_number_gen, jnp.ones([1, 28, 28, 1]))
    optim = optax.sgd(lr, m)
    return train_state.TrainState.create(
        apply_fn=cnn.apply, params=weight_params['params'], tx=optim
    )

In [39]:
@jax.jit
def train_step(current_state, img_batch):
    def loss_fn(params):
        logits = CNN().apply({'params': params}, img_batch['image'])
        loss = cross_entropy_loss(logits=logits, labels=img_batch['label'])
        return loss, logits
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = grad_fn(current_state.params)
    current_state = current_state.apply_gradients(grads=grads)
    loss, accuracy = compute_metrics(logits=logits, lables==img_batch['label'])
    metrics = {
        'loss': loss,
        'accuracy': accuracy
    }
    return current_state, metrics

In [45]:
@jax.jit
def eval_step(params, batch):
    logits = CNN().apply({'params': params}, batch['image'])
    loss, acc = compute_metrics(logits=logits, lables=batch['label'])
    return {'loss': loss, 'acc': acc}

In [41]:
def train_epoch(state, train_ds, batch_size, epoch, rng):
    train_ds_size = train_ds['image'].shape[0]
    sps = train_ds_size // batch_size
    perms = jax.random.permutation(rng, train_ds_size)
    perms = perms[:sps * batch_size]  # skip incomplete batch
    perms = perms.reshape((sps, batch_size))
    batch_metrics = []
    for perm in perms:
        batch = {k: v[perm, ...] for k, v in train_ds.items()}
        state, metrics = train_step(state, batch)
        wandb.log({
            'training_loss': metrics['loss'],
            "training_acc": metrics['accuracy']
        })
        batch_metrics.append(metrics)
    batch_metrics_np = jax.device_get(batch_metrics)
    epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]}
    print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (
    epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100))

    return state

In [47]:
def eval_model(params, test_ds):
    metrics = eval_step(params, test_ds)
    metrics = jax.device_get(metrics)
    summary = jax.tree_util.tree_map(lambda x: x.item(), metrics)
    return summary['loss'], summary['acc']

In [22]:
train_ds, test_ds = datasets('mnist')

In [23]:
train_ds

{'image': DeviceArray([[[[0.],
                [0.],
                [0.],
                ...,
                [0.],
                [0.],
                [0.]],
 
               [[0.],
                [0.],
                [0.],
                ...,
                [0.],
                [0.],
                [0.]],
 
               [[0.],
                [0.],
                [0.],
                ...,
                [0.],
                [0.],
                [0.]],
 
               ...,
 
               [[0.],
                [0.],
                [0.],
                ...,
                [0.],
                [0.],
                [0.]],
 
               [[0.],
                [0.],
                [0.],
                ...,
                [0.],
                [0.],
                [0.]],
 
               [[0.],
                [0.],
                [0.],
                ...,
                [0.],
                [0.],
                [0.]]],
 
 
              [[[0.],
        

In [24]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

In [25]:
rng

DeviceArray([4146024105,  967050713], dtype=uint32)

In [26]:
init_rng

DeviceArray([2718843009, 1272950319], dtype=uint32)

In [27]:
learning_rate = 0.1
momentum = 0.9

In [28]:
state = create_train_state(init_rng, learning_rate, momentum)
del init_rng  # Must not be used anymore.

In [29]:
state

TrainState(step=0, apply_fn=<bound method Module.apply of CNN()>, params=FrozenDict({
    Conv_0: {
        kernel: DeviceArray([[[[ 3.32821965e-01, -2.82645077e-01, -1.80686668e-01,
                        -6.79024681e-02, -7.08380580e-01,  2.77155995e-01,
                         2.83575892e-01,  5.28992593e-01, -3.09524119e-01,
                         5.76644421e-01,  3.28146845e-01, -7.18250945e-02,
                        -4.51020032e-01, -5.43700039e-01, -3.07050925e-02,
                        -4.92292941e-01,  7.26196110e-01,  1.11178897e-01,
                         3.07978690e-01, -7.01449156e-01, -7.72421584e-02,
                        -4.72714186e-01,  1.41713694e-01, -1.60780072e-01,
                        -1.76688924e-01,  1.37713790e-01, -5.79013899e-02,
                         1.17106691e-01,  3.02598000e-01,  2.02198133e-01,
                         3.30374926e-01,  7.13031411e-01]],
        
                      [[-1.29809931e-01, -3.06068122e-01,  4.04507458e-01

In [30]:
num_epochs = 10
batch_size = 32

In [31]:
wandb.config = {
  "learning_rate": learning_rate,
  "epochs": num_epochs,
  "batch_size": batch_size,
}

In [32]:
wandb.config

{'learning_rate': 0.1, 'epochs': 10, 'batch_size': 32}

In [48]:
for epoch in range(1, num_epochs + 1):
  # Use a separate PRNG key to permute image data during shuffling
    rng, input_rng = jax.random.split(rng)
  # Run an optimization step over a training batch
    state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
  # Evaluate on the test set after each training epoch
    test_loss, test_accuracy = eval_model(state.params, test_ds)
    print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (
      epoch, test_loss, test_accuracy * 100))
    wandb.log({
        'test_loss': test_loss,
        'test_accuracy': test_accuracy
    })

train epoch: 1, loss: 0.2005, accuracy: 95.42
 test epoch: 1, loss: 0.28, accuracy: 93.08
train epoch: 2, loss: 0.3859, accuracy: 91.78
 test epoch: 2, loss: 0.52, accuracy: 88.11
train epoch: 3, loss: 2.1672, accuracy: 18.07
 test epoch: 3, loss: 2.31, accuracy: 9.82
train epoch: 4, loss: 2.3088, accuracy: 10.60
 test epoch: 4, loss: 2.30, accuracy: 10.32
train epoch: 5, loss: 2.3077, accuracy: 10.26
 test epoch: 5, loss: 2.31, accuracy: 10.09
train epoch: 6, loss: 2.3083, accuracy: 10.52
 test epoch: 6, loss: 2.31, accuracy: 10.10
train epoch: 7, loss: 2.3079, accuracy: 10.35
 test epoch: 7, loss: 2.31, accuracy: 9.82
train epoch: 8, loss: 2.3080, accuracy: 10.33
 test epoch: 8, loss: 2.31, accuracy: 10.28
train epoch: 9, loss: 2.3080, accuracy: 10.32
 test epoch: 9, loss: 2.31, accuracy: 11.35
train epoch: 10, loss: 2.3088, accuracy: 10.28
 test epoch: 10, loss: 2.30, accuracy: 11.35
