<div class="alert alert-info">
    <h2 align='center'>🍀 Herbarium Jax/Flax Training - 🤗 + KFolds + W&B  Tracking 📊</h1>
</div>

<p style='text-align: center'>
    JAX/Flax training notebook that fine-tunes Vision Transformer from 🤗 transformers also with Weights and Biases tracking. <br> Do keep in mind, this *is* a proper training notebook but I have set it to train only for 1 epoch only for the first fold since there is so much data, it would be otherwise quite hard to train for all folds with more than one epochs.<br>
</p>

<h1 style='color: #fc0362; font-family: Segoe UI; font-size: 1.5em; font-weight: 300; font-size: 24px'>If you liked this notebook, kindly leave an upvote ⬆️</h1>

#### Attribution
This notebook takes a lot of functions and major inspiration from HuggingFace's Flax Vision example [here](https://github.com/huggingface/transformers/blob/master/examples/flax/vision/run_image_classification.py)

<h1 align='center' style='color: #8532a8; font-family: Segoe UI; font-size: 1.5em; font-weight: 300; font-size: 32px'>1. Installation & Imports 📩</h1>

In [None]:
%%sh
pip install -q wandb
pip install -q transformers

In [None]:
import json
import os
from typing import Callable

import wandb
import numpy as np
import pandas as pd
from PIL import Image
import logging

import torch
import torchvision.transforms as transforms
from tqdm.notebook import tqdm
from sklearn.model_selection import StratifiedKFold

import jax
import jax.numpy as jnp
import optax
import transformers
from flax import jax_utils
from flax.jax_utils import unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from transformers import AutoConfig, FlaxAutoModelForImageClassification

# To keep out those nasty warnings
import warnings
warnings.simplefilter('ignore')
transformers.utils.logging.set_verbosity(transformers.logging.ERROR)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

In [None]:
Config = {
    'MODEL_NAME': 'google/vit-base-patch16-224-in21k',
    'JSON_PATH': "../input/herbarium-2022-fgvc9/train_metadata.json",
    'IMG_PATH': "../input/herbarium-2022-fgvc9/train_images",
    'NUM_LABELS': 15501,
    'N_SPLITS': 5,
    'TRAIN_BS': 32,
    'VALID_BS': 32,
    'N_EPOCHS': 2,
    'NUM_WORKERS': 4,
    'LR': 1e-3,
    'IMG_SIZE': 224,
    'infra': "Kaggle",
    'competition': 'herbarium_2022',
    '_wandb_kernel': 'tanaym',
    "wandb": True,
}

<h1 align='center' style='color: #8532a8; font-family: Segoe UI; font-size: 1.5em; font-weight: 300; font-size: 32px'>2. About Weights and Biases 📊</h1><center><br>
<p style="text-align:center">WandB is a developer tool for companies turn deep learning research projects into deployed software by helping teams track their models, visualize model performance and easily automate training and improving models.
We will use their tools to log hyperparameters and output metrics from your runs, then visualize and compare results and quickly share findings with your colleagues.<br><br></p>

To login to W&B, you can use below snippet.

```python
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wb_key = user_secrets.get_secret("WANDB_API_KEY")

wandb.login(key=wb_key)
```
Make sure you have your W&B key stored as `WANDB_API_KEY` under Add-ons -> Secrets

You can view [this](https://www.kaggle.com/ayuraj/experiment-tracking-with-weights-and-biases) notebook to learn more about W&B tracking.

If you don't want to login to W&B, the kernel will still work and log everything to W&B in anonymous mode.

In [None]:
# Start W&B logging
if Config['wandb']:
    run = wandb.init(
        project='jax',
        config=Config,
        group='vision',
        job_type='train',
        anonymous='must'
    )

<h1 align='center' style='color: #8532a8; font-family: Segoe UI; font-size: 1.5em; font-weight: 300; font-size: 32px'>3. Utility functions ⚒️</h1>

In [None]:
# Some utility functions
def wandb_log(**kwargs):
    """
    Logs a key-value pair to W&B
    """
    for k, v in kwargs.items():
        wandb.log({k: v})

def load_json(json_file):
    """
    Loads metadata json file and returns a processed pandas dataframe
    """
    with open(json_file, "r", encoding="ISO-8859-1") as file:
        train = json.load(file)
    train_img = pd.DataFrame(train['images'])
    train_ann = pd.DataFrame(train['annotations'])
    train_df = train_img.merge(train_ann, on='image_id')
    return train_df

def create_learning_rate_fn(
    train_ds_size: int, 
    train_batch_size: int, 
    num_train_epochs: int, 
    num_warmup_steps: int, 
    learning_rate: float
) -> Callable[[int], jnp.array]:
    
    """Returns a linear warmup, linear_decay learning rate function."""
    steps_per_epoch = train_ds_size // train_batch_size
    num_train_steps = steps_per_epoch * num_train_epochs
    warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
    decay_fn = optax.linear_schedule(
        init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
    )
    schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
    return schedule_fn

def loss_fn(logits, labels):
    loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
    return loss.mean()

# Define gradient update step fn
# Pretty self explanatory for the most part if you have some exposure in deep learning
def train_step(state, batch):
    dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)

    def compute_loss(params):
        labels = batch.pop("labels")
        logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
        loss = loss_fn(logits, labels)
        return loss

    grad_fn = jax.value_and_grad(compute_loss)
    loss, grad = grad_fn(state.params)
    grad = jax.lax.pmean(grad, "batch")

    new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)

    metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
    metrics = jax.lax.pmean(metrics, axis_name="batch")

    return new_state, metrics

# Define validation function
def valid_step(params, batch):
    labels = batch.pop("labels")
    logits = model(**batch, params=params, train=False)[0]
    loss = loss_fn(logits, labels)

    # summarize metrics
    accuracy = (jnp.argmax(logits, axis=-1) == labels).mean()
    metrics = {"loss": loss, "accuracy": accuracy}
    metrics = jax.lax.pmean(metrics, axis_name="batch")
    return metrics

<h1 align='center' style='color: #8532a8; font-family: Segoe UI; font-size: 1.5em; font-weight: 300; font-size: 32px'>4. Custom dataset class & Augmentations 🖼️</h1>

In [None]:
class HerbariumData(torch.utils.data.Dataset):
    """
    Custom dataset class for this competition's data
    """
    def __init__(self, df, labels=None, augments=True, is_test=False):
        self.df = df
        self.labels = labels
        self.augments = augments
        self.is_test = is_test
        
    def __getitem__(self, idx):
        file_name = self.df['file_name'].values[idx]
        file_path = os.path.join(Config['IMG_PATH'], file_name)
        image = Image.open(file_path)
        
        if self.augments:
            image = self.augments(image)
        
        if not self.is_test:
            labels = self.labels.values[idx]
            return image, labels
        return image
    
    def __len__(self):
        return len(self.df)

In [None]:
def collate_fn(samples):
    pixel_values = torch.stack([sample[0] for sample in samples])
    labels = torch.tensor([sample[1] for sample in samples])

    batch = {"pixel_values": pixel_values, "labels": labels}
    batch = {k: v.numpy() for k, v in batch.items()}

    return batch

In [None]:
class Augments:
    def train_augments():
        return transforms.Compose(
            [
                transforms.RandomResizedCrop(Config['IMG_SIZE']),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ]
        )
    def valid_augments():
        return transforms.Compose(
            [
                transforms.Resize(Config['IMG_SIZE']),
                transforms.CenterCrop(Config['IMG_SIZE']),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ]
        )

<h1 align='center' style='color: #8532a8; font-family: Segoe UI; font-size: 1.5em; font-weight: 300; font-size: 32px'>5. Training 🚀</h1>

<img src="https://upload.wikimedia.org/wikipedia/commons/thumb/3/3e/Vision_Transformer.gif/675px-Vision_Transformer.gif">

In [None]:
# Create a model Config
config = AutoConfig.from_pretrained(
    Config['MODEL_NAME'],
    num_labels=Config['NUM_LABELS'],
    image_size=Config['IMG_SIZE'],
)

A simple training state for a single optax optimizer.

In [None]:
class TrainState(train_state.TrainState):
    dropout_rng: jnp.ndarray

    def replicate(self):
        return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))

A quick overview of what I'm doing in the training cell down below;
1. Reading the JSON file into a pandas dataframe and splitting it into k-folds
2. Getting the current fold's split of training and validation data
3. Making a dataset out of the training and validation pandas dataframes and then making dataloaders of them
4. Defining the model with the proper parameters and the config we defined above
5. Initializing jax-specific variables such as random key and linear decay learning rate schedule function
6. Defining Adam optimizer using Optax
7. Creating parallel versions of the training and validation step functions (for multi-device functionality)
8. Making a train state and replicating it on each device (for multi-device functionality)
9. Now running the epochs, inside it, we are running training and validation loops over all batches
10. Inside a typical logic, we have;
    * We shard (basically, breaking into small subsets) the batch data
    * Pass the current state and the sharded batch data through corresponding step function (train or valid)
    * Getting the metric from the output and printing it along with logging it to WandB

In [None]:
# Main trainer
if __name__ == "__main__":
    kf = StratifiedKFold(n_splits=Config['N_SPLITS'])
    train_file = load_json(Config['JSON_PATH'])
    
    for fold_, (train_idx, valid_idx) in enumerate(kf.split(X=train_file, y=train_file['category_id'])):
        # Only training for one fold since the data is huge and I don't want you all to wait an eternity
        if fold_ != 0:
            continue
        print(f"{'='*40} Fold: {fold_+1} / {Config['N_SPLITS']} {'='*40}")
        
        train_ = train_file.loc[train_idx]
        valid_ = train_file.loc[valid_idx]
        
        # Create train and validation dataloaders
        train_dataset = HerbariumData(train_, labels=train_['category_id'], augments=Augments.train_augments())
        valid_dataset = HerbariumData(valid_, labels=valid_['category_id'], augments=Augments.valid_augments())

        train = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=Config['TRAIN_BS']*jax.device_count(),
            shuffle=True,
            pin_memory=True,
            num_workers=Config['NUM_WORKERS'],
            collate_fn=collate_fn,
        )
        valid = torch.utils.data.DataLoader(
            valid_dataset,
            batch_size=Config['VALID_BS']*jax.device_count(),
            shuffle=False,
            num_workers=Config['NUM_WORKERS'],
            collate_fn=collate_fn,
        )
        
        # Instantiate the model
        model = FlaxAutoModelForImageClassification.from_pretrained(
            Config['MODEL_NAME'], config=config, seed=42, dtype='float32'
        )
        
        # Initialize our training parameters
        rng = jax.random.PRNGKey(42)
        rng, dropout_rng = jax.random.split(rng)

        # Create learning rate schedule
        linear_decay_lr_schedule_fn = create_learning_rate_fn(
            len(train_dataset),
            Config['TRAIN_BS']*jax.device_count(),
            Config['N_EPOCHS'],
            0,
            Config['LR'],
        )

        # Create Adam optimizer
        adamw = optax.adam(
            learning_rate=linear_decay_lr_schedule_fn,
        )
        
        # Create parallel version of the train and eval step
        p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
        p_valid_step = jax.pmap(valid_step, "batch")

        # Replicate the train state on each device
        state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
        state = state.replicate()
                
        for epoch in range(Config['N_EPOCHS']):
            print(f"{'-'*20} Epoch: {epoch+1} / {Config['N_EPOCHS']} {'-'*20}")
            rng, input_rng = jax.random.split(rng)

            # Training
            train_metrics = []
            steps_per_epoch = len(train_dataset) // (Config['TRAIN_BS']*jax.device_count())
            train_prog = tqdm(train, total=steps_per_epoch)

            for batch in train_prog:
                batch = shard(batch)
                state, train_metric = p_train_step(state, batch)
                train_metrics.append(train_metric)

                train_prog.set_description(f"loss: {train_metric['loss'].tolist()[0]:.4f}")
                
                # Log to wandb
                if Config['wandb']:
                    wandb_log(
                        train_loss=train_metric['loss'].tolist()[0]
                    )
            train_metric = unreplicate(train_metric)
            train_prog.close()
            print(f"Train loss at Epoch {epoch+1}: {train_metric['loss']}")

            # Evaluating
            # Not doing the valid part, only 2 epochs of training since it takes a lot of time.
            continue
            valid_metrics = []
            steps_per_epoch = len(valid_dataset) // (Config['VALID_BS']*jax.device_count())
            valid_prog = tqdm(valid, total=steps_per_epoch)

            for batch in valid_prog:
                batch = shard(batch)
                metric = p_valid_step(state.params, batch)
                valid_metrics.append(metric)

                valid_prog.set_description(f"val_loss: {metric['loss'].tolist()[0]:.4f}")
                
                # Log to wandb
                if Config['wandb']:
                    wandb_log(
                        train_loss=metric['loss'].tolist()[0]
                    )

            # Normalize eval metrics
            valid_metrics = get_metrics(valid_metrics)
            valid_metrics = jax.tree_map(jnp.mean, valid_metrics)

            valid_prog.close()
            print(f"Valid loss at Epoch {epoch+1}: {metric['loss'][0]}")

In [None]:
# Code taken from https://www.kaggle.com/ayuraj/interactive-eda-using-w-b-tables

# Finish the logging run
if Config['wandb']:
    run.finish()

<center>
<img src="https://img.shields.io/badge/Upvote-If%20you%20like%20my%20work-07b3c8?style=for-the-badge&logo=kaggle">
</center>