In [1]:
from typing import Iterator, NamedTuple

from absl import app
!pip install dm-haiku
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_datasets as tfds

NUM_CLASSES = 10  # MNIST has 10 classes (hand-written digits).


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting dm-haiku
  Downloading dm_haiku-0.0.9-py3-none-any.whl (352 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m352.1/352.1 kB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jmp>=0.0.2
  Downloading jmp-0.0.4-py3-none-any.whl (18 kB)
Installing collected packages: jmp, dm-haiku
Successfully installed dm-haiku-0.0.9 jmp-0.0.4


In [2]:
class Batch(NamedTuple):
  image: np.ndarray  # [B, H, W, 1]
  label: np.ndarray  # [B]


class TrainingState(NamedTuple):
  params: hk.Params
  avg_params: hk.Params
  opt_state: optax.OptState


In [3]:
def net_fn(images: jax.Array) -> jax.Array:
  """Standard LeNet-300-100 MLP network."""
  x = images.astype(jnp.float32) / 255.
  mlp = hk.Sequential([
      hk.Flatten(),
      hk.Linear(300), jax.nn.relu,
      hk.Linear(100), jax.nn.relu,
      hk.Linear(NUM_CLASSES),
  ])
  return mlp(x)


def load_dataset(
    split: str,
    *,
    shuffle: bool,
    batch_size: int,
) -> Iterator[Batch]:
  """Loads the MNIST dataset."""
  ds = tfds.load("mnist:3.*.*", split=split).cache().repeat()
  if shuffle:
    ds = ds.shuffle(10 * batch_size, seed=0)
  ds = ds.batch(batch_size)
  ds = ds.map(lambda x: Batch(**x))
  return iter(tfds.as_numpy(ds))

In [7]:
# First, make the network and optimiser.
network = hk.without_apply_rng(hk.transform(net_fn))
optimiser = optax.adam(1e-3)

# Make datasets.
train_dataset = load_dataset("train", shuffle=True, batch_size=1_000)
eval_datasets = {
    split: load_dataset(split, shuffle=False, batch_size=10_000)
    for split in ("train", "test")
}

# Initialise network and optimiser; note we draw an input to get shapes.
initial_params = network.init(
    jax.random.PRNGKey(seed=0), next(train_dataset).image)
initial_opt_state = optimiser.init(initial_params)
state = TrainingState(initial_params, initial_params, initial_opt_state)

In [19]:
def loss(params: hk.Params, batch: Batch) -> jax.Array:
  """Cross-entropy classification loss, regularised by L2 weight decay."""
  l2_regulariser = 0.5 * sum(
      jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params))
  log_likelihood = jnp.sum(labels * jax.nn.log_softmax(logits))

  return -log_likelihood / batch_size + 1e-4 * l2_regulariser

def train_step(params: hk.Params, batch: Batch):
  batch_size, *_ = batch.image.shape
  logits = network.apply(params, batch.image)
  labels = jax.nn.one_hot(batch.label, NUM_CLASSES)
  return batch_size, logits, labels

# @jax.jit
def evaluate(params: hk.Params, batch: Batch) -> jax.Array:
  """Evaluation metric (classification accuracy)."""
  logits = network.apply(params, batch.image)
  predictions = jnp.argmax(logits, axis=-1)
  return jnp.mean(predictions == batch.label)

# @jax.jit
def update(state: TrainingState, batch: Batch) -> TrainingState:
  """Learning rule (stochastic gradient descent)."""
  grads = jax.grad(loss)(state.params, batch)
  print(f'grads type: {type(grads)}, opt_state type: {type(state.opt_state)}')
  print(f'grads: {grads.keys()}')
  sys.exit()
  # print(f'params shape: {params.shape}, batch shape: {batch.shape}')

  updates, opt_state = optimiser.update(grads, state.opt_state)
  params = optax.apply_updates(state.params, updates)
  # Compute avg_params, the exponential moving average of the "live" params.
  # We use this only for evaluation (cf. https://doi.org/10.1137/0330046).
  avg_params = optax.incremental_update(
      params, state.avg_params, step_size=0.001)
  return TrainingState(params, avg_params, opt_state)

# Training & evaluation loop.
for step in range(10):
  if step % 100 == 0:
    # Periodically evaluate classification accuracy on train & test sets.
    # Note that each evaluation is only on a (large) batch.
    for split, dataset in eval_datasets.items():
      accuracy = np.array(evaluate(state.avg_params, next(dataset))).item()
      print({"step": step, "split": split, "accuracy": f"{accuracy:.3f}"})

  # Do SGD on a batch of training examples.
  state = update(state, next(train_dataset))


{'step': 0, 'split': 'train', 'accuracy': '0.197'}
{'step': 0, 'split': 'test', 'accuracy': '0.200'}
grads type: <class 'dict'>, opt_state type: <class 'tuple'>
grads: dict_keys(['linear', 'linear_1', 'linear_2'])


SystemExit: ignored

In [18]:
import sys
def loss(params: hk.Params, batch: Batch) -> jax.Array:
  """Cross-entropy classification loss, regularised by L2 weight decay."""
  batch_size, *_ = batch.image.shape
  logits = network.apply(params, batch.image)
  labels = jax.nn.one_hot(batch.label, NUM_CLASSES)

  l2_regulariser = 0.5 * sum(
      jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params))
  log_likelihood = jnp.sum(labels * jax.nn.log_softmax(logits))

  return -log_likelihood / batch_size + 1e-4 * l2_regulariser

# @jax.jit
def evaluate(params: hk.Params, batch: Batch) -> jax.Array:
  """Evaluation metric (classification accuracy)."""
  logits = network.apply(params, batch.image)
  predictions = jnp.argmax(logits, axis=-1)
  return jnp.mean(predictions == batch.label)

# @jax.jit
def update(state: TrainingState, batch: Batch) -> TrainingState:
  """Learning rule (stochastic gradient descent)."""
  grads = jax.grad(loss)(state.params, batch)
  print(f'grads type: {type(grads)}, opt_state type: {type(state.opt_state)}')
  print(f'grads: {grads.keys()}')
  sys.exit()
  # print(f'params shape: {params.shape}, batch shape: {batch.shape}')

  updates, opt_state = optimiser.update(grads, state.opt_state)
  params = optax.apply_updates(state.params, updates)
  # Compute avg_params, the exponential moving average of the "live" params.
  # We use this only for evaluation (cf. https://doi.org/10.1137/0330046).
  avg_params = optax.incremental_update(
      params, state.avg_params, step_size=0.001)
  return TrainingState(params, avg_params, opt_state)

# Training & evaluation loop.
for step in range(10):
  if step % 100 == 0:
    # Periodically evaluate classification accuracy on train & test sets.
    # Note that each evaluation is only on a (large) batch.
    for split, dataset in eval_datasets.items():
      accuracy = np.array(evaluate(state.avg_params, next(dataset))).item()
      print({"step": step, "split": split, "accuracy": f"{accuracy:.3f}"})

  # Do SGD on a batch of training examples.
  state = update(state, next(train_dataset))


{'step': 0, 'split': 'train', 'accuracy': '0.197'}
{'step': 0, 'split': 'test', 'accuracy': '0.200'}
grads type: <class 'dict'>, opt_state type: <class 'tuple'>
grads: dict_keys(['linear', 'linear_1', 'linear_2'])


SystemExit: ignored

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
