In [None]:
import torchvision
import jax.numpy as jnp
import jax.numpy as jnp
from flax import linen as nn
import jax
from typing import Any, Callable, Sequence
from jax import lax, random, numpy as jnp
from flax import linen as nn
from flax.training import train_state
from clu import metrics
import flax
import optax

In [None]:
import numpy as np
from jax.tree_util import tree_map
from torch.utils import data
from torchvision.datasets import MNIST

def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

class NumpyLoader(data.DataLoader):
  def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        pin_memory=pin_memory,
        drop_last=drop_last,
        timeout=timeout,
        worker_init_fn=worker_init_fn)

batch_size = 32
n_targets = 10
from jax.nn import one_hot
class FlattenAndCast(object):
  def __call__(self, pic):
    return jnp.expand_dims(np.array(pic, dtype=jnp.float32) / 255, -1)
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False, transform=FlattenAndCast())
test_generator = NumpyLoader(mnist_dataset_test, batch_size=batch_size, num_workers=0)

In [None]:
class Network(nn.Module):
  @nn.compact
  def __call__(self, x, training: bool = True):
    x = nn.Conv(features=32, kernel_size=(5,5))(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2,2))
    x = nn.Conv(features=64, kernel_size=(5,5))(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2,2))
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(features=1024)(x)
    x = nn.Dropout(rate=0.5, deterministic=not training)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x

In [None]:
model = Network()
root_key = jax.random.key(seed=0)
key1, key2, dropout_key = jax.random.split(key=root_key, num=3)
random_flattened_image = random.normal(key1, (1,28,28,1))
variables = model.init(key2, random_flattened_image, training=False) # Initialization call
params = variables['params']

In [None]:
jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes

In [None]:
model.apply({'params': params}, random_flattened_image, training=True, rngs={'dropout': dropout_key})

In [None]:
@flax.struct.dataclass
class Metrics(metrics.Collection):
  accuracy: metrics.Accuracy
  loss: metrics.Average.from_output('loss')
class TrainState(train_state.TrainState):
  metrics: Metrics
  key: jax.Array
state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    key=dropout_key,
    tx=optax.sgd(learning_rate=0.01, momentum=0.9),
    metrics=Metrics.empty())


In [None]:
@jax.jit
def compute_metrics(state, x, y):
  logits = state.apply_fn(
      {'params': state.params},
      x,
      training=False
      )
  loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=y).mean()
  metric_updates = state.metrics.single_from_model_output(
    logits=logits, labels=y, loss=loss)
  metrics = state.metrics.merge(metric_updates)
  state = state.replace(metrics=metrics)
  return state

In [None]:
@jax.jit
def update(train_state, x, y, dropout_key):
  dropout_train_key = jax.random.fold_in(key=dropout_key, data=train_state.step)
  def loss(params, images, targets):
    logits = train_state.apply_fn(
      {'params': params},
      images,
      training=True,
      rngs={'dropout': dropout_train_key}
      )
    loss_ce = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=targets).mean()
    return loss_ce
  loss_value, grads = jax.value_and_grad(loss)(train_state.params, x, y)
  train_state = train_state.apply_gradients(grads=grads)
  return train_state, loss_value

In [None]:
num_epochs = 25

In [None]:
import time

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in training_generator:
    y = y.astype(jnp.int32)
    state, loss_value = update(state, x, y, dropout_key)
    state = compute_metrics(state, x, y)
  epoch_time = time.time() - start_time
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))

  for metric,value in state.metrics.compute().items():
    print(f"Training set {metric} {value}")
  state = state.replace(metrics=state.metrics.empty())

  test_state = state
  for x, y in test_generator:
    y = y.astype(jnp.int32)
    test_state = compute_metrics(test_state, x, y)

  for metric,value in test_state.metrics.compute().items():
    print(f"Test set {metric} {value}")