### Intro to JAX
[JAX](https://github.com/google/jax) is a framework which is used for high-performance numerical computing and machine learning research developed at [Google Research](https://research.google/) teams. It allows you to build Python applications with a NumPy-consistent API that specializes in differentiating, vectorizing, parallelizing, and compiling to GPU/TPU Just-In-Time. JAX was designed with performance and speed as a first priority, and is natively compatible with common machine learning accelerators such as [GPUs](https://www.kaggle.com/docs/efficient-gpu-usage) and [TPUs](https://www.kaggle.com/docs/tpu). Large ML models can take ages to train -- you might be interested in using JAX for applications where speed and performance are particularly important!
### When to use JAX vs TensorFlow?
[TensorFlow](https://www.tensorflow.org/guide) is a fantastic product, with a rich and fully-featured ecosystem, capable of supporting most every use case a machine learning practitioner might have (e.g. [TFLite](https://www.tensorflow.org/lite) for on-device inference computing, [TFHub](https://tfhub.dev/) for sharing pre-trained models, and many additional specialized applications as well). This type of broad mandate both contrasts and compliments JAX's philosophy, which is more narrowly focused on speed and performance.  We recommend using JAX in situations where you do want to maximize speed and performance but you do not require any of the long tail of features and additional functionalities that only the [TensorFlow ecosystem](https://www.tensorflow.org/learn) can provide.
### Intro to the FLAX
Just like [JAX](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) focuses on speed, other members of the JAX ecosystem are encouraged to specialize as well.  For example, [Flax](https://flax.readthedocs.io/en/latest/) focuses on neural networks and [jgraph](https://github.com/deepmind/jraph) focuses on graph networks.  

[Flax](https://flax.readthedocs.io/en/latest/) is a JAX-based neural network library that was initially developed by  Google Research's Brain Team (in close collaboration with the JAX team) but is now open source.  If you want to train machine learning models on GPUs and TPUs at an accelerated speed, or if you have an ML project that might benefit from bringing together both [Autograd](https://github.com/hips/autograd) and [XLA](https://www.tensorflow.org/xla), consider using [Flax](https://flax.readthedocs.io/en/latest/) for your next project! [Flax](https://flax.readthedocs.io/en/latest/) is especially well-suited for projects that use large language models, and is a popular choice for cutting-edge [machine learning research](https://arxiv.org/search/?query=JAX&searchtype=all&abstracts=show&order=-announced_date_first&size=50).

### Disclaimer:
**We recommend using [GPUs](https://www.kaggle.com/docs/efficient-gpu-usage) when working with JAX on Kaggle.** These notebooks are compatible with the v3-8 [TPUs](https://www.kaggle.com/docs/tpu) that are provided for free in [Kaggle Notebooks](https://www.kaggle.com/code/new), but JAX was optimized for the newly updated [TPU VM](https://cloud.google.com/blog/products/compute/introducing-cloud-tpu-vms) architecture which is not yet available on Kaggle.


## Imports

In [None]:
#Uncomment and Run when only accelerator is TPU
#%%capture
#!conda install -y -c conda-forge jax jaxlib flax optax datasets transformers
#!conda install -y importlib-metadata

In [None]:
!pip install datasets transformers

In [None]:
# Importing all the libraries necessary for the project
import os
import time
import jax
import flax
import optax
import datasets
import pandas as pd 
import numpy as np
from jax import jit
import jax.numpy as jnp
import tensorflow as tf
from flax.training import train_state
from itertools import chain
from tqdm.notebook import tqdm
from typing import Callable
from flax import traverse_util
from datasets import load_dataset, load_metric ,Dataset,list_metrics,load_from_disk
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from transformers import FlaxAutoModelForSequenceClassification, AutoConfig, AutoTokenizer, BertTokenizer
import warnings
warnings.filterwarnings("ignore")
# to suppress warnings caused by cuda version
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

### TPU detection and configuration
**We recommend using [GPUs](https://www.kaggle.com/docs/efficient-gpu-usage) when working with JAX on Kaggle.** These notebooks are compatible with the v3-8 [TPUs](https://www.kaggle.com/docs/tpu) that are provided for free in [Kaggle Notebooks](https://www.kaggle.com/code/new), but JAX was optimized for the newly updated [TPU VM](https://cloud.google.com/blog/products/compute/introducing-cloud-tpu-vms) architecture which is not yet available on Kaggle.


In [None]:
if 'TPU_NAME' in os.environ:
    import requests
    if 'TPU_DRIVER_MODE' not in globals():
        url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
        resp = requests.post(url)
        TPU_DRIVER_MODE = 1
    from jax.config import config
    config.FLAGS.jax_xla_backend = "tpu_driver"
    config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
    print('Registered TPU:', config.FLAGS.jax_backend_target)
else:
    print('No TPU detected.')

In [None]:
jax.local_devices()

## Load data and preprocess the data
Loading train data csv file using Huggingface's [`load_dataset`](https://huggingface.co/docs/datasets/loading_datasets.html) function from Dataset class

In [None]:
raw_train = load_dataset("csv", data_files={'train': ['../input/jigsaw-multilingual-toxic-comment-classification/jigsaw-toxic-comment-train.csv']})

Spliting the dataset into train and eval sets

In [None]:
raw_train = raw_train["train"].train_test_split(0.2)

In [None]:
raw_train

## Loading the model checkpoint and tokenizer

In [None]:
model_checkpoint = "bert-base-cased" 
tokenizer = BertTokenizer.from_pretrained(model_checkpoint,use_fast=True)

## Pre-process the dataset
Now, this function will preprocess the dataset by taking batch of data and returns the tokenized processed data

In [None]:
def preprocess_function(input_batch):
    '''
    INPUT - input batch from from original dataset
    RETURNS preprocessed data
    '''
    texts = (input_batch["comment_text"],)
    processed = tokenizer(*texts, padding="max_length", max_length=128, truncation=True)
    processed["labels"] = input_batch["toxic"]
    return processed

In [None]:
tokenized_dataset = raw_train.map(preprocess_function, batched=True, remove_columns=raw_train["train"].column_names)

In [None]:
tokenized_dataset

In [None]:
train_dataset = tokenized_dataset["train"]
validation_dataset = tokenized_dataset["test"]

In [None]:
train_dataset

In [None]:
validation_dataset

## Listing and selecting the metrics
Selecting the evaluation metrics using HuggingFace's [`load_metrics`](https://huggingface.co/docs/datasets/loading_metrics.html) 

In [None]:
metrics_list = list_metrics()
metrics_list

In [None]:
metric = load_metric('f1')
metric

## Model config
Defining all the model config parameters below

In [None]:
num_labels = 2 # 0-1
seed = 0
num_train_epochs = 5
learning_rate = 2e-5
per_device_batch_size = 32
weight_decay=1e-2

In [None]:
total_batch_size = per_device_batch_size * jax.local_device_count()
print("The overall batch size (both for training and eval) is", total_batch_size)

In [None]:
# Loading the config and the pre-trained model using HuggingFace's from_pretrained 
config = AutoConfig.from_pretrained(model_checkpoint, num_labels=num_labels)
model = FlaxAutoModelForSequenceClassification.from_pretrained(model_checkpoint, config=config, seed=seed)

In [None]:
num_train_steps = len(train_dataset) // total_batch_size * num_train_epochs
learning_rate_function = optax.cosine_onecycle_schedule(transition_steps=num_train_steps, peak_value=learning_rate, pct_start=0.1)

## Train state


In [None]:
class TrainState(train_state.TrainState):
    '''
    Derived TrainState class that saves the forward pass of the model as an eval function and a loss function
    '''
    logits_function: Callable = flax.struct.field(pytree_node=False)
    loss_function: Callable = flax.struct.field(pytree_node=False)

In [None]:
def decay_mask_fn(params):
    '''
    This function's task is to make sure that weight decay is not applies to any bias or Layernorm weights
    '''
    flat_params = traverse_util.flatten_dict(params)
    flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
    return traverse_util.unflatten_dict(flat_mask)

In [None]:
# Adam optimizer function using optax.adamw
def adamw(weight_decay):
    return optax.adamw(learning_rate=learning_rate_function, b1=0.9, b2=0.999, eps=1e-6, weight_decay=weight_decay,mask=decay_mask_fn)

In [None]:
adamw = adamw(weight_decay)

In [None]:
## Defining the loss and the evaluation function
@jit
def loss_function(logits, labels):
    xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels))
    return jnp.mean(xentropy)
 
@jit
def eval_function(logits):
    return logits.argmax(-1)

In [None]:
# Instantiate a TrainState.
state = TrainState.create(
    apply_fn=model.__call__,
    params=model.params,
    tx=adamw,
    logits_function=eval_function,
    loss_function=loss_function,
)

## Train and evaluate steps

In [None]:
def train_step(state, batch, dropout_rng):
    # take targets
    targets = batch.pop("labels")
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
    
    #define loss function which runs the forward pass 
    def loss_function(params):
        logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
        loss = state.loss_function(logits, targets)
        return loss
    
    
    grad_fn = jax.value_and_grad(loss_function) #differentiate the loss function
    loss, grad = grad_fn(state.params) 
    grad = jax.lax.pmean(grad, "batch") #compute the mean gradient over all devices 
    new_state = state.apply_gradients(grads=grad) #applies the gradients to the weights.
    metrics = jax.lax.pmean({'loss': loss, 'learning_rate': learning_rate_function(state.step)}, axis_name='batch')
    
    return new_state, metrics, new_dropout_rng

In [None]:
parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,)) # parallelized training over all TPU devices


In [None]:
# Define evaluation step
def eval_step(state, batch):
    logits = state.apply_fn(**batch, params=state.params, train=False)[0] #stack the model's forward pass with the logits function
    return state.logits_function(logits)

In [None]:
parallel_eval_step = jax.pmap(eval_step, axis_name="batch")

## Data loader

In [None]:
# Returns batch model input
# 1. define random permutation 
# 2. randomized dataset is extracted and then it converted to a JAX array and sharded over all local TPU devices.
def train_data_loader(rng, dataset, batch_size):
    steps_per_epoch = len(dataset) // batch_size
    perms = jax.random.permutation(rng, len(dataset))
    perms = perms[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    perms = perms.reshape((steps_per_epoch, batch_size))

    for perm in perms:
        batch = dataset[perm]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)
        yield batch

In [None]:
# similar to train data loader 
def eval_data_loader(dataset, batch_size): 
    for i in range(len(dataset) // batch_size):
        batch = dataset[i * batch_size : (i + 1) * batch_size]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)
        yield batch

In [None]:
state = flax.jax_utils.replicate(state)

In [None]:
# generating a seeded PRNGKey for the dropout layers and dataset shuffling.
rng = jax.random.PRNGKey(seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())

## Training 

In [None]:
# Now, we'll define the training loop and train the pre-trained model
start = time.time()
# Full training loop
for i, epoch in enumerate(tqdm(range(1, num_train_epochs + 1), desc=f"Epoch ...", position=0, leave=True)):
    rng, input_rng = jax.random.split(rng)

    # train
    with tqdm(total=len(train_dataset) // total_batch_size, desc="Training...", leave=False) as progress_bar_train:
        for batch in train_data_loader(input_rng, train_dataset, total_batch_size):
            state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
            progress_bar_train.update(1)

    # evaluate
    with tqdm(total=len(validation_dataset) // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
          for batch in eval_data_loader(validation_dataset, total_batch_size):
                labels = batch.pop("labels")
                predictions = parallel_eval_step(state, batch)
                metric.add_batch(predictions=chain(*predictions), references=chain(*labels))
                progress_bar_eval.update(1)

    eval_metric = metric.compute(average='macro')

    loss = round(flax.jax_utils.unreplicate(train_metrics)['loss'].item(), 3)
    eval_score = round(list(eval_metric.values())[0],3)
    metric_name = list(eval_metric.keys())[0]

    print(f"{i+1}/{num_train_epochs} | Train loss: {loss} | Eval {metric_name}: {eval_score}")
    
print("Total time: ", time.time() - start, "seconds")

### **Conclusion**
Here in this notebook, we've illustrated how [JAX](https://github.com/google/jax) and [FLAX](https://flax.readthedocs.io/en/latest/) can be used to train the pre-trained neural network for the text classification dataset, with the F1 score of more than 80%. To see more examples of how to use [JAX](https://github.com/google/jax) and [FLAX](https://flax.readthedocs.io/en/latest/) with different data formats, please see this discussion post.  

Now, it's your turn to  create some amazing notebooks using [JAX](https://github.com/google/jax) and [FLAX](https://flax.readthedocs.io/en/latest/). 

### **Useful resources which helped me:**

* https://flax.readthedocs.io/en/latest/index.html
* https://github.com/google/flax/tree/main/examples
* https://www.kaggle.com/heyytanay/sentiment-clf-jax-flax-on-tpus-w-b/notebook
* https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface/notebook
* https://huggingface.co/docs/datasets/package_reference/main_classes.html#dataset
* https://colab.sandbox.google.com/github/huggingface/notebooks/blob/master/examples/text_classification_flax.ipynb#scrollTo=Mn1GdGpipfWK