<a href="https://colab.research.google.com/github/Dmitrii173173/Awesome-ML/blob/main/MultilayerPerceptron_for_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from flax import linen as nn
import jax
import jax.numpy as jnp
import optax
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from typing import Sequence

In [2]:
LEARNING_RATE = 0.002 # @param{type:"number"}
BATCH_SIZE = 128 # @param{type:"integer"}
N_EPOCHS = 1 # @param{type:"integer"}

In [3]:
(train_loader, test_loader), info = tfds.load(
    "mnist", split=["train", "test"], as_supervised=True, with_info=True
)

min_max_rgb = lambda image, label: (tf.cast(image, tf.float32) / 255., label)
train_loader = train_loader.map(min_max_rgb)
test_loader = test_loader.map(min_max_rgb)

NUM_CLASSES = info.features["label"].num_classes
IMG_SIZE = info.features["image"].shape

train_loader_batched = train_loader.shuffle(
    buffer_size=10_000, reshuffle_each_iteration=True
).batch(BATCH_SIZE, drop_remainder=True)

test_loader_batched = test_loader.batch(BATCH_SIZE, drop_remainder=True)

Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...


Dl Completed...:   0%|          | 0/5 [00:00<?, ? file/s]

Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.


In [4]:
class MLP(nn.Module):
  """A simple multilayer perceptron model for image classification."""
  hidden_sizes: Sequence[int] = (1000, 1000)

  @nn.compact
  def __call__(self, x):
    # Flattens images in the batch.
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(features=self.hidden_sizes[0])(x)
    x = nn.relu(x)
    x = nn.Dense(features=self.hidden_sizes[1])(x)
    x = nn.relu(x)
    x = nn.Dense(features=NUM_CLASSES)(x)
    return x

In [5]:
net = MLP()

@jax.jit
def predict(params, inputs):
  return net.apply({"params": params}, inputs)


@jax.jit
def loss_accuracy(params, data):
  """Computes loss and accuracy over a mini-batch.

  Args:
    params: parameters of the model.
    bn_params: state of the model.
    data: tuple of (inputs, labels).
    is_training: if true, uses train mode, otherwise uses eval mode.

  Returns:
    loss: float
  """
  inputs, labels = data
  logits = predict(params, inputs)
  loss = optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=labels
  ).mean()
  accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
  return loss, {"accuracy": accuracy}

@jax.jit
def update_model(state, grads):
  return state.apply_gradients(grads=grads)

In [6]:
solver = optax.adam(LEARNING_RATE)
rng = jax.random.PRNGKey(0)
dummy_data = jnp.ones((1,) + IMG_SIZE, dtype=jnp.float32)

params = net.init({"params": rng}, dummy_data)["params"]

solver_state = solver.init(params)

def dataset_stats(params, data_loader):
  """Computes loss and accuracy over the dataset `data_loader`."""
  all_accuracy = []
  all_loss = []
  for batch in data_loader.as_numpy_iterator():
    batch_loss, batch_aux = loss_accuracy(params, batch)
    all_loss.append(batch_loss)
    all_accuracy.append(batch_aux["accuracy"])
  return {"loss": np.mean(all_loss), "accuracy": np.mean(all_accuracy)}

In [7]:
train_accuracy = []
train_losses = []

# Computes test set accuracy at initialization.
test_stats = dataset_stats(params, test_loader_batched)
test_accuracy = [test_stats["accuracy"]]
test_losses = [test_stats["loss"]]


@jax.jit
def train_step(params, solver_state, batch):
  # Performs a one step update.
  (loss, aux), grad = jax.value_and_grad(loss_accuracy, has_aux=True)(
      params, batch
  )
  updates, solver_state = solver.update(grad, solver_state, params)
  params = optax.apply_updates(params, updates)
  return params, solver_state, loss, aux


for epoch in range(N_EPOCHS):
  train_accuracy_epoch = []
  train_losses_epoch = []

  for step, train_batch in enumerate(train_loader_batched.as_numpy_iterator()):
    params, solver_state, train_loss, train_aux = train_step(
        params, solver_state, train_batch
    )
    train_accuracy_epoch.append(train_aux["accuracy"])
    train_losses_epoch.append(train_loss)
    if step % 20 == 0:
      print(
          f"step {step}, train loss: {train_loss:.2e}, train accuracy:"
          f" {train_aux['accuracy']:.2f}"
      )

  test_stats = dataset_stats(params, test_loader_batched)
  test_accuracy.append(test_stats["accuracy"])
  test_losses.append(test_stats["loss"])
  train_accuracy.append(np.mean(train_accuracy_epoch))
  train_losses.append(np.mean(train_losses_epoch))

step 0, train loss: 2.29e+00, train accuracy: 0.14
step 20, train loss: 2.55e-01, train accuracy: 0.92
step 40, train loss: 3.55e-01, train accuracy: 0.89
step 60, train loss: 2.48e-01, train accuracy: 0.93
step 80, train loss: 1.35e-01, train accuracy: 0.95
step 100, train loss: 2.12e-01, train accuracy: 0.94
step 120, train loss: 1.85e-01, train accuracy: 0.95
step 140, train loss: 1.87e-01, train accuracy: 0.95
step 160, train loss: 2.22e-01, train accuracy: 0.95
step 180, train loss: 8.08e-02, train accuracy: 0.98
step 200, train loss: 1.11e-01, train accuracy: 0.97
step 220, train loss: 1.16e-01, train accuracy: 0.96
step 240, train loss: 1.85e-01, train accuracy: 0.95
step 260, train loss: 2.08e-01, train accuracy: 0.92
step 280, train loss: 1.40e-01, train accuracy: 0.95
step 300, train loss: 1.69e-01, train accuracy: 0.95
step 320, train loss: 1.06e-01, train accuracy: 0.95
step 340, train loss: 1.08e-01, train accuracy: 0.97
step 360, train loss: 1.56e-01, train accuracy: 0.95

In [8]:
f"Improved accuracy on test DS from {test_accuracy[0]} to {test_accuracy[-1]}"

'Improved accuracy on test DS from 0.14463141560554504 to 0.9708533883094788'