# Image Classification in JAX using `TrainState` (and `WandB`!) ((and checkpointing!!))

> Adding checkpointing to the *WandB* example!

In [None]:
import tensorflow as tf
tf.config.set_visible_devices([], device_type='GPU')

2023-04-27 12:53:59.255932: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-04-27 12:53:59.310993: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
import os
from typing import Any, Callable, Sequence, Union
import numpy as np

import jax
from jax import lax, random, numpy as jnp
from flax.core import freeze, unfreeze
from flax import linen as nn
from flax import struct
from flax.training import train_state
from flax.training import orbax_utils

import optax
import orbax.checkpoint

from clu import metrics
from ml_collections import ConfigDict

from einops import reduce
import wandb

from iqadatasets.datasets import *

## Get the data

> We'll be using MNIST from Keras.

In [None]:
(X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.mnist.load_data()

X_train = X_train[:,:,:,None]/255.0
X_test = X_test[:,:,:,None]/255.0
Y_train = Y_train.astype(np.int32)
Y_test = Y_test.astype(np.int32)

X_train.shape, Y_train.shape, X_test.shape, Y_test.shape

((60000, 28, 28, 1), (60000,), (10000, 28, 28, 1), (10000,))

In [None]:
dst_train = tf.data.Dataset.from_tensor_slices((X_train, Y_train))
dst_val = tf.data.Dataset.from_tensor_slices((X_test, Y_test))

In [None]:
config = {
    "BATCH_SIZE": 256,
    "EPOCHS": 50,
    "LEARNING_RATE": 3e-4,
}
config = ConfigDict(config)
config

BATCH_SIZE: 256
EPOCHS: 50
LEARNING_RATE: 0.0003

In [None]:
wandb.init(project="MNIST_JAX",
           name="Single_Forward",
           job_type="training",
           config=config,
           mode="online",
           )
config = config
config

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjorgvt[0m. Use [1m`wandb login --relogin`[0m to force relogin


BATCH_SIZE: 256
EPOCHS: 50
LEARNING_RATE: 0.0003

In [None]:
dst_train_rdy = dst_train.batch(config.BATCH_SIZE)
dst_val_rdy = dst_val.batch(config.BATCH_SIZE)

## Define the model we're going to use

> It's going to be a very simple model just for demonstration purposes.

In [None]:
class Model(nn.Module):
    @nn.compact
    def __call__(self,
                 inputs,
                 **kwargs,
                 ):
        outputs = nn.Conv(features=32, kernel_size=(3,3))(inputs)
        outputs = nn.relu(outputs)
        outputs = nn.max_pool(outputs, window_shape=(2,2), strides=(2,2))
        outputs = nn.Conv(features=64, kernel_size=(3,3))(outputs)
        outputs = nn.relu(outputs)
        outputs = nn.max_pool(outputs, window_shape=(2,2), strides=(2,2))
        outputs = reduce(outputs, "b h w c -> b c", reduction="mean")
        outputs = nn.Dense(10)(outputs)
        return outputs

## Define the metrics with `clu`

In [None]:
@struct.dataclass
class Metrics(metrics.Collection):
    """Collection of metrics to be tracked during training."""
    accuracy: metrics.Accuracy
    loss: metrics.Average.from_output("loss")

By default, `TrainState` doesn't include metrics, but it's very easy to subclass it so that it does:

In [None]:
class TrainState(train_state.TrainState):
    metrics: Metrics

We'll define a function that initializes the `TrainState` from a module, a rng key and some optimizer:

In [None]:
def create_train_state(module, key, tx, input_shape):
    """Creates the initial `TrainState`."""
    params = module.init(key, jnp.ones(input_shape))["params"]
    return TrainState.create(
        apply_fn=module.apply,
        params=params,
        tx=tx,
        metrics=Metrics.empty()
    )

## Defining the training step

> We want to write a function that takes the `TrainState` and a batch of data can performs an optimization step.

In [None]:
@jax.jit
def train_step(state, batch):
    """Train for a single step."""
    inputs, labels = batch
    def loss_fn(params):
        pred = state.apply_fn({"params": params}, inputs)
        loss = optax.softmax_cross_entropy_with_integer_labels(logits=pred, labels=labels).mean()
        return loss, pred
    (loss, pred), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    state = state.apply_gradients(grads=grads)
    def compute_metrics(state):
        metric_updates = state.metrics.single_from_model_output(
            logits=pred, labels=labels, loss=loss,
        )
        metrics = state.metrics.merge(metric_updates)
        state = state.replace(metrics=metrics)
        return state
    state = compute_metrics(state)
    return state

In their example, they don't calculate the metrics at the same time. I think it is kind of a waste because it means having to perform a new forward pass, but we'll follow as of now. Let's define a function to perform metric calculation:

In [None]:
@jax.jit
def compute_metrics(*, state, batch):
    """Obtaining the metrics for a given batch."""
    inputs, labels = batch
    pred = state.apply_fn({"params": state.params}, inputs)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=pred, labels=labels).mean()
    metric_updates = state.metrics.single_from_model_output(
        logits=pred, labels=labels, loss=loss,
    )
    metrics = state.metrics.merge(metric_updates)
    state = state.replace(metrics=metrics)
    return state

## Train the model!

In [None]:
state = create_train_state(Model(), random.PRNGKey(0), optax.adam(config.LEARNING_RATE), input_shape=(1,28,28,1))

Before actually training the model we're going to set up the checkpointer to be able to save our trained models:

In [None]:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(state)
# orbax_checkpointer.save("test_save", state, save_args=save_args)

To be able to use versioning and automatic bookkeeping we need to wrap `PyTreeCheckpointer` with `orbax.checkpoint.CheckpointManager`. This allows us to customize the saving even more if we need to. As saving a model is an I/O operation, we may benefit from doing it asyncronously. This is as easy as using `AsyncCheckpointer` instead of `PyTreeCheckpointer`.

In [None]:
metrics_history = {
    "train_loss": [],
    "val_loss": [],
    "train_accuracy": [],
    "val_accuracy": [],
}

In [None]:
%%time
for epoch in range(config.EPOCHS):
    ## Training
    for batch in dst_train_rdy.as_numpy_iterator():
        state = train_step(state, batch)
        # state = compute_metrics(state=state, batch=batch)
        # break

    ## Log the metrics
    for name, value in state.metrics.compute().items():
        metrics_history[f"train_{name}"].append(value)
    
    ## Empty the metrics
    state = state.replace(metrics=state.metrics.empty())

    ## Evaluation
    for batch in dst_val_rdy.as_numpy_iterator():
        state = compute_metrics(state=state, batch=batch)
        # break
    for name, value in state.metrics.compute().items():
        metrics_history[f"val_{name}"].append(value)
    state = state.replace(metrics=state.metrics.empty())

    ## Checkpointing
    if metrics_history["val_accuracy"][-1] >= max(metrics_history["val_accuracy"]):
        orbax_checkpointer.save(os.path.join(wandb.run.dir, "model"), state, save_args=save_args, force=True) # force=True means allow overwritting.

    
    wandb.log({"epoch": epoch+1, **{name:values[-1] for name, values in metrics_history.items()}})
    print(f'Epoch {epoch} -> [Train] Loss: {metrics_history["train_loss"][-1]} | Accuracy: {metrics_history["train_accuracy"][-1]} [Val] Loss: {metrics_history["val_loss"][-1]} | Accuracy: {metrics_history["val_accuracy"][-1]}')
    # break

2023-04-27 12:54:25.313858: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_1' with dtype int32 and shape [60000]
	 [[{{node Placeholder/_1}}]]
2023-04-27 12:54:28.345723: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_1' with dtype int32 and shape [10000]
	 [[{{node Placeholder/_1}}]]


Epoch 0 -> [Train] Loss: 2.1820199489593506 | Accuracy: 0.2582666575908661 [Val] Loss: 1.9363574981689453 | Accuracy: 0.42100003361701965
Epoch 1 -> [Train] Loss: 1.67693293094635 | Accuracy: 0.4978833496570587 [Val] Loss: 1.4757040739059448 | Accuracy: 0.5752000212669373
Epoch 2 -> [Train] Loss: 1.377483606338501 | Accuracy: 0.5930666923522949 [Val] Loss: 1.2632259130477905 | Accuracy: 0.6459000110626221
Epoch 3 -> [Train] Loss: 1.2066980600357056 | Accuracy: 0.6537666916847229 [Val] Loss: 1.1153720617294312 | Accuracy: 0.6894000172615051
Epoch 4 -> [Train] Loss: 1.081371545791626 | Accuracy: 0.6956666707992554 [Val] Loss: 1.0010062456130981 | Accuracy: 0.7227000594139099
Epoch 5 -> [Train] Loss: 0.9829501509666443 | Accuracy: 0.7260167002677917 [Val] Loss: 0.9094106554985046 | Accuracy: 0.7510000467300415
Epoch 6 -> [Train] Loss: 0.9034146666526794 | Accuracy: 0.7497000098228455 [Val] Loss: 0.8354026079177856 | Accuracy: 0.770300030708313
Epoch 7 -> [Train] Loss: 0.8378530740737915 |

In [None]:
wandb.finish()

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
train_accuracy,▁▄▄▅▆▆▆▆▇▇▇▇▇▇▇▇▇███████████████████████
train_loss,█▆▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▄▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇█████████████████████
val_loss,█▆▅▅▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,50.0
train_accuracy,0.92835
train_loss,0.24733
val_accuracy,0.9284
val_loss,0.23935


## Restore the trained model

> We have trained our model, let's see if we can load the trained weights.

In [None]:
new_state = create_train_state(Model(), random.PRNGKey(0), optax.adam(config.LEARNING_RATE), input_shape=(1,28,28,1))
assert not new_state == state

To restore a saved checkpoint we only have to call the `.restore()` method of the checkpointer:

In [None]:
new_state = orbax_checkpointer.restore("test_save")

In [None]:
new_state["opt_state"][0]["count"] == state.opt_state[0].count

Array(True, dtype=bool)

We were able to load the same state but it was loaded as a normal Python `dict`, not as a `TrainState`. If we want to load it as a custom object we have to probide *Orbax* a example of the type of *PyTree* that we want to load. First we'll reinstantiate a new `TrainState` and then we will pass it to `.restore(item=sample_object)` with the `item` argument:

In [None]:
new_state = create_train_state(Model(), random.PRNGKey(0), optax.adam(config.LEARNING_RATE), input_shape=(1,28,28,1))
assert not new_state == state

In [None]:
new_state = orbax_checkpointer.restore("test_save", item=new_state)
assert not new_state == state

If we test for equality we will get a `False` result, but that is because the restoration loads the original `jnp.Array` as `np.array`, but their content will be the same.

In [None]:
jax.tree_util.tree_map(lambda x,y: (x==y).all(), state.params, new_state.params)

FrozenDict({
    Conv_0: {
        bias: Array(True, dtype=bool),
        kernel: Array(True, dtype=bool),
    },
    Conv_1: {
        bias: Array(True, dtype=bool),
        kernel: Array(True, dtype=bool),
    },
    Dense_0: {
        bias: Array(True, dtype=bool),
        kernel: Array(True, dtype=bool),
    },
})

In [None]:
jax.tree_util.tree_map(lambda x,y: (x==y).all(), state.opt_state, new_state.opt_state)

(ScaleByAdamState(count=Array(True, dtype=bool), mu=FrozenDict({
     Conv_0: {
         bias: Array(True, dtype=bool),
         kernel: Array(True, dtype=bool),
     },
     Conv_1: {
         bias: Array(True, dtype=bool),
         kernel: Array(True, dtype=bool),
     },
     Dense_0: {
         bias: Array(True, dtype=bool),
         kernel: Array(True, dtype=bool),
     },
 }), nu=FrozenDict({
     Conv_0: {
         bias: Array(True, dtype=bool),
         kernel: Array(True, dtype=bool),
     },
     Conv_1: {
         bias: Array(True, dtype=bool),
         kernel: Array(True, dtype=bool),
     },
     Dense_0: {
         bias: Array(True, dtype=bool),
         kernel: Array(True, dtype=bool),
     },
 })),
 EmptyState())