# **🚀💫Supercharge KerasNLP Models with Wandb**

Semantic similarity refers to the task of determining the degree of similarity between two
sentences in terms of their meaning. We already saw in [this](https://keras.io/examples/nlp/semantic_similarity_with_bert/)
example how to use SNLI (Stanford Natural Language Inference) corpus to predict sentence
semantic similarity with the HuggingFace Transformers library. In this tutorial we will
learn how to use [KerasNLP](https://keras.io/keras_nlp/), an extension of the core Keras API,
for the same task. Furthermore, we will discover how KerasNLP effectively reduces boilerplate
code and simplifies the process of building and utilizing models. For more information on KerasNLP,
please refer to [KerasNLP's official documentation](https://keras.io/keras_nlp/).

Weights&Biases is an amazing platform for Experiment Tracking. However it doesn't support Multi backend Keras Core yet. Kudos to [Soumik Rakshit](https://www.kaggle.com/soumikrakshit) for his endeavour [**Wandb-addons**](https://geekyrakshit.dev/wandb-addons/) which provides multi-backend compatible Keras Callbacks.

<div style="
           display:fill;
           border:solid;
           border-radius:5px;
           font-size:110%;
           font-family:Verdana;
           letter-spacing:0.5px;
            border-style:solid;">

<h3 style="padding: 10px;text-align: center;"> Outline </h3></div>

1. **Getting started with KerasNLP**
2. **Overview of dataset**
3. **Establishing baseline with BERT.**
4. **Improving baseline by tweaking learning rate**
5. **Improving baseline further with learning rate scheduler**
6. **Choosing right hyperparameters with Wandb Sweep**

<div style="
           display:fill;
           border:solid;
           border-radius:5px;
           font-size:110%;
           font-family:Verdana;
           letter-spacing:0.5px;
            border-style:solid;">

<h3 style="padding: 10px;text-align: center;"> 1. Getting started with KerasNLP </h3></div>

The following guide uses [Keras Core](https://keras.io/keras_core/) to work in
any of `tensorflow`, `jax` or `torch`. Support for Keras Core is baked into
KerasNLP, simply change the `KERAS_BACKEND` environment variable below to change
the backend you would like to use. We select the `jax` backend below, which will
give us a particularly fast train step below.

In [1]:
!pip install -q keras-nlp wandb
!pip install --upgrade -q git+https://github.com/soumik12345/wandb-addons

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [2]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # or "tensorflow" or "torch"

import numpy as np
import tensorflow as tf
import keras_core as keras
import keras_nlp
import tensorflow_datasets as tfds
import wandb
from wandb_addons.keras import WandbMetricsLogger

PROJECT_NAME = "keras-nlp-x-wandb"

Using JAX backend.


In [3]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33myash22222[0m ([33mtechtitans[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

To load the SNLI dataset, we use the tensorflow-datasets library, which
contains over 550,000 samples in total. However, to ensure that this example runs
quickly, we use only 20% of the training samples.

<div style="
           display:fill;
           border:solid;
           border-radius:5px;
           font-size:110%;
           font-family:Verdana;
           letter-spacing:0.5px;
            border-style:solid;">

<h3 style="padding: 10px;text-align: center;"> 2. Overview of dataset </h3></div>


Every sample in the dataset contains three components: `hypothesis`, `premise`,
and `label`. epresents the original caption provided to the author of the pair,
while the hypothesis refers to the hypothesis caption created by the author of
the pair. The label is assigned by annotators to indicate the similarity between
the two sentences.

The dataset contains three possible similarity label values: Contradiction, Entailment,
and Neutral. Contradiction represents completely dissimilar sentences, while Entailment
denotes similar meaning sentences. Lastly, Neutral refers to sentences where no clear
similarity or dissimilarity can be established between them.

In [4]:
snli_train = tfds.load("snli", split="train[:20%]")
snli_val = tfds.load("snli", split="validation")
snli_test = tfds.load("snli", split="test")

# Here's an example of how our training samples look like, where we randomly select
# four samples:
sample = snli_test.batch(4).take(1).get_single_element()
sample

{'hypothesis': <tf.Tensor: shape=(4,), dtype=string, numpy=
 array([b'A girl is entertaining on stage',
        b'A group of people posing in front of a body of water.',
        b"The group of people aren't inide of the building.",
        b'The people are taking a carriage ride.'], dtype=object)>,
 'label': <tf.Tensor: shape=(4,), dtype=int64, numpy=array([0, 0, 0, 0])>,
 'premise': <tf.Tensor: shape=(4,), dtype=string, numpy=
 array([b'A girl in a blue leotard hula hoops on a stage with balloon shapes in the background.',
        b'A group of people taking pictures on a walkway in front of a large body of water.',
        b'Many people standing outside of a place talking to each other in front of a building that has a sign that says "HI-POINTE."',
        b'Three people are riding a carriage pulled by four horses.'],
       dtype=object)>}

### Preprocessing

In our dataset, we have identified that some samples have missing or incorrectly labeled
data, which is denoted by a value of -1. To ensure the accuracy and reliability of our model,
we simply filter out these samples from our dataset.

In [5]:
def filter_labels(sample):
    return sample["label"] >= 0

Here's a utility function that splits the example into an `(x, y)` tuple that is suitable
for `model.fit()`. By default, `keras_nlp.models.BertClassifier` will tokenize and pack
together raw strings using a `"[SEP]"` token during training. Therefore, this label
splitting is all the data preparation that we need to perform.

In [6]:
def split_labels(sample):
    x = (sample["hypothesis"], sample["premise"])
    y = sample["label"]
    return x, y


train_ds = (
    snli_train.filter(filter_labels)
    .map(split_labels, num_parallel_calls=tf.data.AUTOTUNE)
)

val_ds = (
    snli_val.filter(filter_labels)
    .map(split_labels, num_parallel_calls=tf.data.AUTOTUNE)
)

test_ds = (
    snli_test.filter(filter_labels)
    .map(split_labels, num_parallel_calls=tf.data.AUTOTUNE)
)

def get_batched_dataset(batch_size):
    train_set = train_ds.batch(batch_size)
    val_set = val_ds.batch(batch_size)
    test_set = test_ds.batch(batch_size)
    return train_set, val_set, test_set

<div style="
           display:fill;
           border:solid;
           border-radius:5px;
           font-size:110%;
           font-family:Verdana;
           letter-spacing:0.5px;
            border-style:solid;">

<h3 style="padding: 10px;text-align: center;"> 3. Establishing baseline with BERT </h3></div>



We use the BERT model from KerasNLP to establish a baseline for our semantic similarity
task. The `keras_nlp.models.BertClassifier` class attaches a classification head to the BERT
Backbone, mapping the backbone outputs to a logit output suitable for a classification task.
This significantly reduces the need for custom code.

KerasNLP models have built-in tokenization capabilities that handle tokenization by default
based on the selected model. However, users can also use custom preprocessing techniques
as per their specific needs. If we pass a tuple as input, the model will tokenize all the
strings and concatenate them with a `"[SEP]"` separator.

We use this model with pretrained weights, and we can use the `from_preset()` method
to use our own preprocessor. For the SNLI dataset, we set `num_classes` to 3.

In [None]:
%%wandb

with wandb.init(project=PROJECT_NAME, name="baseline") as run:
    bert_classifier = keras_nlp.models.BertClassifier.from_preset(
        "bert_tiny_en_uncased", num_classes=3
    )
    train_set, val_set, test_set = get_batched_dataset(512)
    bert_classifier.fit(train_set, validation_data=val_set, epochs=1, callbacks=[WandbMetricsLogger(log_freq="batch")])
    bert_classifier.evaluate(test_set, callbacks=[WandbMetricsLogger(log_freq="batch")])



<div style="
           display:fill;
           border:solid;
           border-radius:5px;
           font-size:110%;
           font-family:Verdana;
           letter-spacing:0.5px;
            border-style:solid;">

<h3 style="padding: 10px;text-align: center;"> 4. Improving baseline by tweaking learning rate </h3></div>

In [None]:
%%wandb

with wandb.init(project=PROJECT_NAME, name="change-lr-bs") as run:
    bert_classifier = keras_nlp.models.BertClassifier.from_preset(
        "bert_tiny_en_uncased", num_classes=3
    )
    bert_classifier.compile(
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=keras.optimizers.Adam(5e-5),
        metrics=[keras.metrics.SparseCategoricalAccuracy()]
    )

    train_set, val_set, test_set = get_batched_dataset(512)

    bert_classifier.fit(train_set, validation_data=val_set, epochs=1, callbacks=[WandbMetricsLogger(log_freq="batch")])

    bert_classifier.evaluate(test_set, callbacks=[WandbMetricsLogger(log_freq="batch")])

Just tweaking the learning rate alone was not enough to boost performance. Let's try again, but this time with
`keras.optimizers.AdamW`, and a learning rate schedule.

<div style="
           display:fill;
           border:solid;
           border-radius:5px;
           font-size:110%;
           font-family:Verdana;
           letter-spacing:0.5px;
            border-style:solid;">

<h3 style="padding: 10px;text-align: center;"> 5. Improving baseline further with learning rate scheduler </h3></div>

In [None]:
class TriangularSchedule(keras.optimizers.schedules.LearningRateSchedule):
    """Linear ramp up for `warmup` steps, then linear decay to zero at `total` steps."""

    def __init__(self, rate, warmup, total):
        self.rate = rate
        self.warmup = warmup
        self.total = total

    def get_config(self):
        config = {"rate": self.rate, "warmup": self.warmup, "total": self.total}
        return config

    def __call__(self, step):
        step = keras.ops.cast(step, dtype="float32")
        rate = keras.ops.cast(self.rate, dtype="float32")
        warmup = keras.ops.cast(self.warmup, dtype="float32")
        total = keras.ops.cast(self.total, dtype="float32")

        warmup_rate = rate * step / self.warmup
        cooldown_rate = rate * (total - step) / (total - warmup)
        triangular_rate = keras.ops.minimum(warmup_rate, cooldown_rate)
        return keras.ops.maximum(triangular_rate, 0.0)

In [None]:
%%wandb

with wandb.init(project=PROJECT_NAME, name="lr-schedule") as run:
    bert_classifier = keras_nlp.models.BertClassifier.from_preset(
        "bert_tiny_en_uncased", num_classes=3
    )

    train_set, val_set, test_set = get_batched_dataset(512)
    # Get the total count of training batches.
    # This requires walking the dataset to filter all -1 labels.
    epochs = 3
    total_steps = sum(1 for _ in train_ds.as_numpy_iterator()) * epochs
    warmup_steps = int(total_steps * 0.2)

    bert_classifier.compile(
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=keras.optimizers.AdamW(
            TriangularSchedule(1e-4, warmup_steps, total_steps)
        ),
        metrics=[keras.metrics.SparseCategoricalAccuracy()],
    )

    bert_classifier.fit(train_set, validation_data=val_set, epochs=epochs, callbacks=[WandbMetricsLogger(log_freq="batch")])
    bert_classifier.evaluate(test_set, callbacks=[WandbMetricsLogger(log_freq="batch")])

<div style="
           display:fill;
           border:solid;
           border-radius:5px;
           font-size:110%;
           font-family:Verdana;
           letter-spacing:0.5px;
            border-style:solid;">

<h3 style="padding: 10px;text-align: center;"> 6. Choosing right hyperparameters with Wandb Sweep </h3></div>

In [None]:
import wandb

sweep_config = {
    'project': PROJECT_NAME,
    'method': 'grid',
    'run_cap': 6,
    'metric': {
      'name': 'accuracy',
      'goal': 'maximize'
    },
    'parameters': {

        'learning_rate': {
            'values': [5e-6, 2e-5, 5e-5, 1e-4]
        },
        'batch_size': {
            'values': [256, 512]
        }
    }
}
sweep_defaults = {
    'learning_rate': 5e-5,
    'batch_size': 512,
}

sweep_id = wandb.sweep(sweep_config, project=PROJECT_NAME)

In [None]:
def train():

    wandb.init(project=PROJECT_NAME, config=sweep_defaults)

    bert_classifier = keras_nlp.models.BertClassifier.from_preset(
      "bert_tiny_en_uncased", num_classes=3)

    train_set, val_set, test_set = get_batched_dataset(wandb.config.batch_size)

    optimizer = keras.optimizers.AdamW(learning_rate = wandb.config.learning_rate,
                                        epsilon = 1e-8)

    bert_classifier.compile(
          loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
          optimizer=optimizer,
          metrics=[keras.metrics.SparseCategoricalAccuracy()]
    )

    epochs = 2

    bert_classifier.fit(train_set, validation_data=val_set, epochs=epochs,
              callbacks=[WandbMetricsLogger(log_freq="batch")])

    bert_classifier.evaluate(test_set, callbacks=[WandbMetricsLogger(log_freq="batch")])

In [None]:
wandb.agent(sweep_id, function=train)

In [None]:
%wandb shivance/"keras-nlp-x-wandb"/reports/Vmlldzo1Mjk1ODQ4

We hope this tutorial has been helpful in demonstrating the ease and effectiveness
of using KerasNLP and BERT for semantic similarity tasks.

Throughout this tutorial, we demonstrated how to use a pretrained BERT model to
establish a baseline and improve performance by training a larger RoBERTa model
using just a few lines of code.

The KerasNLP toolbox provides a range of modular building blocks for preprocessing
text, including pretrained state-of-the-art models and low-level Transformer Encoder
layers. We believe that this makes experimenting with natural language solutions
more accessible and efficient.