In [1]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import numpy as np
import optax
import tensorflow_datasets as tfds

In [2]:
mnist_builder = tfds.builder("mnist")
mnist_info = mnist_builder.info

In [3]:
mnist_info


tfds.core.DatasetInfo(
    name='mnist',
    full_name='mnist/3.0.1',
    description="""
    The MNIST database of handwritten digits.
    """,
    homepage='http://yann.lecun.com/exdb/mnist/',
    data_path='/Users/rajathbharadwaj/tensorflow_datasets/mnist/3.0.1',
    file_format=tfrecord,
    download_size=11.06 MiB,
    dataset_size=21.00 MiB,
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=10000, num_shards=1>,
        'train': <SplitInfo num_examples=60000, num_shards=1>,
    },
    citation="""@article{lecun2010mnist,
      title={MNIST handwritten digit database},
      author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
      journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},
      volume={2},
      year={2010}
  

In [4]:
mnist_builder.download_and_prepare()
datasets = mnist_builder.as_dataset()

In [5]:
datasets

{'test': <PrefetchDataset element_spec={'image': TensorSpec(shape=(28, 28, 1), dtype=tf.uint8, name=None), 'label': TensorSpec(shape=(), dtype=tf.int64, name=None)}>,
 'train': <PrefetchDataset element_spec={'image': TensorSpec(shape=(28, 28, 1), dtype=tf.uint8, name=None), 'label': TensorSpec(shape=(), dtype=tf.int64, name=None)}>}

In [6]:
train, test = datasets['train'], datasets['test']

In [7]:
train_ds = train.repeat().shuffle(1024).batch(128)

In [8]:
train_ds.prefetch(2)

<PrefetchDataset element_spec={'image': TensorSpec(shape=(None, 28, 28, 1), dtype=tf.uint8, name=None), 'label': TensorSpec(shape=(None,), dtype=tf.int64, name=None)}>

In [9]:
import tensorflow as tf

In [10]:
features = tf.compat.v1.data.make_one_shot_iterator(train_ds).get_next()

2022-10-14 12:47:07.502582: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


In [11]:
class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(32, kernel_size=(3, 3), )(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(128, kernel_size=(3, 3), )(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(512)(x)
        x = nn.relu(x)
        x = nn.Dense(10)(x)
        return x        

In [12]:
def cross_entropy_loss(logits, lables):
    lables_onehot = jax.nn.one_hot(lables, num_classes=10)
    return optax.softmax_cross_entropy(logits, lables_onehot).mean()

In [13]:
import wandb

In [14]:
wandb.init('mnist_with_flax&jax', config={})

[34m[1mwandb[0m: Currently logged in as: [33mrajathdb[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [15]:
def compute_metrics(logits, lables):
    loss = cross_entropy_loss(logits, lable)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == lables)
    return loss, accuracy

In [16]:
def datasets(name):
    ds_builder = tfds.builder(name)
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    train_ds['image'] = jnp.float32(train_ds['image'])/ 255.
    test_ds['image'] = jnp.float32(test_ds['image']) / 255.
    return train_ds, test_ds

In [17]:
@jax.jit
def create_train_state(random_number_gen, lr, m):
    cnn = CNN()
    weight_params = cnn.init(random_number_gen, jnp.ones([1, 28, 28, 1]))
    optim = optax.sgd(lr, m)
    return train_state.TrainState.create(
        apply_fn=cnn.apply, params=weight_params['params'], tx=optim
    )

In [19]:
@jax.jit
def train_step(current_state, img_batch):
    def loss_fn(params):
        logits = CNN().apply({'params': params}, img_batch['image'])
        loss = cross_entropy_loss(logits=logits, labels=img_batch['label'])
        return loss, logits
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = grad_fn(current_state.params)
    current_state = current_state.apply_gradients(grads=grads)
    loss, accuracy = compute_metrics(logits, lables=img_batch['lable'])
    return current_state, {'loss': loss, 'acc': acc}

In [20]:
@jax.jit
def eval_step(params, batch):
    logits = CNN().apply({'params': params}, batch['image'])
    loss, acc = compute_metrics(logits, batch['lable'])
    return {'loss': loss, 'acc': acc}

In [None]:
def train_epoch(state, train_ds, batch_size, epoch, rng):
    train_ds_size = train_ds.shape[0]
    sps = train_ds_size // batch_size
    