In [1]:
%%bash
pip install -U dm-haiku optax

Collecting dm-haiku
  Downloading dm_haiku-0.0.6-py3-none-any.whl (309 kB)
Collecting optax
  Downloading optax-0.1.2-py3-none-any.whl (140 kB)
Collecting jmp>=0.0.2
  Downloading jmp-0.0.2-py3-none-any.whl (16 kB)
Collecting chex>=0.0.4
  Downloading chex-0.1.3-py3-none-any.whl (72 kB)
Installing collected packages: jmp, chex, optax, dm-haiku
Successfully installed chex-0.1.3 dm-haiku-0.0.6 jmp-0.0.2 optax-0.1.2


In [2]:
from numpy.lib.npyio import BagObj
from typing import Iterator, Mapping, Tuple
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_datasets as tfds

Batch = Mapping[str, np.ndarray]


def net_fn(batch: Batch) -> jnp.ndarray:
  """Standard LeNet-300-100 MLP network."""
  x = batch["image"].astype(jnp.float32) / 255.
  mlp = hk.Sequential([
      hk.Flatten(),
      hk.Linear(1024), jax.nn.tanh,
      hk.Linear(1024), jax.nn.tanh,
      hk.Linear(10),
  ])
  return mlp(x)


def load_dataset(
    split: str,
    *,
    is_training: bool,
    batch_size: int,
) -> Iterator[Batch]:
  """Loads the dataset as a generator of batches."""
  ds = tfds.load("mnist:3.*.*", split=split).cache().repeat()
  if is_training:
    ds = ds.shuffle(10 * batch_size, seed=0)
  ds = ds.batch(batch_size)
  return iter(tfds.as_numpy(ds))


def main(_):
  # Make the network and optimiser.
  net = hk.without_apply_rng(hk.transform(net_fn))
  opt = optax.adam(1e-3)

  # Training loss (cross-entropy).
  def loss(params: hk.Params, batch: Batch) -> jnp.ndarray:
    """Compute the loss of the network, including L2."""
    logits = net.apply(params, batch)
    labels = jax.nn.one_hot(batch["label"], 10)

    l2_loss = 0.5 * sum(jnp.sum(jnp.square(p)) for p in jax.tree_leaves(params))
    softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(logits))
    softmax_xent /= labels.shape[0]

    return softmax_xent + 1e-4 * l2_loss
    
  def accuracy(params: hk.Params, batch: Batch) -> jnp.ndarray:
    predictions = net.apply(params, batch)
    return jnp.mean(jnp.argmax(predictions, axis=-1) == batch["label"])

  def update(
      params: hk.Params,
      opt_state: optax.OptState,
      batch: Batch,
  ) -> Tuple[hk.Params, optax.OptState]:
    """Learning rule (stochastic gradient descent)."""
    loss_fn = lambda params, bb=batch: loss(params, bb)
    tangents = {}
    for layer_name, layer in params.items():
      tangents[layer_name] = {}
      for param_name, p in layer.items():
        tangents[layer_name][param_name] = jax.random.normal(jax.random.PRNGKey(42), jnp.shape(p))
    JVP = jax.jvp(loss_fn, (params,), (tangents,))
    updates = {}
    for layer_name, layer in tangents.items():
      updates[layer_name] = {}
      for param_name, p in layer.items():
        updates[layer_name][param_name] = - 0.001 * JVP[1] * p
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state

  # We maintain avg_params, the exponential moving average of the "live" params.
  # avg_params is used only for evaluation (cf. https://doi.org/10.1137/0330046)
  @jax.jit
  def ema_update(params, avg_params):
    return optax.incremental_update(params, avg_params, step_size=0.001)

  # Make datasets.
  train = load_dataset("train", is_training=True, batch_size=64)
  train_eval = load_dataset("train", is_training=False, batch_size=64)
  test_eval = load_dataset("test", is_training=False, batch_size=64)

  # Initialize network and optimiser; note we draw an input to get shapes.
  params = avg_params = net.init(jax.random.PRNGKey(42), next(train))
  opt_state = opt.init(params)

  # Train/eval loop.
  for step in range(100001):
    if step % 1000 == 0:
      # Periodically evaluate classification accuracy on train & test sets.
      train_accuracy = accuracy(avg_params, next(train_eval))
      test_accuracy = accuracy(avg_params, next(test_eval))
      train_accuracy, test_accuracy = jax.device_get(
          (train_accuracy, test_accuracy))
      print(f"[Step {step}] Train / Test accuracy: {train_accuracy:.3f} / {test_accuracy:.3f}.")

    # Do SGD on a batch of training examples.
    params, opt_state = update(params, opt_state, next(train))
    avg_params = ema_update(params, avg_params)

main(1)

[Step 0] Train / Test accuracy: 0.109 / 0.109.
[Step 1000] Train / Test accuracy: 0.078 / 0.141.
[Step 2000] Train / Test accuracy: 0.141 / 0.141.
[Step 3000] Train / Test accuracy: 0.078 / 0.094.
[Step 4000] Train / Test accuracy: 0.125 / 0.109.
[Step 5000] Train / Test accuracy: 0.062 / 0.125.
[Step 6000] Train / Test accuracy: 0.141 / 0.141.
[Step 7000] Train / Test accuracy: 0.188 / 0.125.
[Step 8000] Train / Test accuracy: 0.078 / 0.125.
[Step 9000] Train / Test accuracy: 0.094 / 0.094.
[Step 10000] Train / Test accuracy: 0.078 / 0.125.
[Step 11000] Train / Test accuracy: 0.156 / 0.141.
[Step 12000] Train / Test accuracy: 0.125 / 0.156.
[Step 13000] Train / Test accuracy: 0.062 / 0.125.
[Step 14000] Train / Test accuracy: 0.188 / 0.172.
[Step 15000] Train / Test accuracy: 0.109 / 0.109.
[Step 16000] Train / Test accuracy: 0.156 / 0.125.


KeyboardInterrupt: ignored