In [34]:
import jax
import jax.numpy as jnp
import flax
from flax import nnx
import optax
import tensorflow_datasets as tfds  # TFDS to download MNIST.
import tensorflow as tf  # TensorFlow / `tf.data` operations.
from functools import partial
from models import TestConvNet
from utils import clipping_ste

import matplotlib.pyplot as plt
import seaborn as sns
# from Ipython.display import display, clear_output

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Test network

In [35]:
testnet = TestConvNet(rngs=nnx.Rngs(0))
nnx.display(testnet)

In [37]:
rngs = nnx.Rngs(keys = 0)
x_test = jax.random.normal(rngs.keys(), (10, 28, 28, 1))
y_test = testnet(x_test)
print(y_test.shape)

(10, 10)


In [3]:
x = jnp.arange(0., 1.5, 0.01)
rngs = nnx.Rngs(activation=0)
y = clipping_ste(x, 0.0, 1.0, rngs.activation())
print(y)

[1. 0. 0. 0. 0. 1. 1. 0. 1. 0. 0. 1. 0. 1. 1. 1. 1. 0. 1. 1. 0. 0. 1. 1.
 1. 0. 1. 1. 0. 1. 1. 0. 0. 1. 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 0. 1. 1. 1. 1. 0. 1. 0. 1. 1. 1. 1. 0. 1. 0. 0. 1. 1. 0. 0. 0. 0.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 0. 1. 1. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1.
 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1.]


## Dataset

In [38]:
tf.random.set_seed(0)  # Set the random seed for reproducibility.

train_steps = 5000
eval_every = 500
batch_size = 32

train_ds: tf.data.Dataset = tfds.load('mnist', split='train', data_dir='/local_disk/vikrant/datasets')
test_ds: tf.data.Dataset = tfds.load('mnist', split='test', data_dir='/local_disk/vikrant/datasets')

train_ds = train_ds.map(
  lambda sample: {
    'image': tf.cast(sample['image'], tf.float32) / 255,
    'label': sample['label'],
  }
)  # normalize train set
test_ds = test_ds.map(
  lambda sample: {
    'image': tf.cast(sample['image'], tf.float32) / 255,
    'label': sample['label'],
  }
)  # Normalize the test set.

# Create a shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from.
train_ds = train_ds.repeat().shuffle(1024)
# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency.
train_ds = train_ds.batch(batch_size, drop_remainder=True).take(train_steps).prefetch(1)
# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency.
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)

## Optimizer

In [39]:
learning_rate = 0.005
momentum = 0.9

optimizer = nnx.Optimizer(testnet, optax.adamw(learning_rate, momentum))
metrics = nnx.MultiMetric(
  accuracy=nnx.metrics.Accuracy(),
  loss=nnx.metrics.Average('loss'),
)

nnx.display(optimizer)

# Loss function

In [40]:
def loss_fn(model: TestConvNet, batch):
  logits = model(batch['image'])
  loss = optax.softmax_cross_entropy_with_integer_labels(
    logits=logits, labels=batch['label']
  ).mean()
  return loss, logits

@nnx.jit
def train_step(model: TestConvNet, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
  """Train for a single step."""
  grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(model, batch)
  metrics.update(loss=loss, logits=logits, labels=batch['label'])  # In-place updates.
  optimizer.update(grads)  # In-place updates.

@nnx.jit
def eval_step(model: TestConvNet, metrics: nnx.MultiMetric, batch):
  loss, logits = loss_fn(model, batch)
  metrics.update(loss=loss, logits=logits, labels=batch['label'])  # In-place updates.


## Training

In [41]:
metrics_history = {
  'train_loss': [],
  'train_accuracy': [],
  'test_loss': [],
  'test_accuracy': [],
}

for step, batch in enumerate(train_ds.as_numpy_iterator()):
  # Run the optimization for one step and make a stateful update to the following:
  # - The train state's model parameters
  # - The optimizer state
  # - The training loss and accuracy batch metrics
  train_step(testnet, optimizer, metrics, batch)

  if step > 0 and (step % eval_every == 0 or step == train_steps - 1):  # One training epoch has passed.
    # Log the training metrics.
    for metric, value in metrics.compute().items():  # Compute the metrics.
      metrics_history[f'train_{metric}'].append(value)  # Record the metrics.
    metrics.reset()  # Reset the metrics for the test set.

    # Compute the metrics on the test set after each training epoch.
    for test_batch in test_ds.as_numpy_iterator():
      eval_step(testnet, metrics, test_batch)

    # Log the test metrics.
    for metric, value in metrics.compute().items():
      metrics_history[f'test_{metric}'].append(value)

    metrics.reset()  # Reset the metrics for the next training epoch.

    # print the metrics
    print(f'Step {step}:')
    for metric, value in metrics_history.items():
      print(f'{metric}: {value[-1]}')


Step 500:
train_loss: 3.3003435134887695
train_accuracy: 0.10104790329933167
test_loss: 2.808133602142334
test_accuracy: 0.09595352411270142
Step 1000:
train_loss: 3.177182674407959
train_accuracy: 0.10012500733137131
test_loss: 3.0817670822143555
test_accuracy: 0.11348157376050949
Step 1500:
train_loss: 3.207362651824951
train_accuracy: 0.09987500309944153
test_loss: 3.0334861278533936
test_accuracy: 0.11348157376050949
Step 2000:
train_loss: 3.194047212600708
train_accuracy: 0.0988750010728836
test_loss: 3.223853588104248
test_accuracy: 0.10096153616905212
Step 2500:
train_loss: 3.167471408843994
train_accuracy: 0.09943750500679016
test_loss: 3.066194534301758
test_accuracy: 0.0974559336900711
Step 3000:
train_loss: 3.1213741302490234
train_accuracy: 0.10206250846385956
test_loss: 2.9976444244384766
test_accuracy: 0.0974559336900711
Step 3500:
train_loss: 3.212129592895508
train_accuracy: 0.10225000232458115
test_loss: 3.5534276962280273
test_accuracy: 0.09815705567598343
Step 4000:


# Random testing

In [7]:
key = jax.random.key(0)
x = jax.random.normal(key, (5, 10))
print(f"input: {x}")

x_resize = jnp.resize(x, (5, 16))
print(f"Resize: {x_resize}")

x_pad = jnp.pad(x, pad_width=((0,0), (0, 16 - 10)))
print(f"Padded: {x_pad}")

input: [[ 1.6226422   2.0252647  -0.43359444 -0.07861735  0.1760909  -0.97208923
  -0.49529874  0.4943786   0.6643493  -0.9501635 ]
 [ 2.1795304  -1.9551506   0.35857072  0.15779513  1.2770847   1.5104648
   0.970656    0.59960806  0.0247007  -1.9164772 ]
 [-1.8593491   1.728144    0.04719035  0.814128    0.13132767  0.28284705
   1.2435943   0.6902801  -0.80073744 -0.74099   ]
 [-1.5388287   0.30269185 -0.02071605  0.11328721 -0.2206547   0.07052256
   0.8532958  -0.8217738  -0.01461421 -0.15046217]
 [-0.9001352  -0.7590727   0.33309513  0.80924904  0.04269255 -0.57767123
  -0.41439894 -1.9412533   1.3161184   0.7542728 ]]
Resize: [[ 1.6226422   2.0252647  -0.43359444 -0.07861735  0.1760909  -0.97208923
  -0.49529874  0.4943786   0.6643493  -0.9501635   2.1795304  -1.9551506
   0.35857072  0.15779513  1.2770847   1.5104648 ]
 [ 0.970656    0.59960806  0.0247007  -1.9164772  -1.8593491   1.728144
   0.04719035  0.814128    0.13132767  0.28284705  1.2435943   0.6902801
  -0.80073744 -0.