# Image Classification in JAX using `TrainState` (and `WandB`!)

> Adding `wandb` logging into the basic example.

In [1]:
import tensorflow as tf
tf.config.set_visible_devices([], device_type='GPU')

2023-04-26 17:33:24.463105: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-04-26 17:33:24.513451: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
from typing import Any, Callable, Sequence, Union
import numpy as np

import jax
from jax import lax, random, numpy as jnp
from flax.core import freeze, unfreeze
from flax import linen as nn
from flax import struct
from flax.training import train_state

import optax

from clu import metrics
from ml_collections import ConfigDict

from einops import reduce
import wandb

from iqadatasets.datasets import *

## Get the data

> We'll be using MNIST from Keras.

In [3]:
(X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.mnist.load_data()

X_train = X_train[:,:,:,None]/255.0
X_test = X_test[:,:,:,None]/255.0
Y_train = Y_train.astype(np.int32)
Y_test = Y_test.astype(np.int32)

X_train.shape, Y_train.shape, X_test.shape, Y_test.shape

((60000, 28, 28, 1), (60000,), (10000, 28, 28, 1), (10000,))

In [4]:
dst_train = tf.data.Dataset.from_tensor_slices((X_train, Y_train))
dst_val = tf.data.Dataset.from_tensor_slices((X_test, Y_test))

In [5]:
config = {
    "BATCH_SIZE": 256,
    "EPOCHS": 50,
    "LEARNING_RATE": 3e-4,
}
config = ConfigDict(config)
config

BATCH_SIZE: 256
EPOCHS: 50
LEARNING_RATE: 0.0003

In [6]:
wandb.init(project="MNIST_JAX",
           name="Single_Forward",
           job_type="training",
           config=config,
           mode="online",
           )
config = config
config

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjorgvt[0m. Use [1m`wandb login --relogin`[0m to force relogin


BATCH_SIZE: 256
EPOCHS: 50
LEARNING_RATE: 0.0003

In [7]:
dst_train_rdy = dst_train.batch(config.BATCH_SIZE)
dst_val_rdy = dst_val.batch(config.BATCH_SIZE)

## Define the model we're going to use

> It's going to be a very simple model just for demonstration purposes.

In [8]:
class Model(nn.Module):
    @nn.compact
    def __call__(self,
                 inputs,
                 **kwargs,
                 ):
        outputs = nn.Conv(features=32, kernel_size=(3,3))(inputs)
        outputs = nn.relu(outputs)
        outputs = nn.max_pool(outputs, window_shape=(2,2), strides=(2,2))
        outputs = nn.Conv(features=64, kernel_size=(3,3))(outputs)
        outputs = nn.relu(outputs)
        outputs = nn.max_pool(outputs, window_shape=(2,2), strides=(2,2))
        outputs = reduce(outputs, "b h w c -> b c", reduction="mean")
        outputs = nn.Dense(10)(outputs)
        return outputs

## Define the metrics with `clu`

In [9]:
@struct.dataclass
class Metrics(metrics.Collection):
    """Collection of metrics to be tracked during training."""
    accuracy: metrics.Accuracy
    loss: metrics.Average.from_output("loss")

By default, `TrainState` doesn't include metrics, but it's very easy to subclass it so that it does:

In [10]:
class TrainState(train_state.TrainState):
    metrics: Metrics

We'll define a function that initializes the `TrainState` from a module, a rng key and some optimizer:

In [11]:
def create_train_state(module, key, tx, input_shape):
    """Creates the initial `TrainState`."""
    params = module.init(key, jnp.ones(input_shape))["params"]
    return TrainState.create(
        apply_fn=module.apply,
        params=params,
        tx=tx,
        metrics=Metrics.empty()
    )

## Defining the training step

> We want to write a function that takes the `TrainState` and a batch of data can performs an optimization step.

In [12]:
@jax.jit
def train_step(state, batch):
    """Train for a single step."""
    inputs, labels = batch
    def loss_fn(params):
        pred = state.apply_fn({"params": params}, inputs)
        loss = optax.softmax_cross_entropy_with_integer_labels(logits=pred, labels=labels).mean()
        return loss, pred
    (loss, pred), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    state = state.apply_gradients(grads=grads)
    def compute_metrics(state):
        metric_updates = state.metrics.single_from_model_output(
            logits=pred, labels=labels, loss=loss,
        )
        metrics = state.metrics.merge(metric_updates)
        state = state.replace(metrics=metrics)
        return state
    state = compute_metrics(state)
    return state

In their example, they don't calculate the metrics at the same time. I think it is kind of a waste because it means having to perform a new forward pass, but we'll follow as of now. Let's define a function to perform metric calculation:

In [13]:
@jax.jit
def compute_metrics(*, state, batch):
    """Obtaining the metrics for a given batch."""
    inputs, labels = batch
    pred = state.apply_fn({"params": state.params}, inputs)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=pred, labels=labels).mean()
    metric_updates = state.metrics.single_from_model_output(
        logits=pred, labels=labels, loss=loss,
    )
    metrics = state.metrics.merge(metric_updates)
    state = state.replace(metrics=metrics)
    return state

## Train the model!

In [14]:
state = create_train_state(Model(), random.PRNGKey(0), optax.adam(config.LEARNING_RATE), input_shape=(1,28,28,1))

In [15]:
metrics_history = {
    "train_loss": [],
    "val_loss": [],
    "train_accuracy": [],
    "val_accuracy": [],
}

In [16]:
%%time
for epoch in range(config.EPOCHS):
    ## Training
    for batch in dst_train_rdy.as_numpy_iterator():
        state = train_step(state, batch)
        # state = compute_metrics(state=state, batch=batch)
        # break

    ## Log the metrics
    for name, value in state.metrics.compute().items():
        metrics_history[f"train_{name}"].append(value)
    
    ## Empty the metrics
    state = state.replace(metrics=state.metrics.empty())

    ## Evaluation
    for batch in dst_val_rdy.as_numpy_iterator():
        state = compute_metrics(state=state, batch=batch)
        # break
    for name, value in state.metrics.compute().items():
        metrics_history[f"val_{name}"].append(value)
    state = state.replace(metrics=state.metrics.empty())
    
    wandb.log({"epoch": epoch+1, **{name:values[-1] for name, values in metrics_history.items()}})
    print(f'Epoch {epoch} -> [Train] Loss: {metrics_history["train_loss"][-1]} | Accuracy: {metrics_history["train_accuracy"][-1]} [Val] Loss: {metrics_history["val_loss"][-1]} | Accuracy: {metrics_history["val_accuracy"][-1]}')
    # break

2023-04-26 17:33:50.399662: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_1' with dtype int32 and shape [60000]
	 [[{{node Placeholder/_1}}]]
2023-04-26 17:33:53.381922: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_1' with dtype int32 and shape [10000]
	 [[{{node Placeholder/_1}}]]


Epoch 0 -> [Train] Loss: 2.182008743286133 | Accuracy: 0.2583500146865845 [Val] Loss: 1.936364769935608 | Accuracy: 0.42110002040863037
Epoch 1 -> [Train] Loss: 1.6769200563430786 | Accuracy: 0.4977000057697296 [Val] Loss: 1.4756618738174438 | Accuracy: 0.5753000378608704
Epoch 2 -> [Train] Loss: 1.3775689601898193 | Accuracy: 0.593250036239624 [Val] Loss: 1.263392448425293 | Accuracy: 0.6458000540733337
Epoch 3 -> [Train] Loss: 1.2068145275115967 | Accuracy: 0.6537833213806152 [Val] Loss: 1.1156055927276611 | Accuracy: 0.6897000074386597
Epoch 4 -> [Train] Loss: 1.081493616104126 | Accuracy: 0.6955500245094299 [Val] Loss: 1.0011286735534668 | Accuracy: 0.7230000495910645
Epoch 5 -> [Train] Loss: 0.9830330014228821 | Accuracy: 0.7260167002677917 [Val] Loss: 0.909456729888916 | Accuracy: 0.751300036907196
Epoch 6 -> [Train] Loss: 0.9034460186958313 | Accuracy: 0.7495499849319458 [Val] Loss: 0.8353870511054993 | Accuracy: 0.7705000638961792
Epoch 7 -> [Train] Loss: 0.837811291217804 | Ac

In [17]:
wandb.finish()

0,1
train_accuracy,▁▄▄▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇██████████████████████
train_loss,█▆▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▄▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇█████████████████████
val_loss,█▆▅▅▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_accuracy,0.92842
train_loss,0.24722
val_accuracy,0.928
val_loss,0.23932
