#### ⚠️⚠️⚠️ Note that this notebook is still **WORK IN PROGRESS**, any constructive feedbacks are welcome!

---
### **🎯 OBJECTIVE**
- Demonstrate how you could utilise Flax/JAX to build ⚡high performance model training pipeline for Kaggle competition
- Demonstrate how you could apply transfer learning on different CV backbones (e.g. ResNet) with Flax/JAX

### **🚧 TODO**
- [Basic] Why TPU is not significantly faster than GPU??
- [Basic] In inference, determine when to assign new individual
- [Basic] Add validation set into training loop
- [Basic] Find the best learning rate
- [Basic] Varying learning rate over epochs
- [Basic] Train on bigger image (224x224/ 512x512)
- [Basic] Add K-fold validation split
- [Advanced] Change PyTorch dataloader to more performant data loader
- [Advanced] Add W&B for experiment logging
- [Advanced] Pad remainder batch as last iteration instead of skipping it
- [Advanced] Different head, backbones, loss function... etc.

### **📝 LOG**
- [20/02/2022] First draft released
- [22/02/2022] Resolve bug in cross entropy loss
- [23/02/2022] Work in Kaggle TPU
- [24/02/2022] (1) Speed up training by pre-resized dataset, (2) enable to work in Kaggle TPU, (3) freeze backbone from param update, (4) get the loss to converge
- [25/02/2022] add LB submission pipeline

### **📚 REFERENCES**
- [Flax & JAX training loop example from huggingface](https://github.com/huggingface/transformers/blob/master/examples/flax/image-captioning/run_image_captioning_flax.py)
- [Flax & JAX training loop example from Flax repo](https://github.com/google/flax/blob/d068512a932da3e05b822790a591bac391aeab36/examples/nlp_seq/train.py#L206-L207)
- [Sentiment CLF - JAX/FLAX on TPUs + 🤗 + W&B 🚀](https://www.kaggle.com/heyytanay/sentiment-clf-jax-flax-on-tpus-w-b)
- [[Pytorch|0.374] 🐳 EffB5 KFold with FocalLoss](https://www.kaggle.com/snoop2head/pytorch-0-374-effb5-kfold-with-focalloss)
- [Train EffNet in TF TPU](https://www.kaggle.com/dschettler8845/baseline-solution-train-indiv-model#dataset_preparation)
- [Resize Images of Happy Whales and Dolphins](https://www.kaggle.com/gpreda/resize-images-of-happy-whales-and-dolphins)
- [How to freeze layer or detach a layer in training stage](https://github.com/google/flax/issues/825)
- [How to checkpoint Flax model](https://github.com/google/flax/discussions/1876)

---


### 0. Pre-setup

In [None]:
%%capture
# upgrade jax, jaxlib, flax is essential, otherwise it may fail to load pretrained model
import os
import requests
import torch

# Kaggle mode
if os.path.isdir("/kaggle/input"):

    # TPU mode
    if "TPU_NAME" in os.environ:
        print("TPU & Kaggle mode: Upgrade JAX and set up TPU")
        !pip install --upgrade jax
        !pip install --upgrade jaxlib
        !pip install git+https://github.com/deepmind/optax.git
        !pip install flax
        
        from jax.config import config
        if 'TPU_DRIVER_MODE' not in globals():
            url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
            resp = requests.post(url)
            TPU_DRIVER_MODE = 1
        config.FLAGS.jax_xla_backend = "tpu_driver"
        config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
        print('Registered TPU:', config.FLAGS.jax_backend_target)
        
    # GPU mode
    elif torch.cuda.is_available():
        print("GPU & Kaggle mode: Upgrade JAX specific to CUDA and CuDNN")
        !pip install --upgrade "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
        !pip install --upgrade jaxlib
        !pip install git+https://github.com/deepmind/optax.git
        !pip install flax
            
    # CPU mode
    else:
        raise NotImplementedError("CPU & Kaggle mode is not supported yet")
#         print("CPU & Kaggle mode: Upgrade JAX")
#         !pip install jax
#         !pip install --upgrade jaxlib
#         !pip install git+https://github.com/deepmind/optax.git
#         !pip install flax
        
# Colab mode
else:
    !pip install kaggle
    
    # TPU mode
    if "TPU_NAME" in os.environ:
        print("TPU & Colab mode: Set up TPU devices")
        import jax.tools.colab_tpu
        jax.tools.colab_tpu.setup_tpu()

!pip install flax
!pip install jax-resnet

In [None]:
import os
import joblib
from typing import *
from functools import partial

from tqdm import tqdm
import pandas as pd
import numpy as np
from PIL import Image
from sklearn.preprocessing import LabelEncoder

# for data loading
import torch
from torch.utils.data import Dataset, DataLoader
# for augmentation
from albumentations import *

import optax
import flax
from flax import linen as nn
from flax.core import frozen_dict
from flax.core import FrozenDict
from flax.jax_utils import replicate, unreplicate
from flax.training import train_state, checkpoints
from flax.training.common_utils import shard, shard_prng_key

import jaxlib
import jax
import jax.numpy as jnp
from jax import lax
from jax_resnet import pretrained_resnet, slice_variables, Sequential

print(f"Flax version: {flax.__version__}")
print(f"Optax version: {optax.__version__}")
print(f"JAX version: {jax.__version__}")
print(f"JAX lib version: {jaxlib.__version__}")
print(f"JAX device: {jax.devices()}")

### **1. Configuration**

In [None]:
EPOCH_N = 20
FIRST_N_ITER = None  # early stop on iterations for testing
LEARNING_RATE = 0.05
# GPU: (128, 128) -> 64, (256, 256) -> 16
# TPU: (128, 128) -> 64*8*2
BATCH_SIZE = 64*8*2
NUM_WORKERS = 2
PRE_RESIZE = True # if so, no need resizing during dataloading
FREEZE_BACKBONE = True
MOCK_DATASET = False  # control dataloading for speed test
ROOT_DIR = '/kaggle/input/happy-whale-and-dolphin'
TRAIN_CSV = os.path.join(ROOT_DIR, 'train.csv')
TEST_CSV = os.path.join(ROOT_DIR, 'sample_submission.csv')
TRAIN_DIR = '/kaggle/input/resize-images-of-happy-whales-and-dolphins/train_images_128'
TEST_DIR = '/kaggle/input/resize-images-of-happy-whales-and-dolphins/test_images_128'

IMAGE_SIZE: Optional[Tuple[int, int]] = (128, 128)
# ref: https://pytorch.org/vision/stable/models.html
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

OUTPUT_N = 15587
MODEL_ARCH = 'resnet18'
MODEL_SAVE_DIR = 'resnet_checkpoints'

# validator
assert all([IMAGE_SIZE[0] == s for s in IMAGE_SIZE]), "must be equal"

### **2. Data Preproc and Loading**

In [None]:
class HappyWhaleDataset(Dataset):
    """
    borrow from:
    - https://www.kaggle.com/snoop2head/pytorch-0-374-effb5-kfold-with-focalloss
    """
    def __init__(self, df, transforms=None, mock_dataset=False, inference=False):
        if mock_dataset:
            print(f"Mocking the dataset for testing...")
        
        self.mock_dataset = mock_dataset
        self.transforms = transforms
        self.df = df
        self.file_names = df.file_path.values
        if inference:
            self.species = None
            self.labels = None
        else:
            self.species = df.species.values
            self.labels = df.individual_id.values

    def __getitem__(self, index):
        getter = self.__get_mock_item if self.mock_dataset else self.__get_item
        return getter(index)
    
    def __get_mock_item(self, index):
        mock_img = np.ones((*IMAGE_SIZE, 3))
        if self.labels is None and self.species is None:
            return mock_img
        mock_label = np.random.randint(0, OUTPUT_N)
        return mock_img, mock_label
    
    def __get_item(self, index):
        image_path = self.file_names[index]
        image = Image.open(image_path)
        image = np.array(image)
        
        # take care of grayscale image by adding channel
        if len(image.shape) == 2: 
            image = np.dstack((image,)*3)        
        if self.transforms:
            image= self.transforms(image=image)['image']
        
        if self.labels is None and self.species is None:
            return image
        
        label = self.labels[index]
        return image, label

    def __len__(self):
        return len(self.df)

    def set_transform(self, transform):
        self.transform = transform


def get_train_file_path(filename, image_dir):
    return os.path.join(image_dir, filename)

In [None]:
df = pd.read_csv(TRAIN_CSV)
df['file_path'] = df['image'].apply(
    partial(get_train_file_path, image_dir=TRAIN_DIR)
)
print(f"No. of unique labels: {len(df.individual_id.unique())}")

# encode label from str to integer
label_encoder = LabelEncoder()
label_encoder.fit(df.individual_id)
df.individual_id = label_encoder.transform(df.individual_id)
with open("label_encoder.pkl", "wb") as p:
    joblib.dump(label_encoder, p)

assert len(df.individual_id.unique()) == OUTPUT_N
    
df.head(3)

In [None]:
# define transformations applied on dataset
if not PRE_RESIZE:
    assert IMAGE_SIZE is not None
transforms_ls = [Resize(*IMAGE_SIZE, p=1.0)] if not PRE_RESIZE else []
transforms_ls += [Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD, max_pixel_value=255.0, p=1.0)]
transforms = Compose(transforms_ls)

# set dataset and dataloader
train_dataset = HappyWhaleDataset(
    df, transforms=transforms,
    mock_dataset=MOCK_DATASET,
    inference=False
)
# default collate_func converts numpy to tensor in dataloading
# reference: https://discuss.pytorch.org/t/torch-dataloader-gives-torch-tensor-as-ouput/31084/6
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS, shuffle=True
)

# sanity check
x, _ = train_dataset[0]
batch, labels = next(iter(train_loader))

assert x.shape[:2] == IMAGE_SIZE and batch.shape[0] == BATCH_SIZE
print(f"Shape for one batch: {batch.shape}")

### **3. Model Construction**

In [None]:
# @TODO fill in later
class MarginLayer(nn.Module):
    """
    reference:
    - https://www.kaggle.com/ragnar123/unsupervised-baseline-arcface
    - https://arxiv.org/pdf/1801.07698.pdf
    """
    @nn.compact
    def __call__(self, inputs):
        raise NotImplementedError


class Head(nn.Module):
    """
    references:
    - fastai
    - https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html
    """
    batch_norm_cls: partial = partial(nn.BatchNorm, momentum=0.9)
        
    @nn.compact
    def __call__(self, inputs, train: bool):
        """ param train determines dropout and batchnorm """
        output_n = inputs.shape[-1]
        x = self.batch_norm_cls(use_running_average=not train)(inputs)
        x = nn.Dropout(rate=0.25)(x, deterministic=not train)
        x = nn.Dense(features=output_n)(x)
        x = nn.relu(x)
        x = self.batch_norm_cls(use_running_average=not train)(x)
        x = nn.Dropout(rate=0.5)(x, deterministic=not train)
        x = nn.Dense(features=OUTPUT_N)(x)
        return x

    
class HappyWhaleModel(nn.Module):
    backbone: Sequential
    head: Head
        
    def __call__(self, inputs, train: bool):
        x = self.backbone(inputs)
        # average pool layer
        x = jnp.mean(x, axis=(1, 2))
        x = self.head(x, train)
        return x

    
def _get_backbone_and_params(model_arch: str) -> Tuple[nn.Module, FrozenDict]:
    # get model & param structure for pretrained model
    if model_arch == 'resnet18':
        resnet_tmpl, params = pretrained_resnet(18)
        model = resnet_tmpl()
    else:
        raise NotImplementedError
        
    # get model & param structure for backbone
    start, end = 0, len(model.layers) - 2
    backbone = Sequential(model.layers[start:end])
    backbone_params = slice_variables(params, start, end)
    return backbone, backbone_params


def get_model_and_variables(model_arch: str, head_init_key: int) -> Tuple[nn.Module, FrozenDict]:
    """
    variables is a composition of params and batch state
    """
    inputs = jnp.ones((1, *IMAGE_SIZE, 3), jnp.float32)
    key = jax.random.PRNGKey(head_init_key)
    # get backbone
    backbone, backbone_params = _get_backbone_and_params(model_arch)
    
    # determine input size for head model
    backbone_output = backbone.apply(backbone_params, inputs, mutable=False)
    head_inputs = jnp.ones((1, backbone_output.shape[-1]), jnp.float32)
    # get head
    head = Head()
    head_params = head.init(key, head_inputs, train=False)
    
    # get final model
    model = HappyWhaleModel(backbone, head)
    # combine params from backbone and head
    variables = FrozenDict({
        'params': {
            'backbone': backbone_params['params'],
            'head': head_params['params']
        },
        'batch_stats': {
            'backbone': backbone_params['batch_stats'],
            'head': head_params['batch_stats']
        }
    })
    return model, variables

In [None]:
%%capture
# sanity check
model, variables = get_model_and_variables('resnet18', 0)
inputs = jnp.ones((1, *IMAGE_SIZE, 3), jnp.float32)
key = jax.random.PRNGKey(0)
o = model.apply(variables, inputs, train=False, mutable=False)

assert o.shape[0] == inputs.shape[0] and o.shape[-1] == OUTPUT_N

### **4. Loss Function and Metrics**

In [None]:
def topk_accuracy(logits, labels, top_k = (1,)) -> List[jnp.DeviceArray]:
    """
    reference:
    - https://www.kaggle.com/snoop2head/pytorch-0-374-effb5-kfold-with-focalloss
    
    logits: (bs, class_n)
    labels: (bs, )
    """
    max_k = max(top_k)
    batch_size = labels.shape[0]
    
    _, preds = jax.lax.top_k(logits, k = max_k)
    # (max_k, bs)
    preds = preds.transpose()
    # convert labels to have shape: (max_k, bs)
    correct = (preds == jnp.repeat(labels[None, :], repeats=preds.shape[0], axis = 0) * 1.)
    
    accuracies = []
    for k in top_k:
        _accuracy = correct[:k].sum() * 100. / batch_size
        accuracies.append(_accuracy)
    return accuracies


# inputs: (logits, labels)
loss_fn = optax.softmax_cross_entropy
# inputs: (logits, labels)
eval_fn = partial(topk_accuracy, top_k = (1, 3))

### **5. Optimizer and TrainState**

In [None]:
class TrainState(train_state.TrainState):
    batch_stats: FrozenDict
    loss_fn: Callable = flax.struct.field(pytree_node=False)
    eval_fn: Callable = flax.struct.field(pytree_node=False)


def create_mask(params, label_fn):
    """
    reference:
    - https://colab.research.google.com/drive/16wcmLt0AIKzMmLPrliuBMfmBvM8VEc4p#scrollTo=TqDvTL_tIQCH
    """
    def _map(params, mask, label_fn):
        for k in params:
            if label_fn(k):
                mask[k] = 'zero'
            else:
                if isinstance(params[k], FrozenDict):
                    mask[k] = {}
                    _map(params[k], mask[k], label_fn)
                else:
                    mask[k] = 'adam'
    mask = {}
    _map(params, mask, label_fn)
    return frozen_dict.freeze(mask)


def zero_grads():
    """
    reference:
    - https://github.com/deepmind/optax/issues/159#issuecomment-896459491
    """
    def init_fn(_): 
        return ()
    def update_fn(updates, state, params=None):
        return jax.tree_map(jnp.zeros_like, updates), ()
    return optax.GradientTransformation(init_fn, update_fn)

In [None]:
adamw = optax.adamw(
    learning_rate=LEARNING_RATE,
    b1=0.9, b2=0.999, 
    eps=1e-6, weight_decay=1e-2
)

# freeze backbone by masking target params
if FREEZE_BACKBONE:
    # label as zero if params belong to "backbone" key
    optimizer = optax.multi_transform(
        {'adam': adamw, 'zero': zero_grads()},
        create_mask(variables['params'], lambda s: s.startswith('backbone'))
    )
# update all params
else:
    optimizer = adamw

state = TrainState.create(
    apply_fn = model.apply,
    params = variables['params'],
    tx = optimizer,
    batch_stats = variables['batch_stats'],
    loss_fn = loss_fn,
    eval_fn = eval_fn
)

### **6. Training Loop**

In [None]:
def train_step(state: TrainState, batch, labels, dropout_rng) -> Tuple[TrainState, dict, jnp.DeviceArray]:
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
    
    # params as input because we differentiate wrt it 
    def loss_function(params):
        # if you set state.params, then params can't be backpropagated through!
        variables = {'params': params, 'batch_stats': state.batch_stats}
        # return mutated states if mutable is specified
        logits, new_batch_stats = state.apply_fn(
            variables, batch, train=True, 
            mutable=['batch_stats'],
            rngs={'dropout': dropout_rng}
        )
        # logits: (BS, OUTPUT_N), one_hot: (BS, OUTPUT_N)
        one_hot = jax.nn.one_hot(labels, OUTPUT_N)
        loss = state.loss_fn(logits, one_hot).mean()
        return loss, (logits, new_batch_stats)
    
    # if you wanna vary lr per step
    #lr = learning_rate_fn(state.step)
    
    # backpropagation and update params & batch_stats 
    grad_fn = jax.value_and_grad(loss_function, has_aux=True)
    (loss, aux), grads = grad_fn(state.params)
    logits, new_batch_stats = aux
    grads = lax.pmean(grads, axis_name='batch')
    new_state = state.apply_gradients(
        grads=grads, batch_stats=new_batch_stats['batch_stats']
    )
    
    # evaluation metrics
    accuracy = state.eval_fn(logits, labels)
    
    # store metadata
    metadata = jax.lax.pmean(
        {'loss': loss, 'accuracy': accuracy},
        axis_name='batch'
    )
    return new_state, metadata, new_dropout_rng


def val_step(state: TrainState, batch, labels):
    variables = {'params': state.params, 'batch_stats': state.batch_stats}
    logits = state.apply_fn(variables, batch, train=False)
    return state.eval_fn(logits, labels)

In [None]:
parallel_train_step = jax.pmap(train_step, axis_name='batch', donate_argnums=(0,))
parallel_val_step = jax.pmap(val_step, axis_name='batch', donate_argnums=(0,))

# required for parallelism
state = replicate(state)

# control randomness on dropout and update inside train_step
rng = jax.random.PRNGKey(0)
dropout_rng = jax.random.split(rng, jax.local_device_count())  # for parallelism

In [None]:
for epoch_i in tqdm(range(EPOCH_N), desc=f"{EPOCH_N} epochs", position=0, leave=True):
    # training set
    train_loss, train_accuracy = [], []
    iter_n = len(train_dataset)//BATCH_SIZE
    
    # @TODO: print out from tqdm can be hard to read, find an alternative
    with tqdm(total=iter_n, desc=f"{iter_n} iterations", leave=False) as progress_bar:
        for iter_i, (batch, labels) in enumerate(train_loader):
            
            # unavoid reminder cant be evenly distributed across tpu devices
            if iter_i + 1 > iter_n:
                break
            if (epoch_i == 0) and (FIRST_N_ITER is not None) and (iter_i >= FIRST_N_ITER):
                print(f"Early stop at iteration: {iter_i+1}")
                break
                
            # shard to enable parallelism
            batch, labels = shard(batch), shard(labels)
            batch = jnp.array(batch, dtype=jnp.float32)
            labels = jnp.array(labels, dtype=jnp.float32)
            
            # backprop and update param & batch stats
            state, train_metadata, dropout_rng = parallel_train_step(state, batch, labels, dropout_rng)
            train_metadata = unreplicate(train_metadata)
            
            # update train statistics
            _train_loss, _train_top1_acc, _train_top3_acc = map(float, [train_metadata['loss'], *train_metadata['accuracy']])
            train_loss.append(_train_loss)
            train_accuracy.append(_train_top1_acc)
            progress_bar.update(1)
    
    avg_train_loss = sum(train_loss)/len(train_loss)
    avg_train_acc = sum(train_accuracy)/len(train_accuracy)
    print(f"[{epoch_i+1}/{EPOCH_N}] Train Loss: {avg_train_loss:.03} | Train Accuracy: {avg_train_acc:.03}")
    
    # validation set
    pass

In [None]:
state = unreplicate(state)

# checkpoint the model
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
checkpoints.save_checkpoint(ckpt_dir=MODEL_SAVE_DIR, target=state,
                            step=int(state.step), overwrite=True)
#restored_state = checkpoints.restore_checkpoint(ckpt_dir=MODEL_SAVE_DIR, target=state)
print(f"Model checkpoint saved: {MODEL_SAVE_DIR}")

### 5. Submission

In [None]:
def infer_step(state: TrainState, batch):
    """
    return top5 predicted index for each sample
    """
    variables = {'params': state.params, 'batch_stats': state.batch_stats}
    logits = state.apply_fn(variables, batch, train=False)
    _, top_preds = jax.lax.top_k(logits, k = 5)
    return top_preds


# load csv for submission
test_df = pd.read_csv(TEST_CSV)
test_df['file_path'] = test_df['image'].apply(
    partial(get_train_file_path, image_dir=TEST_DIR)
)
print(f"No. of Submission Entries: {len(test_df)}")

# setup dataloader for submission
test_dataset = HappyWhaleDataset(
    test_df, transforms=transforms, 
    mock_dataset=MOCK_DATASET,
    inference=True
)
test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS, shuffle=False
)

# ready for running batch
parallel_infer_step = jax.pmap(infer_step, axis_name='batch', donate_argnums=(0,))
state = replicate(state)

In [None]:
# batch processing samples for submission
top5_idxs = []
iter_n = len(test_dataset)//BATCH_SIZE
for iter_i, batch in tqdm(enumerate(test_loader)):
    if iter_i + 1 > iter_n:
        state = unreplicate(state)
        _batch_top5_idxs = infer_step(state, batch)
        # make this cell block re-runnable
        state = replicate(state)
    else:
        batch = shard(batch)
        batch = jnp.array(batch, dtype=jnp.float32)
        _batch_top5_idxs = parallel_infer_step(state, batch)
        _batch_top5_idxs = _batch_top5_idxs.reshape(-1, 5)
    top5_idxs.append(_batch_top5_idxs)

top5_idxs = jnp.concatenate(top5_idxs)
assert top5_idxs.shape[0] == test_df.shape[0]

In [None]:
# label decoding
prediction_parser = lambda idxs: ' '.join(label_encoder.inverse_transform(idxs)).strip()
top5_labels = list(map(prediction_parser, top5_idxs))
test_df['predictions'] = top5_labels

# write the same csv
test_df.drop(columns=['file_path'], inplace=True)
test_df.to_csv('submission.csv', index=False)
print(f"Submission csv written!!")