<h1 style='text-align: center'> Sentiment Classification - 1.6M Tweets + HuggingFace BERT + Jax/Flax TPUs + W&B Tracking </h1>

<p style='text-align: center'>
Sentiment Classification on 1.6 Million tweets using Jax/Flax with TPUs using HuggingFace BERT and W&B Tracking!<br> 
I have used this <a href='https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/text_classification_flax.ipynb'> script</a> as a base and then modified it to work with this dataset.<br>
I have also used Weights and Biases tracking to keep track of the training process and the experiments I will conduct.
</p>

<center><img src="https://i.imgur.com/gb6B4ig.png" width="400" alt="Weights & Biases"/></center><br>
<p style="text-align:center">WandB is a developer tool for companies turn deep learning research projects into deployed software by helping teams track their models, visualize model performance and easily automate training and improving models.
We will use their tools to log hyperparameters and output metrics from your runs, then visualize and compare results and quickly share findings with your colleagues.<br><br>We'll be using this to train our K Fold Cross Validation and gain better insights about our training. <br><br></p>

![img](https://i.imgur.com/BGgfZj3.png)

**You can upvote this kernel, if you found it useful!**

In [1]:
%%capture
! pip install --upgrade jax
! pip install --upgrade jaxlib
! pip install git+https://github.com/huggingface/transformers.git
! pip install git+https://github.com/deepmind/optax.git
! pip install flax
! conda install -y -c conda-forge datasets
! conda install -y importlib-metadata

In [2]:
%%sh
pip install -q wandb --upgrade

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
allennlp 2.5.0 requires transformers<4.7,>=4.1, but you have transformers 4.10.0.dev0 which is incompatible.
allennlp 2.5.0 requires wandb<0.11.0,>=0.10.0, but you have wandb 0.11.2 which is incompatible.


In [3]:
import os
if 'TPU_NAME' in os.environ:
    import requests
    if 'TPU_DRIVER_MODE' not in globals():
        url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
        resp = requests.post(url)
        TPU_DRIVER_MODE = 1

    from jax.config import config
    config.FLAGS.jax_xla_backend = "tpu_driver"
    config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
    print('Registered TPU:', config.FLAGS.jax_backend_target)
else:
    print('No TPU detected. Can be changed under "Runtime/Change runtime type".')

Registered TPU: grpc://10.0.0.2:8470


In [4]:
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

import datasets
from datasets import load_dataset, load_metric

import jax
import flax
import optax
import jaxlib
import jax.numpy as jnp

from itertools import chain
from typing import Callable

from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.training import train_state
from flax import traverse_util

from transformers import AutoTokenizer, FlaxAutoModelForSequenceClassification, AutoConfig

import wandb

jax.local_devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [5]:
class Config:
    nb_epochs = 5
    lr = 2e-5
    per_device_bs = 32
    num_labels = 2
    model_name = 'bert-base-uncased'
    total_batch_size = per_device_bs * jax.local_device_count()
    tokenizer = AutoTokenizer.from_pretrained(model_name)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=570.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




In [None]:
# You can add your W&B API token as Kaggle secret with the name "WANDB_API_KEY".
# To get your W&B API token, visit https://wandb.ai/authorize

# W&B Login
# from kaggle_secrets import UserSecretsClient
# user_secrets = UserSecretsClient()
# wb_key = user_secrets.get_secret("WANDB_API_KEY")

# wandb.login(key=wb_key)

In [6]:
CONFIG = dict(
    lr=2e-5,
    model_name = 'bert-base-uncased',
    epochs = 5,
    split = 0.10,
    per_device_bs = 32,
    seed = 42,
    num_labels = 2,
    infra = "Kaggle",
    competition = 'none',
    _wandb_kernel = 'tanaym'
)

run = wandb.init(project='jax_flax', 
                 config=CONFIG,
                 group='bert',
                 job_type='train',
                 anonymous='allow'
                )

[34m[1mwandb[0m: (1) Private W&B dashboard, no account required
[34m[1mwandb[0m: (2) Use an existing W&B account


[34m[1mwandb[0m: Enter your choice:  1


[34m[1mwandb[0m: You chose 'Private W&B dashboard, no account required'
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [7]:
def wandb_log(**kwargs):
    """
    Logs a key-value pair to W&B
    """
    for k, v in kwargs.items():
        wandb.log({k: v})

In [8]:
def simple_acc(preds, labels):
    assert len(preds) == len(labels), "Predictions and Labels matrices must be of same length"
    acc = (preds == labels).sum() / len(preds)
    return acc

class ACCURACY(datasets.Metric):
    def _info(self):
        return datasets.MetricInfo(
            description="Calculates Accuracy metric.",
            citation="TODO: _CITATION",
            inputs_description="_KWARGS_DESCRIPTION",
            features=datasets.Features({
                'predictions': datasets.Value('int64'),
                'references': datasets.Value('int64'),
            }),
            codebase_urls=[],
            reference_urls=[],
            format='numpy'
        )

    def _compute(self, predictions, references):
        return {"ACCURACY": simple_acc(predictions, references)}
    
metric = ACCURACY()

In [9]:
def split_and_save(file_path: str, split: float = 0.10):
    file = pd.read_csv(file_path, encoding='latin-1', names=['sentiment', 'id', 'date', 'query', 'username', 'text'])
    file = file[['sentiment', 'text']]
    file['sentiment'] = file['sentiment'].map({4: 1, 0: 0})
    
    file = file.sample(frac=1).reset_index(drop=True)
    split_nb = int(len(file) * split)
    
    test_set = file[:split_nb].reset_index(drop=True)
    train_set = file[split_nb:].reset_index(drop=True)
    
    train_set.to_csv("train_file.csv", index=None)
    test_set.to_csv("test_file.csv", index=None)
    print("Done.")

split_and_save("../input/sentiment140/training.1600000.processed.noemoticon.csv")

Done.


In [10]:
# Get the training and testing files loaded in HF dataset format
raw_train = load_dataset("csv", data_files={'train': ['./train_file.csv']})
raw_test = load_dataset("csv", data_files={'test': ['./test_file.csv']})

Downloading and preparing dataset csv/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /root/.cache/huggingface/datasets/csv/default-af255dc9caab3475/0.0.0/9144e0a4e8435090117cea53e6c7537173ef2304525df4a077c435d8ee7828ff...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-af255dc9caab3475/0.0.0/9144e0a4e8435090117cea53e6c7537173ef2304525df4a077c435d8ee7828ff. Subsequent calls will reuse this data.
Downloading and preparing dataset csv/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /root/.cache/huggingface/datasets/csv/default-6b1cd4bbd25aec0e/0.0.0/9144e0a4e8435090117cea53e6c7537173ef2304525df4a077c435d8ee7828ff...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-6b1cd4bbd25aec0e/0.0.0/9144e0a4e8435090117cea53e6c7537173ef2304525df4a077c435d8ee7828ff. Subsequent calls will reuse this data.


In [11]:
def preprocess_function(data):
    """
    Preprocessing function
    """
    texts = (data["text"],)
    processed = Config.tokenizer(*texts, padding="max_length", max_length=128, truncation=True)
    processed["labels"] = data["sentiment"]
    return processed

In [12]:
%%time
train_dataset = raw_train.map(preprocess_function, batched=True, remove_columns=raw_train["train"].column_names)
test_dataset = raw_test.map(preprocess_function, batched=True, remove_columns=raw_test['test'].column_names)

HBox(children=(FloatProgress(value=0.0, max=1440.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=160.0), HTML(value='')))


CPU times: user 8min 13s, sys: 1min 14s, total: 9min 28s
Wall time: 3min 40s


In [13]:
train = train_dataset['train']
valid = test_dataset['test']
print(len(train), len(valid))

1440000 160000


In [14]:
config = AutoConfig.from_pretrained(Config.model_name, num_labels=Config.num_labels)
model = FlaxAutoModelForSequenceClassification.from_pretrained(Config.model_name, config=config, seed=42)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=438064459.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-uncased were not used when initializing FlaxBertForSequenceClassification: {('cls', 'predictions', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'scale'), ('cls', 'predictions', 'transform', 'dense', 'kernel'), ('cls', 'predictions', 'transform', 'dense', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias')}
- This IS expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of FlaxBertForSequenceClassification were not initialized from the model checkpoint at bert-base-unca

In [15]:
num_train_steps = len(train) // Config.total_batch_size * Config.nb_epochs
learning_rate_function = optax.cosine_onecycle_schedule(transition_steps=num_train_steps, peak_value=Config.lr, pct_start=0.1)
print("The number of train steps (all the epochs) is", num_train_steps)

The number of train steps (all the epochs) is 28125


In [16]:
optimizer = optax.adamw(learning_rate=Config.lr, b1=0.9, b2=0.999, eps=1e-6, weight_decay=1e-2)

In [17]:
def loss_fn(logits, targets):
    loss = optax.softmax_cross_entropy(logits, onehot(targets, num_classes=Config.num_labels))
    return jnp.mean(loss)
def eval_fn(logits):
    return logits.argmax(-1)

In [18]:
class TrainState(train_state.TrainState):
    eval_function: Callable = flax.struct.field(pytree_node=False)
    loss_function: Callable = flax.struct.field(pytree_node=False)

In [19]:
state = TrainState.create(
    apply_fn = model.__call__,
    params = model.params,
    tx = optimizer,
    eval_function=eval_fn,
    loss_function=loss_fn,
)

In [20]:
def train_step(state, batch, dropout_rng):
    targets = batch.pop("labels")
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
    
    def loss_function(params):
        logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
        loss = state.loss_function(logits, targets)
        return loss
    
    grad_fn = jax.value_and_grad(loss_function)
    loss, grad = grad_fn(state.params)
    grad = jax.lax.pmean(grad, "batch")
    new_state = state.apply_gradients(grads=grad)
    metrics = jax.lax.pmean({'loss': loss, 'learning_rate': learning_rate_function(state.step)}, axis_name='batch')
    
    return new_state, metrics, new_dropout_rng

In [21]:
parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))

In [22]:
def eval_step(state, batch):
    logits = state.apply_fn(**batch, params=state.params, train=False)[0]
    return state.eval_function(logits)

In [23]:
parallel_eval_step = jax.pmap(eval_step, axis_name="batch")

In [24]:
def sentimentTrainDataLoader(rng, dataset, batch_size):
    steps_per_epoch = len(dataset) // batch_size
    perms = jax.random.permutation(rng, len(dataset))
    perms = perms[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    perms = perms.reshape((steps_per_epoch, batch_size))

    for perm in perms:
        batch = dataset[perm]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)

        yield batch

In [25]:
def sentimentEvalDataLoader(dataset, batch_size):
    for i in range(len(dataset) // batch_size):
        batch = dataset[i * batch_size : (i + 1) * batch_size]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)

        yield batch

In [26]:
state = flax.jax_utils.replicate(state)

  "jax.host_count has been renamed to jax.process_count. This alias "
  "jax.host_id has been renamed to jax.process_index. This alias "


In [27]:
rng = jax.random.PRNGKey(42)
dropout_rngs = jax.random.split(rng, jax.local_device_count())

In [28]:
for i, epoch in enumerate(tqdm(range(1, Config.nb_epochs + 1), desc=f"Epoch...", position=0, leave=True)):
    rng, input_rng = jax.random.split(rng)

    # train
    with tqdm(total=len(train) // Config.total_batch_size, desc="Training...", leave=False) as progress_bar_train:
        for batch in sentimentTrainDataLoader(input_rng, train, Config.total_batch_size):
            state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
            progress_bar_train.update(1)

    # evaluate
    with tqdm(total=len(valid) // Config.total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
        for batch in sentimentEvalDataLoader(valid, Config.total_batch_size):
            labels = batch.pop("labels")
            predictions = parallel_eval_step(state, batch)
            metric.add_batch(predictions=chain(*predictions), references=chain(*labels))
            progress_bar_eval.update(1)

    eval_metric = metric.compute()

    loss = round(flax.jax_utils.unreplicate(train_metrics)['loss'].item(), 3)
    eval_score = round(list(eval_metric.values())[0], 3)
    metric_name = list(eval_metric.keys())[0]
    
    wandb_log(
        train_loss=loss,
        valid_acc=eval_score
    )
    print(f"{i+1}/{Config.nb_epochs} | Train loss: {loss} | Eval {metric_name}: {eval_score}")

HBox(children=(FloatProgress(value=0.0, description='Epoch...', max=5.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Training...', max=5625.0, style=ProgressStyle(description…

HBox(children=(FloatProgress(value=0.0, description='Evaluating...', max=625.0, style=ProgressStyle(descriptio…

1/5 | Train loss: 0.257 | Eval ACCURACY: 0.869


HBox(children=(FloatProgress(value=0.0, description='Training...', max=5625.0, style=ProgressStyle(description…

HBox(children=(FloatProgress(value=0.0, description='Evaluating...', max=625.0, style=ProgressStyle(descriptio…

2/5 | Train loss: 0.308 | Eval ACCURACY: 0.875


HBox(children=(FloatProgress(value=0.0, description='Training...', max=5625.0, style=ProgressStyle(description…

HBox(children=(FloatProgress(value=0.0, description='Evaluating...', max=625.0, style=ProgressStyle(descriptio…

3/5 | Train loss: 0.3 | Eval ACCURACY: 0.875


HBox(children=(FloatProgress(value=0.0, description='Training...', max=5625.0, style=ProgressStyle(description…

HBox(children=(FloatProgress(value=0.0, description='Evaluating...', max=625.0, style=ProgressStyle(descriptio…

4/5 | Train loss: 0.234 | Eval ACCURACY: 0.872


HBox(children=(FloatProgress(value=0.0, description='Training...', max=5625.0, style=ProgressStyle(description…

HBox(children=(FloatProgress(value=0.0, description='Evaluating...', max=625.0, style=ProgressStyle(descriptio…

5/5 | Train loss: 0.192 | Eval ACCURACY: 0.871



### [Check out the W&B Run Page here $\rightarrow$](https://wandb.ai/anony-mouse-121867/jax_flax/runs/28s0ucnx)

![img](https://i.imgur.com/YnmHupI.gif)