In [None]:
!pip install -U -q pip jax jaxlib
!pip install -U -q git+https://github.com/google/flax.git
!pip install -U torch torchvision

In [3]:
from tqdm import tqdm
import jax
import jax.numpy as jnp
from jax import random


from flax import linen as nn
from flax.training import train_state


import numpy as np
import optax


from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

In [22]:
class CNN(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x

In [28]:
def mnist_transform(x):
  return np.expand_dims(np.array(x, dtype=np.float), axis=2) / 255.

def mnist_collate_fn(batch):
  batch = list(zip(*batch))

  x = np.stack(batch[0])
  y = np.stack(batch[1])
  return x, y

In [29]:
train = MNIST(root='train', train=True, transform=mnist_transform, download=True)
test = MNIST(root='test', train=False, transform=mnist_transform, download=True)
train_loader = DataLoader(train, batch_size=64, shuffle=True, collate_fn=mnist_collate_fn)
test_images = np.expand_dims(jnp.array(test.data), axis=3)
test_lbls = jnp.array(test.targets)

In [35]:
def create_train_state(key, learning_rate, momentum):
  cnn = CNN()
  params = cnn.init(key, jnp.ones([1,28,28,1]))['params']
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)

In [39]:
def compute_accuracy(logits, y):
  accuracy = jnp.mean(jnp.argmax(logits, -1) == y)
  return accuracy

@jax.jit
def train_step(state, x,  y):
  def loss_fn(params):
    logits = CNN().apply({'params': params}, x)
    one_hot_labels = jax.nn.one_hot(y, num_classes=10)
    loss = -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))
    return loss, logits

  (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = {
              'loss': loss,
              'accuracy': compute_accuracy(logits, y),
  }
  return state, metrics

@jax.jit
def eval_step(state, x, y):
  logits = CNN().apply({'params': state.params}, x)
  return compute_accuracy(logits, y)

In [32]:
def train_epoch(state, dataloader, epoch):
  batch_metrics = []
  with tqdm(total=len(dataloader)) as tq:
    for cnt, (x, y) in enumerate(dataloader):
      tq.update(1)
      state, metrics = train_step(state, x, y)
      batch_metrics.append(metrics)

  batch_metrics_np = jax.device_get(batch_metrics)
  epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]
  }
  return state, epoch_metrics_np


def evaluate_model(state, x, y):
  metrics = eval_step(state, x, y)
  metrics = jax.device_get(metrics)
  metrics = jax.tree_map(lambda x: x, item(), metrics)
  return metrics

In [None]:
learning_rate = 0.1
momentum = 0.9
num_epochs = 3
batch_size = 32
key = random.PRNGKey(0)
state = create_train_state(key, learning_rate, momentum)


for epoch in range(1, num_epochs + 1):
  state, train_metrics = train_epoch(state, train_loader, epoch)
  print(f"Train epoch: {epoch}, loss: {train_metrics['loss']:.4}, accuracy: {train_metrics['accuracy'] * 100:.4}")


  test_metrics = eval_step(state, test_images, test_lbls)
  print(f"Test epoch: {epoch}, accuracy: {test_metrics * 100:.4}")