# Image Classification in JAX using `TrainState`

> It was easier to follow the guide before trying to turn it into something different.

In [None]:
import tensorflow as tf
tf.config.set_visible_devices([], device_type='GPU')

2023-04-26 12:28:31.204444: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-04-26 12:28:31.258188: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
from typing import Any, Callable, Sequence, Union
import numpy as np

import jax
from jax import lax, random, numpy as jnp
from flax.core import freeze, unfreeze
from flax import linen as nn
from flax import struct
from flax.training import train_state

import optax

from clu import metrics
from ml_collections import ConfigDict

from einops import reduce

from iqadatasets.datasets import *

## Get the data

> We'll be using MNIST from Keras.

In [None]:
(X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.mnist.load_data()

X_train = X_train[:,:,:,None]/255.0
X_test = X_test[:,:,:,None]/255.0
Y_train = Y_train.astype(np.int32)
Y_test = Y_test.astype(np.int32)

X_train.shape, Y_train.shape, X_test.shape, Y_test.shape

((60000, 28, 28, 1), (60000,), (10000, 28, 28, 1), (10000,))

In [None]:
dst_train = tf.data.Dataset.from_tensor_slices((X_train, Y_train))
dst_val = tf.data.Dataset.from_tensor_slices((X_test, Y_test))

In [None]:
config = {
    "BATCH_SIZE": 256,
    "EPOCHS": 50,
    "LEARNING_RATE": 3e-4,
}
config = ConfigDict(config)
config

BATCH_SIZE: 256
EPOCHS: 50
LEARNING_RATE: 0.0003

In [None]:
dst_train_rdy = dst_train.batch(config.BATCH_SIZE)
dst_val_rdy = dst_val.batch(config.BATCH_SIZE)

## Define the model we're going to use

> It's going to be a very simple model just for demonstration purposes.

In [None]:
class Model(nn.Module):
    @nn.compact
    def __call__(self,
                 inputs,
                 **kwargs,
                 ):
        outputs = nn.Conv(features=32, kernel_size=(3,3))(inputs)
        outputs = nn.relu(outputs)
        outputs = nn.max_pool(outputs, window_shape=(2,2), strides=(2,2))
        outputs = nn.Conv(features=64, kernel_size=(3,3))(outputs)
        outputs = nn.relu(outputs)
        outputs = nn.max_pool(outputs, window_shape=(2,2), strides=(2,2))
        outputs = reduce(outputs, "b h w c -> b c", reduction="mean")
        outputs = nn.Dense(10)(outputs)
        return outputs

## Define the metrics with `clu`

In [None]:
@struct.dataclass
class Metrics(metrics.Collection):
    """Collection of metrics to be tracked during training."""
    accuracy: metrics.Accuracy
    loss: metrics.Average.from_output("loss")

By default, `TrainState` doesn't include metrics, but it's very easy to subclass it so that it does:

In [None]:
class TrainState(train_state.TrainState):
    metrics: Metrics

We'll define a function that initializes the `TrainState` from a module, a rng key and some optimizer:

In [None]:
def create_train_state(module, key, tx, input_shape):
    """Creates the initial `TrainState`."""
    params = module.init(key, jnp.ones(input_shape))["params"]
    return TrainState.create(
        apply_fn=module.apply,
        params=params,
        tx=tx,
        metrics=Metrics.empty()
    )

## Defining the training step

> We want to write a function that takes the `TrainState` and a batch of data can performs an optimization step.

In [None]:
@jax.jit
def train_step(state, batch):
    """Train for a single step."""
    inputs, labels = batch
    def loss(params):
        pred = state.apply_fn({"params": params}, inputs)
        loss = optax.softmax_cross_entropy_with_integer_labels(logits=pred, labels=labels).mean()
        return loss
    grads = jax.grad(loss)(state.params)
    state = state.apply_gradients(grads=grads)
    return state

In their example, they don't calculate the metrics at the same time. I think it is kind of a waste because it means having to perform a new forward pass, but we'll follow as of now. Let's define a function to perform metric calculation:

In [None]:
@jax.jit
def compute_metrics(*, state, batch):
    """Obtaining the metrics for a given batch."""
    inputs, labels = batch
    pred = state.apply_fn({"params": state.params}, inputs)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=pred, labels=labels).mean()
    metric_updates = state.metrics.single_from_model_output(
        logits=pred, labels=labels, loss=loss,
    )
    metrics = state.metrics.merge(metric_updates)
    state = state.replace(metrics=metrics)
    return state

## Train the model!

In [None]:
state = create_train_state(Model(), random.PRNGKey(0), optax.adam(config.LEARNING_RATE), input_shape=(1,28,28,1))

In [None]:
metrics_history = {
    "train_loss": [],
    "val_loss": [],
    "train_accuracy": [],
    "val_accuracy": [],
}

In [None]:
%%time
for epoch in range(config.EPOCHS):
    ## Training
    for batch in dst_train_rdy.as_numpy_iterator():
        state = train_step(state, batch)
        state = compute_metrics(state=state, batch=batch)
        # break

    ## Log the metrics
    for name, value in state.metrics.compute().items():
        metrics_history[f"train_{name}"].append(value)
    
    ## Empty the metrics
    state = state.replace(metrics=state.metrics.empty())

    ## Evaluation
    for batch in dst_val_rdy.as_numpy_iterator():
        state = compute_metrics(state=state, batch=batch)
        # break
    for name, value in state.metrics.compute().items():
        metrics_history[f"val_{name}"].append(value)
    state = state.replace(metrics=state.metrics.empty())
    
    print(f'Epoch {epoch} -> [Train] Loss: {metrics_history["train_loss"][-1]} | Accuracy: {metrics_history["train_accuracy"][-1]} [Val] Loss: {metrics_history["val_loss"][-1]} | Accuracy: {metrics_history["val_accuracy"][-1]}')
    # break

Epoch 0 -> [Train] Loss: 2.179863691329956 | Accuracy: 0.26034998893737793 [Val] Loss: 1.9363536834716797 | Accuracy: 0.4207000136375427
Epoch 1 -> [Train] Loss: 1.6733933687210083 | Accuracy: 0.501966655254364 [Val] Loss: 1.4757570028305054 | Accuracy: 0.5743000507354736
Epoch 2 -> [Train] Loss: 1.3742271661758423 | Accuracy: 0.5971333384513855 [Val] Loss: 1.263370394706726 | Accuracy: 0.6460000276565552
Epoch 3 -> [Train] Loss: 1.2035051584243774 | Accuracy: 0.6574167013168335 [Val] Loss: 1.11570143699646 | Accuracy: 0.6895000338554382
Epoch 4 -> [Train] Loss: 1.0782545804977417 | Accuracy: 0.6978999972343445 [Val] Loss: 1.0012279748916626 | Accuracy: 0.7230000495910645
Epoch 5 -> [Train] Loss: 0.9798782467842102 | Accuracy: 0.7282666563987732 [Val] Loss: 0.9096516966819763 | Accuracy: 0.7509000301361084
Epoch 6 -> [Train] Loss: 0.9003434777259827 | Accuracy: 0.7511667013168335 [Val] Loss: 0.8355976343154907 | Accuracy: 0.7704000473022461
Epoch 7 -> [Train] Loss: 0.8347769975662231 |