In [None]:
%%capture
!pip install datasets
!pip install git+https://github.com/huggingface/transformers.git
!pip install flax
!pip install git+https://github.com/deepmind/optax.git
!pip install evaluate
!pip install ipywidgets
!pip install black isort #jupyter notebook code formatter; optional

In [164]:
import os
from itertools import chain
from typing import Callable

import evaluate
import flax
import jax
import jax.numpy as jnp
import optax
import pandas as pd
from datasets import load_dataset
from flax import traverse_util
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from ipywidgets import IntProgress as IProgress
from tqdm.notebook import tqdm
from transformers import (
    AutoConfig,
    AutoTokenizer,
    FlaxAutoModelForSequenceClassification,
)

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

In [165]:
jax.local_devices()

[gpu(id=0),
 gpu(id=1),
 gpu(id=2),
 gpu(id=3),
 gpu(id=4),
 gpu(id=5),
 gpu(id=6),
 gpu(id=7)]

In [207]:
task = "qqp"
model_checkpoint = "bert-base-cased"
per_device_batch_size = 64

In [None]:
raw_dataset = load_dataset("glue", task)
metric = evaluate.load("glue", task)

In [209]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [210]:
def preprocess_function(examples):
    texts = (examples["question1"], examples["question2"])
    processed = tokenizer(*texts, padding="max_length", max_length=128, truncation=True)
    processed["labels"] = examples["label"]
    return processed

In [211]:
# Details about how to handle and process huggingface dataset:
# https://huggingface.co/docs/datasets/process
data = raw_dataset["train"].shuffle(seed=0)
train_data = data.select(list(range(int(data.shape[0] * 0.1))))
eval_data = data.select(list(range(int(data.shape[0] * 0.1), int(data.shape[0] * 0.2))))
print(f"Shape of the original training dataset is: {data.shape}")
print(f"Shape of the current training dataset is: {train_data.shape}")
print(f"Shape of the current evaluation dataset is: {eval_data.shape}")

Shape of the original training dataset is: (363846, 4)
Shape of the current training dataset is: (36384, 4)
Shape of the current evaluation dataset is: (36385, 4)


In [212]:
train_dataset = train_data.map(
    preprocess_function, batched=True, remove_columns=train_data.column_names
)
eval_dataset = eval_data.map(
    preprocess_function, batched=True, remove_columns=eval_data.column_names
)

Map:   0%|          | 0/36385 [00:00<?, ? examples/s]

In [213]:
pd.DataFrame(train_dataset[:3])

Unnamed: 0,input_ids,token_type_ids,attention_mask,labels
0,"[101, 1327, 1132, 1103, 2182, 1115, 1127, 1309...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",0
1,"[101, 1327, 1132, 1103, 4583, 2489, 1827, 1111...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",0
2,"[101, 2009, 1132, 1199, 22413, 1603, 15376, 13...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",0


In [214]:
num_labels = 2
seed = 0
config = AutoConfig.from_pretrained(model_checkpoint, num_labels=num_labels)
model = FlaxAutoModelForSequenceClassification.from_pretrained(
    model_checkpoint, config=config, seed=seed
)

Some weights of the model checkpoint at bert-base-cased were not used when initializing FlaxBertForSequenceClassification: {('cls', 'predictions', 'bias'), ('cls', 'predictions', 'transform', 'dense', 'kernel'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'scale'), ('cls', 'predictions', 'transform', 'dense', 'bias')}
- This IS expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of FlaxBertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased 

In [215]:
num_train_epochs = 6
learning_rate = 2e-5

In [216]:
total_batch_size = per_device_batch_size * jax.local_device_count()
print("The overall batch size (both for training and eval) is", total_batch_size)

The overall batch size (both for training and eval) is 512


In [217]:
num_train_steps = len(train_dataset) // total_batch_size * num_train_epochs

learning_rate_function = optax.linear_schedule(
    init_value=learning_rate, end_value=0, transition_steps=num_train_steps
)

In [218]:
class TrainState(train_state.TrainState):
    logits_function: Callable = flax.struct.field(pytree_node=False)
    loss_function: Callable = flax.struct.field(pytree_node=False)

In [219]:
# Create a decay_mask_fn function to make sure that weight decay is not applied to any bias or LayerNorm weights
# as it may not improve model performance and even be harmful.
def decay_mask_fn(params):
    flat_params = traverse_util.flatten_dict(params)
    flat_mask = {
        path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale"))
        for path in flat_params
    }
    return traverse_util.unflatten_dict(flat_mask)

In [220]:
# Standard Adam optimizer with weight decay
def adamw(weight_decay):
    return optax.adamw(
        learning_rate=learning_rate_function,
        b1=0.9,
        b2=0.999,
        eps=1e-6,
        weight_decay=weight_decay,
        mask=decay_mask_fn,
    )

In [221]:
def loss_function(logits, labels):
    xentropy = optax.softmax_cross_entropy(
        logits, onehot(labels, num_classes=num_labels)
    )
    return jnp.mean(xentropy)


def eval_function(logits):
    return logits.argmax(-1)

In [222]:
# Instantiate the TrainState
state = TrainState.create(
    apply_fn=model.__call__,
    params=model.params,
    tx=adamw(weight_decay=0.01),
    logits_function=eval_function,
    loss_function=loss_function,
)

In [223]:
def train_step(state, batch, dropout_rng):
    targets = batch.pop("labels")
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

    def loss_function(params):
        logits = state.apply_fn(
            **batch, params=params, dropout_rng=dropout_rng, train=True
        )[0]
        loss = state.loss_function(logits, targets)
        return loss

    grad_function = jax.value_and_grad(loss_function)
    loss, grad = grad_function(state.params)
    grad = jax.lax.pmean(grad, "batch")
    new_state = state.apply_gradients(grads=grad)
    metrics = jax.lax.pmean(
        {"loss": loss, "learning_rate": learning_rate_function(state.step)},
        axis_name="batch",
    )
    return new_state, metrics, new_dropout_rng

In [224]:
def eval_step(state, batch):
    logits = state.apply_fn(**batch, params=state.params, train=False)[0]
    return state.logits_function(logits)

In [225]:
parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
parallel_eval_step = jax.pmap(eval_step, axis_name="batch")
state = flax.jax_utils.replicate(state)

In [226]:
def glue_train_data_loader(rng, dataset, batch_size):
    steps_per_epoch = len(dataset) // batch_size
    perms = jax.random.permutation(rng, len(dataset))
    perms = perms[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    perms = perms.reshape((steps_per_epoch, batch_size))

    for perm in perms:
        batch = dataset[perm]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)

        yield batch

In [227]:
def glue_eval_data_loader(dataset, batch_size):
    for i in range(len(dataset) // batch_size):
        batch = dataset[i * batch_size : (i + 1) * batch_size]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)

        yield batch

In [228]:
rng = jax.random.PRNGKey(seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())

In [229]:
for i, epoch in enumerate(
    tqdm(range(1, num_train_epochs + 1), desc=f"Epoch ...", position=0, leave=True)
):
    rng, input_rng = jax.random.split(rng)

    # train
    with tqdm(
        total=len(train_dataset) // total_batch_size, desc="Training...", leave=True
    ) as progress_bar_train:
        for batch in glue_train_data_loader(input_rng, train_dataset, total_batch_size):
            state, train_metrics, dropout_rngs = parallel_train_step(
                state, batch, dropout_rngs
            )
            progress_bar_train.update(1)

    # evaluate
    with tqdm(
        total=len(eval_dataset) // total_batch_size, desc="Evaluating...", leave=False
    ) as progress_bar_eval:
        for batch in glue_eval_data_loader(eval_dataset, total_batch_size):
            labels = batch.pop("labels")
            predictions = parallel_eval_step(state, batch)
            metric.add_batch(
                predictions=list(chain(*predictions)), references=list(chain(*labels))
            )
            progress_bar_eval.update(1)

    eval_metric = metric.compute()

    loss = round(flax.jax_utils.unreplicate(train_metrics)["loss"].item(), 3)
    eval_score1 = round(list(eval_metric.values())[0], 3)
    metric_name1 = list(eval_metric.keys())[0]
    eval_score2 = round(list(eval_metric.values())[1], 3)
    metric_name2 = list(eval_metric.keys())[1]
    print(
        f"{i+1}/{num_train_epochs} | Train loss: {loss} | Eval {metric_name1}: {eval_score1}, {metric_name2}: {eval_score2}"
    )

Epoch ...:   0%|          | 0/6 [00:00<?, ?it/s]

Training...:   0%|          | 0/71 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]

1/6 | Train loss: 0.475 | Eval accuracy: 0.799, f1: 0.762


Training...:   0%|          | 0/71 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]

2/6 | Train loss: 0.369 | Eval accuracy: 0.834, f1: 0.789


Training...:   0%|          | 0/71 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]

3/6 | Train loss: 0.299 | Eval accuracy: 0.846, f1: 0.797


Training...:   0%|          | 0/71 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]

4/6 | Train loss: 0.239 | Eval accuracy: 0.846, f1: 0.806


Training...:   0%|          | 0/71 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]

5/6 | Train loss: 0.252 | Eval accuracy: 0.849, f1: 0.802


Training...:   0%|          | 0/71 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]

6/6 | Train loss: 0.212 | Eval accuracy: 0.849, f1: 0.805


## <a class="anchor" id="appendix"></a>Appendix: using JAX device mesh to achieve parallelism

In [230]:
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P

In [231]:
config = AutoConfig.from_pretrained(model_checkpoint, num_labels=num_labels)
model = FlaxAutoModelForSequenceClassification.from_pretrained(
    model_checkpoint, config=config, seed=seed
)
state = TrainState.create(
    apply_fn=model.__call__,
    params=model.params,
    tx=adamw(weight_decay=0.01),
    logits_function=eval_function,
    loss_function=loss_function,
)

Some weights of the model checkpoint at bert-base-cased were not used when initializing FlaxBertForSequenceClassification: {('cls', 'predictions', 'bias'), ('cls', 'predictions', 'transform', 'dense', 'kernel'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'scale'), ('cls', 'predictions', 'transform', 'dense', 'bias')}
- This IS expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of FlaxBertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased 

In [232]:
@jax.jit
def train_step(state, batch, dropout_rng):
    targets = batch.pop("labels")
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

    def loss_function(params):
        logits = state.apply_fn(
            **batch, params=params, dropout_rng=dropout_rng, train=True
        )[0]
        loss = state.loss_function(logits, targets)
        return loss

    grad_function = jax.value_and_grad(loss_function)
    loss, grad = grad_function(state.params)
    new_state = state.apply_gradients(grads=grad)
    metrics = {"loss": loss, "learning_rate": learning_rate_function(state.step)}
    return new_state, metrics, new_dropout_rng

In [233]:
@jax.jit
def eval_step(state, batch):
    logits = state.apply_fn(**batch, params=state.params, train=False)[0]
    return state.logits_function(logits)

In [234]:
num_devices = len(jax.local_devices())
devices = mesh_utils.create_device_mesh((num_devices,))

# Data will be split along the batch axis
data_mesh = Mesh(devices, axis_names=("batch",))  # naming axes of the mesh
data_sharding = NamedSharding(
    data_mesh,
    P(
        "batch",
    ),
)  # naming axes of the sharded partition


def glue_train_data_loader(rng, dataset, batch_size):
    steps_per_epoch = len(dataset) // batch_size
    perms = jax.random.permutation(rng, len(dataset))
    perms = perms[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    perms = perms.reshape((steps_per_epoch, batch_size))

    for perm in perms:
        batch = dataset[perm]
        batch = {
            k: jax.device_put(jnp.array(v), data_sharding) for k, v in batch.items()
        }

        yield batch


def glue_eval_data_loader(dataset, batch_size):
    for i in range(len(dataset) // batch_size):
        batch = dataset[i * batch_size : (i + 1) * batch_size]
        batch = {
            k: jax.device_put(jnp.array(v), data_sharding) for k, v in batch.items()
        }

        yield batch

In [235]:
# Replicate the model and optimizer variable on all devices
def get_replicated_train_state(devices, state):
    # All variables will be replicated on all devices
    var_mesh = Mesh(devices, axis_names=("_"))
    # In NamedSharding, axes not mentioned are replicated (all axes here)
    var_replication = NamedSharding(var_mesh, P())

    # Apply the distribution settings to the model variables
    state = jax.device_put(state, var_replication)

    return state


state = get_replicated_train_state(devices, state)

In [236]:
rng = jax.random.PRNGKey(seed)
dropout_rng = jax.random.PRNGKey(seed)

In [237]:
for i, epoch in enumerate(
    tqdm(range(1, num_train_epochs + 1), desc=f"Epoch ...", position=0, leave=True)
):
    rng, input_rng = jax.random.split(rng)

    # train
    with tqdm(
        total=len(train_dataset) // total_batch_size, desc="Training...", leave=True
    ) as progress_bar_train:
        for batch in glue_train_data_loader(input_rng, train_dataset, total_batch_size):
            state, train_metrics, dropout_rng = train_step(state, batch, dropout_rng)
            progress_bar_train.update(1)

    # evaluate
    with tqdm(
        total=len(eval_dataset) // total_batch_size, desc="Evaluating...", leave=False
    ) as progress_bar_eval:
        for batch in glue_eval_data_loader(eval_dataset, total_batch_size):
            labels = batch.pop("labels")
            predictions = eval_step(state, batch)
            metric.add_batch(predictions=list(predictions), references=list(labels))
            progress_bar_eval.update(1)

    eval_metric = metric.compute()

    loss = round(train_metrics["loss"].item(), 3)
    eval_score1 = round(list(eval_metric.values())[0], 3)
    metric_name1 = list(eval_metric.keys())[0]
    eval_score2 = round(list(eval_metric.values())[1], 3)
    metric_name2 = list(eval_metric.keys())[1]
    print(
        f"{i+1}/{num_train_epochs} | Train loss: {loss} | Eval {metric_name1}: {eval_score1}, {metric_name2}: {eval_score2}"
    )

Epoch ...:   0%|          | 0/6 [00:00<?, ?it/s]

Training...:   0%|          | 0/71 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]

1/6 | Train loss: 0.469 | Eval accuracy: 0.796, f1: 0.759


Training...:   0%|          | 0/71 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]

2/6 | Train loss: 0.376 | Eval accuracy: 0.833, f1: 0.788


Training...:   0%|          | 0/71 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]

3/6 | Train loss: 0.296 | Eval accuracy: 0.844, f1: 0.795


Training...:   0%|          | 0/71 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]

4/6 | Train loss: 0.267 | Eval accuracy: 0.846, f1: 0.805


Training...:   0%|          | 0/71 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]

5/6 | Train loss: 0.263 | Eval accuracy: 0.848, f1: 0.804


Training...:   0%|          | 0/71 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]

6/6 | Train loss: 0.222 | Eval accuracy: 0.849, f1: 0.805
