 # TPU Flower Classification with TFRecord & SPMD JAX training
 ### A comprehensive pipeline for JAX+multi-core TPU
Hi Kagglers! Here is a small piece of contribution to Kaggle's ongoing [Google Open Source Software Experts](https://www.kaggle.com/google-oss-expert-prize-winners) Challenge. 

Until coming across Google's TPU and JAX, I have been a long-time Pytorch+GPU user. While Pytorch's **imperative** programming paradigm makes manipulating tensors easier at the cost of making analysis difficult, JAX's **functional** programming paradigm makes magic such as [auto-vectorization](https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html), [easy manipulation of higher-order derivatives](https://jax.readthedocs.io/en/latest/jax-101/04-advanced-autodiff.html), native integration of [JIT](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html), and [seamless composition between these functionalities](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) accessible.

![Image](https://pbs.twimg.com/media/EobMOrBVEAAvrSf?format=jpg&name=large)
*Image Credit: James Bradbury's Twitter two years ago [here](!https://twitter.com/jekbradbury/status/1335001804732395522)

JAX and TPU go *particularly* well together. JAX uses the XLA compiler (which underlies TPU usage) and is by design choice easier to analyze (and optimize). The combination has driven cut-edge research, as evidenced by [this paper](https://arxiv.org/pdf/2011.03641) and [deepmind's blog post](https://www.deepmind.com/blog/using-jax-to-accelerate-our-research).   
Building on wonderful introductions to JAX by [Nilay Chauhan](https://www.kaggle.com/discussions/getting-started/315696) (general introduction), [aakashnain](https://www.kaggle.com/code/aakashnain/building-models-in-jax-part1-stax/notebook) (STAX library) and [Sanyam Bhutani](https://www.kaggle.com/getting-started/308753) (JAX 201), 
this notebook hopes to provide those (like myself) who have experience with Pytorch+GPU with a start in the next generation of deep learning technology.
\
Do you still remember the panic and chaos that is pytorch DDP / XLA (and worse so, both) - and did I forget to mention that **switching between single/multiple GPU/TPU** for a JAX pipeline is **as easy as changing the Kaggle runtime**? Try it out!

This introduction series is the product of the past two weeks I spent with JAX in hopes that it'll make this wonderful library accessible to everyone
### This introduction series features
#### 1. How to perform [common flax model surgery](https://www.kaggle.com/roguekk007/flax-model-surgery/)
#### 2. How to define an SOTA computer-vision backbone in `flax`.
#### 3. Port pretrained weights from pytorch implementation for [EfficientnetV2](https://www.kaggle.com/code/roguekk007/efficientnetv2-jax) (the first, to the best of my knowledge).
#### 4. (This notebook) A highly-efficient general-purpose TPU training pipeline in JAX

![image.png](https://storage.googleapis.com/kaggle-competitions/kaggle/21154/logos/header.png?t=2020-06-04-00-33-35")

## Contents
#### 1. Set up kagggle TPU / JAX runtime
#### 2. TFRecords: Don't make data the bottleneck
#### 3. Flax+Optax: [Model surgery](https://www.kaggle.com/roguekk007/flax-model-surgery/) and optimization
#### 4. Single-Process Multi-Data (SPMD) TPU training using JAX
#### 5. Inference

We will be using the beautiful [Flower Classification on TPU](https://www.kaggle.com/competitions/tpu-getting-started) dataset to exemplify a general CV classification pipeline. Feel free to create wonders with it - and don't forget to upvote!

# 1. Set up Kagggle TPU / JAX Runtime

As mentioned [here](https://www.kaggle.com/discussions/getting-started/315696), JAX is not natively supported by Kaggle TPU Runtime (yet. Kaggle what are you waiting for?) so we'll do some setup - credits to [this kernel](https://www.kaggle.com/code/alexlwh/happywhale-flax-jax-tpu-gpu-resnet-baseline). We do not have to do this on cloud TPU in general.

In [None]:
import os
import warnings
import tensorflow as tf
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'

# Seems like the new one will break jax
!pip install --upgrade jax jaxlib==0.3.5 git+https://github.com/deepmind/optax.git flax optax -q
    
if 'TPU_NAME' in os.environ and 'KAGGLE_DATA_PROXY_TOKEN' in os.environ:
    use_tpu = True
    
    import requests 
    from jax.config import config
    if 'TPU_DRIVER_MODE' not in globals():
        url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
        resp = requests.post(url)
        TPU_DRIVER_MODE = 1
    config.FLAGS.jax_xla_backend = "tpu_driver"
    config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
    # Enforce bfloat16 multiplication
    config.update('jax_default_matmul_precision', 'bfloat16')
    print('Registered (Kaggle) TPU:', config.FLAGS.jax_backend_target)
else:
    use_tpu = False

In [None]:
import functools
from tqdm.notebook import tqdm

import jax
import jax.numpy as jnp
from jax import lax 

import flax
import flax.linen as nn
from flax.core import freeze, unfreeze
from flax.core.frozen_dict import FrozenDict
from flax.training.common_utils import shard, shard_prng_key
from flax.serialization import to_state_dict, from_state_dict,\
                        msgpack_serialize, msgpack_restore, from_bytes

import optax
import msgpack
from jax_efficientnetv2 import efficientnet_v2_pretrained

# Apologies for shamelessly oppressing the warnings here
import warnings
warnings.filterwarnings("ignore")

## Specify Training Arguments

Note we specify `batch_size` now for *a single core* and the effective batch size for training is `device_count * batch_size`. 

In [None]:
args = {
    'experiment_name': 'starter',
    # Efficientnetv2-m
    'model': 'm',
    'batch_size': 8, 
    'epochs': 16,
    'base_lr': 7e-5, # should directly correspond to `batch_size`
    
    # Data / augmentation 
    'img_size': 512, # 192, 224, 331, 512
    # The actual label number is 104
    'num_labels': 104 if not use_tpu else 128,
    'saving_dir': '/kaggle/working/',
    'device_count' : jax.device_count(),
    
    # Debugging-purposes
    'sanity_check': False,
}

# The effective lr should linearly scale with batch size
args['lr'] = args['base_lr'] * args['device_count']
if args['sanity_check']:
    args['epochs'] = 1

# Data-specific, should change for each dataset 
args['data_dir'] = '/kaggle/input/tpu-getting-started/tfrecords-jpeg-'\
                    + str(args['img_size'])+'x'+str(args['img_size'])
args['train_dir'] = os.path.join(args['data_dir'], 'train')
args['val_dir'] = os.path.join(args['data_dir'], 'val')
args['test_dir'] = os.path.join(args['data_dir'], 'test')

print('Running on', args['device_count'], 'processors')
print(jax.devices())

# 2. Dataloading using TFRecords
To be honest, Pytorch Dataloader -> TFRecords is a hard but worthwhile switch. For those used to freely playing around with tensors and numpy arrays (at the cost of huge uneliminatable CPU overhead), `tf.data` provides a more efficient pipeline. Here is an [excellent introduction](https://towardsdatascience.com/a-practical-guide-to-tfrecords-584536bc786c). Please refer to this [notebook on this dataset](https://www.kaggle.com/code/ryanholbrook/tfrecords-basics/notebook#Parsing-Serialized-Examples); also refer to [this link](https://www.kaggle.com/code/yihdarshieh?scriptVersionId=41561359&cellId=29) for tips on `tf.data.TFRecordDataset`. The code below are integrated from these sources.

### Note on TFRecord: TFRecord addresses data bottleneck by:
* Parallelize I/O by sharding data across files.
* I/O pre-fetching on large files
* Rule of thumb from [the official docs](https://www.tensorflow.org/tutorials/load_data/tfrecord): create at least 10 * N files for N hosts, and each file should be ideally 100MB+

### Different data pipelines
* `PIL / OpenCV image` -> `Pytorch Dataset` -> `DataLoader`
* `tfrecord` - (specify decoding) -> `TFRecordDataset` - (batching specifications) -> `DataLoader`

The following code block defines `read_labeled_tfrecord` and `read_unlabeled_tfrecord` to specify encoding. They are later used as mapping functions on `TFRecordsDataset` to parse the data.

In [None]:
# Given an `tf.string`, returns a legible tensor
def decode_image(image_data):
    # image is of type `tf.uint8`
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.reshape(image, [args['img_size'], args['img_size'], 3]) 
    return image

# parses general type example, returns data sample (tuple of (image : ), (label : ))
def read_labeled_tfrecord(example):
    # Note how we are defining the example structure here
    labeled_format = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
    }
    parsed_example = tf.io.parse_single_example(example, labeled_format)
    image = decode_image(parsed_example['image'])
    label = tf.cast(parsed_example['class'], tf.int32)
    return {'image': image, 'label': label} 

def read_unlabeled_tfrecord(example):   
    unlabeled_format = {
        "image": tf.io.FixedLenFeature([], tf.string), 
        "id": tf.io.FixedLenFeature([], tf.string),  
    }
    parsed_example = tf.io.parse_single_example(example, unlabeled_format)
    image = decode_image(parsed_example['image'])
    idnum = parsed_example['id']
    return {'image': image, 'id': idnum}

train_filenames = tf.io.gfile.glob(args['train_dir']+'/*.tfrec')
val_filenames = tf.io.gfile.glob(args['val_dir']+'/*.tfrec')
test_filenames = tf.io.gfile.glob(args['test_dir']+'/*.tfrec')

### Batching and Augmentation
Below are batching operations such as `.batch`, `.prefetch`. Use `to_jax` to convert `tf.Tensor` data to jax datatype and `shard` them across computational cores in the last step. Here is a visible walkthrough, also see comments below. Augmentation code adopted from [here](https://gist.github.com/sayakpaul/e0024bae08afcd3d75b6d52fda191025).

* `to_jax`: Applied at the last step of data pipeline to convert TF tensor to jax tensors
* `tf_randaugment`, `normalize`: image processing / augmentation. `tf_randaugment` exemplifies how integrate *any* python function into the `tf.data` pipeline via `tf.py_function` (also see the [helpful doc](https://www.tensorflow.org/api_docs/python/tf/py_function))
* `load_dataset`: Define the data pipeline, see comments

**Remarks:** `dataset.map` is the key to flexibility! With familiarity I find `tf.data` API no less flexible than Pytorch

In [None]:
import tensorflow_datasets as tfds

input_dtype = jnp.bfloat16 if use_tpu else jnp.float32
label_dtype = jnp.int16
    
# Convert a datasample to JAX-pipeline-compatible format. last step
def to_jax(sample):
    sample['image'] = jnp.array(sample['image'], dtype=input_dtype)
    sample['label'] = jnp.array(sample['label'], dtype=label_dtype)
    # Convert labels to one_hot
    sample['label'] = jax.nn.one_hot(sample['label'], args['num_labels'], dtype=label_dtype, axis=-1)
    return shard(sample)

# Augmentation
from imgaug import augmenters as iaa
aug = iaa.RandAugment(n=2, m=15)
def tf_randaugment(sample):
    augment_fn = lambda img: aug(images=img.numpy())
    im_shape = sample['image'].shape
    [sample['image'],] = tf.py_function(augment_fn, [sample['image']], [tf.float32])
    sample['image'].set_shape(im_shape)
    return sample

def normalize_and_resize(sample):
    sample['image'] = tf.cast(sample['image'], tf.float32) / 128. - 1.
    return sample
    
def load_dataset(filenames, labeled=True, ordered=False, shuffle_buffer_size=1, drop_remainder=False,\
                augment=True):
    AUTO = tf.data.experimental.AUTOTUNE
    # tf.data runtime will optimize this parameter
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)

    options = tf.data.Options()
    if not ordered:
        options.experimental_deterministic = False
        
    # Step 1: Read in the data, shuffle and batching
    dataset = dataset.with_options(options)\
                .map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord,
                    num_parallel_calls=tf.data.experimental.AUTOTUNE)\
                .shuffle(shuffle_buffer_size)\
                .batch(args['batch_size'] * args['device_count'], drop_remainder=drop_remainder)
    
    # We exemplify augmentation using RandAugment
    if augment:
        dataset = dataset.map(tf_randaugment, num_parallel_calls=AUTO)
    # Add `prefetch` at the last step to parallize as much as possible!
    dataset = dataset.map(normalize_and_resize).prefetch(AUTO)
    # Finally, apply to_jax transformation
    return map(to_jax, tfds.as_numpy(dataset))

# 3. Flax & Optax: Model Surgery and Optimization

Since JAX **strictly decouples model state (parameters) and the pass-forward function**, performing surgery on pretrained models is a big challenge. I find [flax](https://flax.readthedocs.io/en/latest/) as flexible for model-building after getting used to its design, and [optax](https://optax.readthedocs.io/en/latest/) is a trivial extension in most usage cases.

### Model Surgery: Customize your Backbone
Please refer to my **[Flax Model Surgery](https://www.kaggle.com/roguekk007/flax-model-surgery/edit)** kernel for more resources and examples on this section. The following topics are addressed in that kernel, they are used in the code block below
1. Defining using `nn.Module`
2. Instantiating model parameters separately using `model.init`
3. Stochastic / Mutable behaviors in `nn.Dropout` and `nn.BatchNorm`
4. Common model surgery

In [this utility script](https://www.kaggle.com/code/roguekk007/jax-efficientnetv2), I implemented EfficientnetV2 in flax - go check it out and compare to pytorch implementations! Since JAX is still young and many pretrained weights are in pytorch, I spent some time to port pretrained EfficientnetV2 weights to JAX/flax model - check out [**this notebook on porting pretrained weights**](https://www.kaggle.com/code/roguekk007/efficientnetv2-jax/notebook)

In this notebook, we make the (reasonable) assumption that we need no stochasticity when there are no mutable states (see `__call__` below)

### 🔪 The Sharp Bits: Models on TPU 🔪 ###
**Very Important**: Whenever possible, **make *everything* a multiple of 128** (at least 8) when working with TPU's! 

TPU's underlying architecture does matrix multiplication in these sizes. 

If you don't manually pad the tensors, XLA will pad them for you and there can be *absolutely horrendous* paddings
* As an example, invoking a `512 x 81313` `NormalizedLinear` layer with our current model will cause TPU memory overflow, which is totally unreasonable because such a layer is comparatively little computation. This is because during norm calculation, TPU somehow ends up with a `[Very very big size] * 2` matrix and wants to pad it to `[Very very big size] * 8`. Simply maually padding our layer to `512 x 81408` will fix this problem. This problem this took me 3 hours to debug

In [None]:
from functools import partial
from typing import Any

# A very minimal head, really
# Retrieve output, perform average pooling, then output
class AddHeadtoBackbone(nn.Module):
    backbone: Any
    num_features : int
    dtype : Any = input_dtype
    
    @nn.compact 
    def __call__(self, x):
        mutable = self.is_mutable_collection('batch_stats')
        x = self.backbone(x)
        x = jax.nn.swish(jnp.mean(x, axis=(1, 2)))
        return nn.Dense(self.num_features, use_bias=False, param_dtype=self.dtype, dtype=self.dtype)(x)

# The random_key will be "consumed" by this function so pass a subkey
def get_model(random_key):
    pretrained_weights_path = f'/kaggle/input/efficientnetv2-jax/efficientnetv2-{args["model"]}.msgpack'
    # Go check the utility script out!
    backbone, backbone_params = efficientnet_v2_pretrained(args['model'], pretrained_weights_path, \
                                                           input_size=args['img_size'], dtype=input_dtype, verbose=True)
    # Add our own head and instantiate model parameters
    model = AddHeadtoBackbone(backbone=backbone, num_features=args['num_labels'])
    dummy_inputs = jnp.ones((1, args['img_size'], args['img_size'], 3), input_dtype)  
    random_key, key1, key2 = jax.random.split(random_key, 3)  
    params = unfreeze(model.init({'params': key1, 'dropout': key2}, dummy_inputs))
    
    # Replace the backbone portion of `params` with pretrained weights
    params['params']['backbone'] = backbone_params['params']
    params['batch_stats']['backbone'] = backbone_params['batch_stats']
    return model, params

### Initialize Model and Optimizer
Optax optimizer is fairly straightforward: 
1. Initialize with parameters we wish to optimize **(don't pass in `batch_stats`)**. 
2. Use `train_state` to aggregate model and optimizer parameter states
3. Use `flax.jax_utils.replicate` to broadcast `train_state` to each process

The most relevant usage of `Optax.optimizer`:
```
### Define model and `optimizer`
op_states = optimizer.init(model_params)

### calculate `grads`
updates, op_states = optimizer.update(grads, op_states, model_params)
new_model_states = optax.apply_updates(model_states, updates)
```

In [None]:
random_key = jax.random.PRNGKey(0)
random_key, subkey = jax.random.split(random_key)

# This subkey is "consumed" by `get_model`
model, model_params = get_model(subkey)
# Cast to training dtype
model_params = jax.tree_map(lambda x : x.astype(input_dtype), model_params)

scheduler = optax.constant_schedule(args['lr'])
optimizer = optax.chain(
  optax.clip(1.0),
  optax.adamw(learning_rate=scheduler, weight_decay=1e-4))

# Don't throw the whole model_params in there!! else we'll be optimizing running statistics
op_params = optimizer.init(model_params['params'])
train_state = {'model': model_params, 'op': op_params}
# Broadcast the train state across cores (analagously done for data via "shard")
# Consider: override the name to save memory
pl_train_state = flax.jax_utils.replicate(train_state)

# 4. A JAX Training Pipeline

### What's at stake?
A JAX training pipeline like this offers:
* Seamless switch between XLA devices (CPU / GPU / TPU).
* Easy device parallelism - due to JAX's functional design.  While we need to take more care in writing JAX programs, the reward is running the *same code* on 1 GPU or 8 TPUs. 
* No more! are the days of illegible pytorch DDP and cryptic yet irresolvable TPU hanging - the final straw which motivated this project.

## 🔪 [The Sharp bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)  🔪
#### An (incomprehensive) list of remarks (and bugs which took a long time):
1. Stochastic processes (drop-out, initialization) always consume random seeds - provide them
    * At each step, *split* your key and feed one to `training_step` via `model.apply(...,rngs={'dropout':subkey})`
    * `model.apply` *will not* require rng in eval mode - provide and check for this!
2. Stateful changes with each forward pass: `BatchNorm` has never been such a pain
    * The initialized parameters for each mode contains two fields: `params` and `batch_stats`
    * Do NOT use the gradients for `batch_stats`. They're not meant to be optimized.
3. Multi-process computation: sharding and syncing
    * Use `shard` to scatter a data batch (multiple of device count) across processes, use `shard_prng` for rng keys
    * `replicate` and `unreplicate` helps you manage model states
    * Use `lax.pmean` to sync data (loss, gradients, batch stats) across processes - but using `pmean` in a function means it can only be used when composed with `pmap`
4. `jax.tree_map` is your best friend: broadcast a function on a parameters, dictionary, lists - you name it
5. Carefully differentiate the behavior during training and testing (validation):
    * Don't be sloppy in managing `mutable=False` during validation, don't provide RNG unless you *have to* - they'll come and bite you in the back
    * `FrozenDicts` from flax throws an error whenever you try to change it, `unfreeze/freeze` the state immediately before/after training for good practice

In [None]:
from pipeline_utils_jax import accuracy_fn
from jax.scipy.special import logsumexp

cross_entropy = lambda logits, labels : -(jax.nn.log_softmax(logits * 16) * labels).sum(-1).mean()
metric_fn = accuracy_fn

metric_fn = accuracy_fn
criterion = cross_entropy

def training_step(apply_fn, update_fn, train_state, batch, subkey):
    def calculate_loss(model_params, apply_fn, batch):
        # Here the model *is* mutable!! set mutable='batch_stats' to enable dropout behavior
        # The mutable state is {'batch_stats': ...} (this bug took me two hours)
        logits, mutable_states = apply_fn(model_params, batch['image'], \
                                    mutable='batch_stats', rngs={'dropout': subkey})
        model_params['batch_stats'] = mutable_states['batch_stats']
        loss = criterion(logits, batch['label'])
        # Sync batch_stats and loss across devices
        model_params['batch_stats'] = jax.tree_map(functools.partial(lax.pmean, axis_name='devices'), \
                                                         model_params['batch_stats'])
        loss = lax.pmean(criterion(logits, batch['label']), 'devices')
        return loss, (logits, model_params)
    
    grad_fn = jax.value_and_grad(calculate_loss, has_aux=True)
    (loss, (logits, train_state['model'])), grads = grad_fn(train_state['model'], apply_fn, batch)
    # Only take the gradients for parameters, and sync across devices
    grads = lax.pmean(grads['params'], 'devices')
    
    updates, train_state['op'] = update_fn(grads, train_state['op'], train_state['model']['params'])
    train_state['model']['params'] = optax.apply_updates(train_state['model']['params'], updates)
    
    metric_value = lax.pmean(metric_fn(logits, batch['label']), 'devices')
    return {'train_state': train_state, 'loss': loss, 'metric':metric_value}

def val_step(apply_fn, model_params, batch):
    # Do not enable batch statistic update (mutable=False) ^or^ stochastic behavior
    # We do not need a rng on this forward pass because Dropout is not stochastic anymore
    logits = apply_fn(model_params, batch['image'], mutable=False)
    loss = lax.pmean(criterion(logits, batch['label']), 'devices')
    # Calculate metric
    metric_value = lax.pmean(metric_fn(logits, batch['label']), 'devices')
    return {'loss': loss, 'metric':metric_value}

train_step_parallel = jax.pmap(functools.partial(training_step, apply_fn=model.apply, update_fn=optimizer.update),\
         axis_name='devices')
val_step_parallel = jax.pmap(functools.partial(val_step, apply_fn=model.apply),\
         axis_name='devices')

### A note on `lax.pmap`
This **is** the magic function which makes parallelization trivial. 

We wrote `train_step(apply_fn, update_fn, train_state, batch, subkey)` and `val_step` *exactly as we would have written for single processor*. The only changes are:
* Use `lax.pmean` for syncing necessary information - the cores host separate models and data, they can easily run off on their own
* `apply_fn, update_fn` are arguments which do not vary across processes, so we wrap them in `functools.partial`
* The rest of the arguments **must all be sharded / replicated** else there'll be error. Wrap process-unspecific elements into `partial` or manually shard them.
* Name the pmap-ped axes (I like to call it `devices`) and use this name for `pmean(..., axis_name='devices')` for cross-device syncing

In [None]:
def save_checkpoint(log_path, logs, verbose=True):
    with open(log_path, "wb") as outfile:
        outfile.write(msgpack_serialize(to_state_dict(logs)))
    if verbose:
        print("Checkpoint written to:", log_path)
        
def load_checkpoint(log_path, init_log, verbose=True):
    # Read msgpack file
    with open(log_path, "rb") as data_file:
        byte_data = data_file.read()
    if verbose:
        print('Checkpoint retrieved from:', log_path)
    return from_bytes(init_log, byte_data)

## Boilerplate Pipeline Code
I copied this line-for-line from my pytorch pipeline, the only differences are:
* Create a new random key every step
* Generate dataloaders anew at every epoch
* Manage `freeze` / `unfreeze` meticulously

### And that's it! All there's needed for easy parallelism to work in JAX
#### Several notes on congratulating ourselves:
1. We're running *very efficient training* on TPU via this pipeline!! We're effectively running a model **the size of `efficientnet-b7`** on **batch size of `64`** with an **image size of `512x512`** (!!!) TFRecords is a necessity to endure this throughput. Try changing the kernel runtime to GPU and see what P100 can endure.
2. By changing `input_dtype`, we can easily see how `bfloat16` effectively saves the RAM withou the pains of `fp16` mixed precision - mixed precision is always another layer of wraparound in Pytorch (and GPU in general) and TPU's `bfloat16` has totally solved this problem.

In [None]:
logs = {
    'train_loss': [], 'train_metric': [], 'val_loss': [], 'val_metric': [], 
    'train_state': None, 'args': args}

for epoch in range(args['epochs']):
    ### Training ###
    trainloader = load_dataset(train_filenames, labeled=True, augment=True,\
                    ordered=False, shuffle_buffer_size=4*args['batch_size'], drop_remainder=True)
    # This "total" is specific to batch size!
    counter = tqdm(trainloader, total=199, leave=False)
    pl_train_state['model'] = unfreeze(pl_train_state['model'])
    train_loss, train_metric, train_counter = 0, 0, 0
    for batch in counter:
        random_key, subkey = jax.random.split(random_key)
        out = train_step_parallel(train_state=pl_train_state, batch=batch, subkey=shard_prng_key(subkey))
        pl_train_state = out['train_state']
        loss_value, metric_value = out['loss'][0].item(), out['metric'][0].item()
        
        counter.set_postfix({'loss':loss_value, 'metric':metric_value})
        train_loss += loss_value
        train_metric += metric_value
        train_counter += 1
        if args['sanity_check']:
            break
    pl_train_state['model'] = freeze(pl_train_state['model'])
    train_loss, train_metric = train_loss / train_counter, train_metric / train_counter
    
    
    ### Validation ###
    valloader = load_dataset(val_filenames, labeled=True, augment=False,\
                    ordered=False, shuffle_buffer_size=4*args['batch_size'], drop_remainder=True)
    
    val_steps, val_loss, val_metric = 0, 0, 0
    counter = tqdm(valloader, total=58, leave=False)
    for batch in counter:
        out = val_step_parallel(model_params=pl_train_state['model'], batch=batch)
        loss_value, metric_value = out['loss'][0].item(), out['metric'][0].item()
        val_loss, val_metric, val_steps = val_loss + loss_value, val_metric + metric_value,\
                                        val_steps + 1
        counter.set_postfix({'val_loss':loss_value, 'val_metric':metric_value})
        if args['sanity_check']:
            break
    val_metric, val_loss = val_metric / val_steps, val_loss / val_steps
    
    # Logging
    logs['train_loss'].append(train_loss)
    logs['train_metric'].append(train_metric)
    logs['val_loss'].append(val_loss)
    logs['val_metric'].append(val_metric)
    
    # Checkpointing
    if (val_metric == max(logs['val_metric'])) or (val_loss == min(logs['val_loss'])):
        print(f'Epoch {epoch}, loss:{train_loss:.4f} acc:{train_metric:.4f} val_loss:{val_loss:.4f} val_metric:{val_metric:.4f}')
        out_path = os.path.join(args['saving_dir'], 'checkpoint.msgpack')
        logs['train_state'] = jax.tree_map(lambda x : x.astype(jnp.float16), \
                    flax.jax_utils.unreplicate(pl_train_state))
        save_checkpoint(out_path, logs, verbose=False)
    # load_checkpoint(out_path, logs)