# JAX and TPUs
This is the best possible combination of hardware and software to run the training loops as fast as possible. JAX uses the incredible XLA compiler which makes the code efficient to run on TPUs. HuggingFace is now porting all its models to Flax library which is a JAX based Neural Network library developed by Google Brain. 

This is the best news to try out our NLP based tasks using the Pre-trained models powered by HuggingFace and using the raw power of TPUs. Once the compilation is done, a TPU-v3 can finish an entire epoch of training and evaluating on this dataset in less than 5 seconds (tested on Colab, Kaggle and TPU-VM), this is amazing!

Best part of all, Kaggle gives you 30 free hours of TPUv3-8 usage every week! That is 240$ of free compute every week. (Price estimated as per costs at us-central1 on [Google Cloud TPU pricing](https://cloud.google.com/tpu/pricing))

Without further ado... Let's GO!

This notebook is inspired by Flax examples provided by the HuggingFace official library [here](https://github.com/huggingface/transformers/tree/master/examples/flax)

# Install dependencies

In [1]:
%%capture
!conda install -y -c conda-forge jax jaxlib flax optax datasets transformers
!conda install -y importlib-metadata

# Setup TPU
Prepare and setup our TPU so that it can be used with JAX. This snip is adapted from [this notebook](https://www.kaggle.com/narainp/jax-tpu-demo-wip)

In [2]:
import os
if 'TPU_NAME' in os.environ:
    import requests
    if 'TPU_DRIVER_MODE' not in globals():
        url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
        resp = requests.post(url)
        TPU_DRIVER_MODE = 1


    from jax.config import config
    config.FLAGS.jax_xla_backend = "tpu_driver"
    config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
    print('Registered TPU:', config.FLAGS.jax_backend_target)
else:
    print('No TPU detected. Can be changed under "Runtime/Change runtime type".')

Registered TPU: grpc://10.0.0.2:8470


In [3]:
import jax
jax.local_devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

# Model Name and Batch Size
`model_checkpoint` is the name of the pre-trained model as per HuggingFace nomenclature. I had checked on several choices of this, but for the sake of simplicity for this tutorial let's just consider **BERT-BASE-UNCASED**. In my tests as on 28th June 2021, **ROBERTA-BASE** is giving an error related to an implementation error.

`per_device_batch_size` here I'm giving a good starting point with 32. TPUv3-8 has a massive computational capability. It can handle a lot more than 32 per device/core of TPU. But as the dataset we have is very tiny with just 2834 samples, I chose the smaller batch size. General rule of thumb, we want to do mini-batch training for better generalization purposes (TODO: Add Citation to the original paper concluded this.)

Feel free to experiment with various models and batch sizes.

In [4]:
model_checkpoint = "bert-base-uncased" # 'roberta-base' has an error remaining are working.
per_device_batch_size = 32

# Define the RMSE Metric
The contest **Common Lit Readability** evaluates all the submission with the **RMSE** metric. Please check the Evaluation tab of the contest to confirm. I implemented the formula given in the Evaluation tab in normal numpy (not JAX). 

I created a Metric wrapper from HuggingFace which handles the device and several other device related issues.

Refer to [this link](https://huggingface.co/docs/datasets/add_metric.html) for creating and defining your own metric using HuggingFace library.

In [5]:
import numpy as np
import datasets

def simple_rmse(preds, labels):
    rmse = np.sqrt(np.sum(np.square(preds-labels))/preds.shape[0])
    return rmse


class RMSE(datasets.Metric):
    def _info(self):
        return datasets.MetricInfo(
            description="Calculates Root Mean Squared Error (RMSE) metric.",
            citation="TODO: _CITATION",
            inputs_description="_KWARGS_DESCRIPTION",
            features=datasets.Features({
                'predictions': datasets.Value('float32'),
                'references': datasets.Value('float32'),
            }),
            codebase_urls=[],
            reference_urls=[],
            format='numpy'
        )

    def _compute(self, predictions, references):
        return {"RMSE": simple_rmse(predictions, references)}


# Loading dataset and metric
I personally prefer HugginFace datasets because they are very well designed and makes it easy to pre-process all the samples very easily and it has several features like easily loading from the CSV file without using any Pandas data frame objects as intermediates.

Full documentation of the HuggingFace dataset can be found [here](https://huggingface.co/docs/datasets/package_reference/main_classes.html#dataset)

In [6]:
from datasets import load_dataset, load_metric
raw_train = load_dataset("csv", data_files={'train': ['../input/commonlitreadabilityprize/train.csv']})
raw_test = load_dataset('csv', data_files={'test': ['../input/commonlitreadabilityprize/test.csv']})

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1562.0, style=ProgressStyle(description…




Using custom data configuration default


Downloading and preparing dataset csv/default-f4f55e3cdc4fe7ff (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /root/.cache/huggingface/datasets/csv/default-f4f55e3cdc4fe7ff/0.0.0/2960f95a26e85d40ca41a230ac88787f715ee3003edaacb8b1f0891e9f04dda2...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-f4f55e3cdc4fe7ff/0.0.0/2960f95a26e85d40ca41a230ac88787f715ee3003edaacb8b1f0891e9f04dda2. Subsequent calls will reuse this data.


Using custom data configuration default


Downloading and preparing dataset csv/default-d5d3f0a81b7bcc19 (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /root/.cache/huggingface/datasets/csv/default-d5d3f0a81b7bcc19/0.0.0/2960f95a26e85d40ca41a230ac88787f715ee3003edaacb8b1f0891e9f04dda2...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-d5d3f0a81b7bcc19/0.0.0/2960f95a26e85d40ca41a230ac88787f715ee3003edaacb8b1f0891e9f04dda2. Subsequent calls will reuse this data.


In [7]:
# Split the train set into train and valid sets
raw_train = raw_train["train"].train_test_split(0.1)

In [8]:
metric = RMSE()

# Pre-process the dataset
This is a very generic pre-processing nothing special. Just tokenized the sentence and padded it appropriately.

In [9]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=570.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




In [10]:
def preprocess_function(examples):
    texts = (examples["excerpt"],)
    processed = tokenizer(*texts, padding="max_length", max_length=128, truncation=True)
    
    processed["labels"] = examples["target"]
    return processed

In [11]:
tokenized_dataset = raw_train.map(preprocess_function, batched=True, remove_columns=raw_train["train"].column_names)

HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




In [12]:
tokenized_dataset

DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'input_ids', 'labels', 'token_type_ids'],
        num_rows: 2550
    })
    test: Dataset({
        features: ['attention_mask', 'input_ids', 'labels', 'token_type_ids'],
        num_rows: 284
    })
})

In [13]:
# The test was created by the 0.1 split of the data which is our validation/evaluation dataset.
train_dataset = tokenized_dataset["train"]
eval_dataset = tokenized_dataset["test"]

# Model
We have a regression problem at hand so the model just needs to output 1 number.

In [14]:
from transformers import FlaxAutoModelForSequenceClassification, AutoConfig

num_labels = 1
seed = 0

config = AutoConfig.from_pretrained(model_checkpoint, num_labels=num_labels)
model = FlaxAutoModelForSequenceClassification.from_pretrained(model_checkpoint, config=config, seed=seed)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=438064459.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-uncased were not used when initializing FlaxBertForSequenceClassification: {('cls', 'predictions', 'transform', 'dense', 'kernel'), ('cls', 'predictions', 'transform', 'dense', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'scale'), ('cls', 'predictions', 'bias')}
- This IS expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of FlaxBertForSequenceClassification were not initialized from the model checkpoint at bert-base-unca

# Training and evaluation loop

In [15]:
import flax
import jax
import optax

from itertools import chain
from tqdm.notebook import tqdm
from typing import Callable

import jax.numpy as jnp

from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.training import train_state
from flax import traverse_util

In [16]:
num_train_epochs = 10
learning_rate = 2e-5

There are 8 cores in TPUv3-8, so the effective `batch_size = 8 * per_device_batch_size`

In [17]:
total_batch_size = per_device_batch_size * jax.local_device_count()
print("The overall batch size (both for training and eval) is", total_batch_size)

The overall batch size (both for training and eval) is 256


I used the One-Cycle LR Scheduler with Cosine Annealing. It is super easy to create this LR Schedule with the [Optax](https://github.com/deepmind/optax) library, it is the recommended library while using any JAX based NN libraries. Optax is being developed by **DeepMind** has several amazing features, definitely give it a try!

TODO: Add citations to the original One-Cycle and Cosine Annealing papers.

In [18]:
num_train_steps = len(train_dataset) // total_batch_size * num_train_epochs

learning_rate_function = optax.cosine_onecycle_schedule(transition_steps=num_train_steps, peak_value=learning_rate, pct_start=0.1, )
print("The number of train steps (all the epochs) is", num_train_steps)

The number of train steps (all the epochs) is 90


## Create a Train State
Next, we will create the *training state* that includes the optimizer, the loss function, and is responsible for updating the model's parameters during training.

Most JAX transformations (notably [jax.jit](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html)) require functions that are transformed to have no side-effects as it follows a functional programming type paradigm at its core. This is because any such side-effects will only be executed once, when the Python version of the function is run during compilation (see [Stateful Computations in JAX](https://jax.readthedocs.io/en/latest/jax-101/07-state.html)). As a consequence, Flax models (which can be transformed by JAX transformations) are **immutable**, and the state of the model (i.e., its weight parameters) are stored *outside* of the model instance.

Flax provides a convenience class [`flax.training.train_state.TrainState`](https://github.com/google/flax/blob/9da95cdd12591f42d2cd4c17089861bff7e43cc5/flax/training/train_state.py#L22), which stores things such as the model parameters, the loss function, the optimizer, and exposes an `apply_gradients` function to update the model's weight parameters.

We create a derived `TrainState` class that additionally stores the model's forward pass as `eval_function` as well as a `loss_function`.

In [19]:
class TrainState(train_state.TrainState):
    logits_function: Callable = flax.struct.field(pytree_node=False)
    loss_function: Callable = flax.struct.field(pytree_node=False)

## AdamW Optimizer
We will be using the standard Adam optimizer with weight decay. For more information on AdamW (Adam + weight decay), one can take a look at [this](https://www.fast.ai/2018/07/02/adam-weight-decay/) blog post. `weight_decay` value of 0.01 is a good starting point, you can tweak this hyper-parameter and experiment with how it influences the final trained model.

Regularizing the *bias* and/or *LayerNorm* has not shown to improve performance and can even be disadvantageous, which is why we disable it here. For more information on this, please check out the following [blog post](https://medium.com/@shrutijadon10104776/why-we-dont-use-bias-in-regularization-5a86905dfcd6) or [paper](https://arxiv.org/abs/1711.05101).

Hence we create a `decay_mask_fn` which makes sure that weight decay is not applied to any *bias* or *LayerNorm* weights. This can easily be done by passing a `mask_fn` to `optax.adamw`.

**NOTE**: Beginners can **ignore** the `decay_mask_fn`, the changes are minimal if you leave out doing this step.

In [20]:
def decay_mask_fn(params):
    flat_params = traverse_util.flatten_dict(params)
    flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
    return traverse_util.unflatten_dict(flat_mask)

In [21]:
def adamw(weight_decay):
    return optax.adamw(learning_rate=learning_rate_function, b1=0.9, b2=0.999, eps=1e-6, weight_decay=weight_decay, mask=decay_mask_fn)

In [22]:
adamw = adamw(1e-2)

## Loss and eval functions
The standard loss function for regression problems is the MSE loss. The book by Bishop has an additional 0.5 term, but we're skipping in that without loss of generality. That term just scales the loss by a constant factor and doesn't have an impact on the gradients (other than scaling).

In [23]:
@jax.jit
def loss_function(logits, labels):
    return jnp.mean((logits[..., 0] - labels) ** 2)

@jax.jit    
def eval_function(logits):
    return logits[..., 0]

## Create the initial train state
Finally!!

In [24]:
state = TrainState.create(
    apply_fn=model.__call__,
    params=model.params,
    tx=adamw,
    logits_function=eval_function,
    loss_function=loss_function,
)

### Defining the training and evaluation step

During fine-tuning, we want to update the model parameters and evaluate the performance after each epoch. 

Let's write the functions `train_step` and `eval_step` accordingly. During training the weight parameters should be updated as follows:

1. Define a loss function `loss_function` that first runs a forward pass of the model given data input. Remember that Flax models are immutable, and we explicitly pass it the state (in this case the model parameters and the RNG). `loss_function` returns a scalar loss (using the previously defined `state.loss_function`) between the model output and input targets.
2. Differentiate this loss function using [`jax.value_and_grad`](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#evaluate-a-function-and-its-gradient-using-value-and-grad). This is a JAX transformation called [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation), which computes the gradient of `loss_function` given the input to the function (i.e., the parameters of the model), and returns the value and the gradient in a pair `(loss, gradients)`.
3. Compute the mean gradient over all devices using the collective operation [lax.pmean](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.pmean.html). As we will see below, each device runs `train_step` on a different batch of data, but by taking the mean here we ensure the model parameters are the same on all devices.
4. Use `state.apply_gradients`, which applies the gradients to the weights.

Below, you can see how each of the described steps above is put into practice.

**NOTE: Taken from HuggingFace examples** 

In [25]:
def train_step(state, batch, dropout_rng):
    targets = batch.pop("labels")
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

    def loss_function(params):
        logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
        loss = state.loss_function(logits, targets)
        return loss

    grad_function = jax.value_and_grad(loss_function)
    loss, grad = grad_function(state.params)
    grad = jax.lax.pmean(grad, "batch")
    new_state = state.apply_gradients(grads=grad)
    metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_function(state.step)}, axis_name="batch")
    return new_state, metrics, new_dropout_rng

Now, we want to do parallelized training over all TPU devices. To do so, we use [`jax.pmap`](https://jax.readthedocs.io/en/latest/jax.html?highlight=pmap#parallelization-pmap). This will compile the function once and run the same program on each device (it is an [SPMD program](https://en.wikipedia.org/wiki/SPMD)). When calling this pmapped function, all inputs (`"state"`, `"batch"`, `"dropout_rng"`) should be replicated for all devices, which means that the first axis of each argument is used to map over all TPU devices.

The argument `donate_argnums` is used to tell JAX that the first argument `"state"` is "donated" to the computation, because it is not needed anymore afterwards. XLA can make use of donated buffers to reduce the memory needed.

In [26]:
parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))

In [27]:
def eval_step(state, batch):
    logits = state.apply_fn(**batch, params=state.params, train=False)[0]
    return state.logits_function(logits)

In [28]:
parallel_eval_step = jax.pmap(eval_step, axis_name="batch")

## Define Data Loaders
In a final step before we can start training, we need to define the data collators. The data collator is important to shuffle the training data before each epoch and to prepare the batch for each training and evaluation step.

First, a random permutation of the whole dataset is defined. 
Then, every time the training data collator is called the next batch of the randomized dataset is extracted, converted to a JAX array and sharded over all local TPU devices.

In [29]:
def train_data_loader(rng, dataset, batch_size):
    steps_per_epoch = len(dataset) // batch_size
    perms = jax.random.permutation(rng, len(dataset))
    perms = perms[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    perms = perms.reshape((steps_per_epoch, batch_size))

    for perm in perms:
        batch = dataset[perm]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)

        yield batch

In [30]:
def eval_data_loader(dataset, batch_size):
    for i in range(len(dataset) // batch_size):
        batch = dataset[i * batch_size : (i + 1) * batch_size]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)

        yield batch

Next, we replicate/copy the weight parameters on each device, so that we can pass them to our pmapped functions.


In [31]:
state = flax.jax_utils.replicate(state)

  "jax.host_count has been renamed to jax.process_count. This alias "
  "jax.host_id has been renamed to jax.process_index. This alias "


# Training
Now we define the full training loop. For each batch in each epoch, we run a training step. Here, we also need to make sure that the PRNGKey is sharded/split over each device. Having completed an epoch, we report the training metrics and can run the evaluation.

The first batch takes a bit longer to process but nothing to worry because during the first batch, XLA compiler is working hard to make everything super fast. The first takes close to 5 mins for processing and then entire epochs take ~5 sec to process. Aren't TPUs amazing!!

**5 seconds for an entire EPOCH!!**

Note: The times mentioned above are an average estimate over 8 different runs on several different TPU machines and several model architectures.

In [32]:
rng = jax.random.PRNGKey(seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())

In [33]:
for i, epoch in enumerate(tqdm(range(1, num_train_epochs + 1), desc=f"Epoch ...", position=0, leave=True)):
    rng, input_rng = jax.random.split(rng)

    # train
    with tqdm(total=len(train_dataset) // total_batch_size, desc="Training...", leave=False) as progress_bar_train:
        for batch in train_data_loader(input_rng, train_dataset, total_batch_size):
            state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
            progress_bar_train.update(1)

    # evaluate
    with tqdm(total=len(eval_dataset) // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
        for batch in eval_data_loader(eval_dataset, total_batch_size):
            labels = batch.pop("labels")
            predictions = parallel_eval_step(state, batch)
            metric.add_batch(predictions=chain(*predictions), references=chain(*labels))
            progress_bar_eval.update(1)

    eval_metric = metric.compute()

    loss = round(flax.jax_utils.unreplicate(train_metrics)['loss'].item(), 3)
    eval_score = round(list(eval_metric.values())[0], 3)
    metric_name = list(eval_metric.keys())[0]

    print(f"{i+1}/{num_train_epochs} | Train loss: {loss} | Eval {metric_name}: {eval_score}")

HBox(children=(FloatProgress(value=0.0, description='Epoch ...', max=10.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Training...', max=9.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Evaluating...', max=1.0, style=ProgressStyle(description_…

1/10 | Train loss: 1.384 | Eval RMSE: 1.215


HBox(children=(FloatProgress(value=0.0, description='Training...', max=9.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Evaluating...', max=1.0, style=ProgressStyle(description_…

2/10 | Train loss: 0.629 | Eval RMSE: 0.905


HBox(children=(FloatProgress(value=0.0, description='Training...', max=9.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Evaluating...', max=1.0, style=ProgressStyle(description_…

3/10 | Train loss: 0.477 | Eval RMSE: 0.708


HBox(children=(FloatProgress(value=0.0, description='Training...', max=9.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Evaluating...', max=1.0, style=ProgressStyle(description_…

4/10 | Train loss: 0.456 | Eval RMSE: 0.642


HBox(children=(FloatProgress(value=0.0, description='Training...', max=9.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Evaluating...', max=1.0, style=ProgressStyle(description_…

5/10 | Train loss: 0.367 | Eval RMSE: 0.663


HBox(children=(FloatProgress(value=0.0, description='Training...', max=9.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Evaluating...', max=1.0, style=ProgressStyle(description_…

6/10 | Train loss: 0.349 | Eval RMSE: 0.642


HBox(children=(FloatProgress(value=0.0, description='Training...', max=9.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Evaluating...', max=1.0, style=ProgressStyle(description_…

7/10 | Train loss: 0.307 | Eval RMSE: 0.642


HBox(children=(FloatProgress(value=0.0, description='Training...', max=9.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Evaluating...', max=1.0, style=ProgressStyle(description_…

8/10 | Train loss: 0.276 | Eval RMSE: 0.64


HBox(children=(FloatProgress(value=0.0, description='Training...', max=9.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Evaluating...', max=1.0, style=ProgressStyle(description_…

9/10 | Train loss: 0.315 | Eval RMSE: 0.641


HBox(children=(FloatProgress(value=0.0, description='Training...', max=9.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Evaluating...', max=1.0, style=ProgressStyle(description_…

10/10 | Train loss: 0.304 | Eval RMSE: 0.642



# Generating Results
Our test dataset has slightly different pre-processing step because we do not have a label in the dataset. So, we should handle accordingly.

In [34]:
def preprocess_test_set_function(examples):
    texts = (examples["excerpt"],)
    processed = tokenizer(*texts, padding="max_length", max_length=128, truncation=True)
    
    return processed

In [35]:
tokenized_test_dataset = raw_test.map(preprocess_test_set_function, batched=True, remove_columns=raw_test["test"].column_names)

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




In [36]:
test_dataset = tokenized_test_dataset["test"]
test_dataset

Dataset({
    features: ['attention_mask', 'input_ids', 'token_type_ids'],
    num_rows: 7
})

We won't shard our data anymore because usually the test sets are very small and can be done entirely on one-core without having the additional overheads. So, we also have to "un-shard" our model and run entirely on the single device of the device slice. So we use the `unreplicate` method in the flax library, [here is the documentation](https://flax.readthedocs.io/en/latest/flax.jax_utils.html#flax.jax_utils.unreplicate)

# Generation
Final step. We have successfully fine-tuned a BERT model to the Lit-Readability task. That's amazing! It took us less than 10 mins to reach a very good score! Now it is time to get our model predictions on our test set.

In [37]:
def test_data_loader(dataset, batch_size):
    if len(dataset)<batch_size:
        batch = dataset[:]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        yield batch
    else:
        for i in range(len(dataset) // batch_size):
            batch = dataset[i * batch_size : (i + 1) * batch_size]
            batch = {k: jnp.array(v) for k, v in batch.items()}

            yield batch
        batch = dataset[(i+1) * batch_size:]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        yield batch

In [38]:
from flax.jax_utils import unreplicate

unrep_state = unreplicate(state)

In [39]:
def generate_results():
    preds = []
    for batch in test_data_loader(test_dataset, total_batch_size):
        if jax.process_index()==0:
            predictions = unrep_state.apply_fn(**batch, train=False, return_dict=False)
            preds.append(predictions[0])
    return preds

In [40]:
preds = generate_results()

Now we clean-up and make our results "Submission ready". First we convert all JAX **DeviceArray** objects to Numpy arrays, then we create a submission file.

In [41]:
import numpy as np
preds = np.vstack([np.asarray(x) for x in preds])
preds

array([[-0.0051295 ],
       [ 0.11400247],
       [ 0.20827302],
       [ 0.0565445 ],
       [ 0.13570154],
       [ 0.17821385],
       [ 0.21502109]], dtype=float32)

In [42]:
import pandas as pd
sample = pd.read_csv('../input/commonlitreadabilityprize/sample_submission.csv')
sample.target = preds
sample

Unnamed: 0,id,target
0,c0f722661,-0.00513
1,f0953f0a5,0.114002
2,0df072751,0.208273
3,04caf4e0c,0.056544
4,0e63f8bea,0.135702
5,12537fe78,0.178214
6,965e592c0,0.215021


Export our results to a CSV File.

In [43]:
sample.to_csv('submission.csv',index=False)