In [None]:
!pip install dm-tree dm-sonnet tensorflow tensorflow_datasets ipywidgets matplotlib >/dev/null

In [None]:
import haiku as hk
import jax
import jax.numpy as jnp
from jax.experimental import jax2tf
import sonnet as snt
import tensorflow as tf
import treeimport os

import jax
import jax.numpy as jnp
from functools import partial

import chex
import haiku as hk

import tensorflow_datasets as tfds

import haiku as hk
import jax
import jax.numpy as jnp
from jax.experimental import jax2tf
import sonnet as snt
import tensorflow as tf
import tree

In [None]:
def f(x):
  net = hk.nets.MLP([300, 100, 10])
  return net(x)

f = hk.transform(f)

rng = jax.random.PRNGKey(42)
x = jnp.ones([1, 28 * 28 * 1])
params = f.init(rng, x)

In [None]:
def create_variable(path, value):
  name = '/'.join(map(str, path)).replace('~', '_')
  return tf.Variable(value, name=name)

class JaxModule(snt.Module):
  def __init__(self, params, apply_fn, name=None):
    super().__init__(name=name)
    self._params = tree.map_structure_with_path(create_variable, params)
    self._apply = jax2tf.convert(lambda p, x: apply_fn(p, None, x))
    self._apply = tf.autograph.experimental.do_not_convert(self._apply)

  def __call__(self, inputs):
    return self._apply(self._params, inputs)

net = JaxModule(params, f.apply)
[v.name for v in net.trainable_variables]

In [None]:

ds_train, ds_test = tfds.load('mnist', split=('train', 'test'), shuffle_files=True, as_supervised=True)

def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  image = tf.cast(image, tf.float32) / 255.
  return image, label

ds_train = ds_train.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(60000)
ds_train = ds_train.batch(100)
ds_train = ds_train.repeat()
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

ds_test = ds_test.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_test = ds_test.batch(100)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

In [None]:
net = JaxModule(params, f.apply)
opt = snt.optimizers.Adam(1e-3)

@tf.function(experimental_compile=True, autograph=False)
def train_step(images, labels):
  """Performs one optimizer step on a single mini-batch."""
  with tf.GradientTape() as tape:
    images = snt.flatten(images)
    logits = net(images)
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                          labels=labels)
    loss = tf.reduce_mean(loss)
    params = tape.watched_variables()
    loss += 1e-4 * sum(map(tf.nn.l2_loss, params))

  grads = tape.gradient(loss, params)
  opt.apply(grads, params)
  return loss

for step, (images, labels) in enumerate(ds_train.take(6001)):
  loss = train_step(images, labels)
  if step % 1000 == 0:
    print(f"Step {step}: {loss.numpy()}")

In [None]:
def accuracy(model):
  total = 0
  correct = 0
  for images, labels in ds_test:
    predictions = tf.argmax(model(snt.flatten(images)), axis=1)
    correct += tf.math.count_nonzero(tf.equal(predictions, labels))
    total += images.shape[0]

  print("Got %d/%d (%.02f%%) correct" % (correct, total, correct / total * 100.))

accuracy(net)

In [None]:
import matplotlib.pyplot as plt

def sample(correct, rows, cols):
  """Utility function to show a sample of images."""
  n = 0

  f, ax = plt.subplots(rows, cols)
  if rows > 1:
    ax = tf.nest.flatten([tuple(ax[i]) for i in range(rows)])
  f.set_figwidth(14)
  f.set_figheight(4 * rows)

  for images, labels in ds_test:
    predictions = tf.argmax(net(snt.flatten(images)), axis=1)
    eq = tf.equal(predictions, labels)
    for i, x in enumerate(eq):
      if x.numpy() == correct:
        label = labels[i]
        prediction = predictions[i]
        image = tf.squeeze(images[i])

        ax[n].imshow(image)
        ax[n].set_title("Prediction:{}\nActual:{}".format(prediction, label))

        n += 1
        if n == (rows * cols):
          break

    if n == (rows * cols):
      break

sample(correct=True, rows=1, cols=5)

In [None]:
sample(correct=False, rows=2, cols=5)