# Finetune a BERT Text Classifier with LightningTrainer

This is an advanced example for LightningTrainer, which demonstrates how to use LightningTrainer with Dataset.

If you just want to quickly convert your existing PyTorch Lightning scripts into Ray AIR, you can refer to this starter example: Train a Pytorch Lightning Image Classifier.

Source: https://docs.ray.io/en/latest/train/examples/lightning/lightning_cola_advanced.html

In [2]:
import ray
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset, load_metric
import numpy as np

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


## Pre-process CoLA Dataset

In [5]:
dataset = load_dataset("glue", "cola")
metric = load_metric("glue", "cola")

ray_datasets = ray.data.from_huggingface(dataset)

2023-09-07 17:06:53,841	INFO worker.py:1612 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m


In [6]:
from ray.data.preprocessors import BatchMapper

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")


def tokenize_sentence(batch):
    encoded_sent = tokenizer(
        batch["sentence"].tolist(),
        max_length=128,
        truncation=True,
        padding="max_length",
        return_tensors="pt",
    )
    batch["input_ids"] = encoded_sent["input_ids"].numpy()
    batch["attention_mask"] = encoded_sent["attention_mask"].numpy()
    batch["label"] = np.array(batch["label"])
    batch.pop("sentence")
    return batch


preprocessor = BatchMapper(tokenize_sentence, batch_format="numpy")

Downloading (…)okenizer_config.json: 100%|██████████| 29.0/29.0 [00:00<00:00, 3.10kB/s]
Downloading (…)lve/main/config.json: 100%|██████████| 570/570 [00:00<00:00, 59.8kB/s]
Downloading (…)solve/main/vocab.txt: 100%|██████████| 213k/213k [00:00<00:00, 395kB/s]
Downloading (…)/main/tokenizer.json: 100%|██████████| 436k/436k [00:11<00:00, 38.2kB/s]


## Define a PyTorch Lightning Model

In [7]:
class SentimentModel(pl.LightningModule):
    def __init__(self, lr=2e-5, eps=1e-8):
        super().__init__()
        self.lr = lr
        self.eps = eps # epsilon
        self.num_classes = 2
        self.model = AutoModelForSequenceClassification.from_pretrained(
            "bert-base-cased", num_labels=self.num_classes
        )
        self.metric = load_metric("glue", "cola")
        self.predictions = []
        self.references = []

    def forward(self, batch):
        input_ids, attention_mask = batch["input_ids"], batch["attention_mask"]
        outputs = self.model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        return logits

    def training_step(self, batch, batch_idx):
        labels = batch["label"]
        logits = self.forward(batch)
        loss = F.cross_entropy(logits.view(-1, self.num_classes), labels)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        labels = batch["label"]
        logits = self.forward(batch)
        preds = torch.argmax(logits, dim=1)
        self.predictions.append(preds)
        self.references.append(labels)

    def on_validation_epoch_end(self):
        predictions = torch.concat(self.predictions).view(-1)
        references = torch.concat(self.references).view(-1)
        matthews_correlation = self.metric.compute(
            predictions=predictions, references=references
        )

        # self.metric.compute() returns a dictionary:
        # e.g. {"matthews_correlation": 0.53}
        self.log_dict(matthews_correlation, sync_dist=True)
        self.predictions.clear()
        self.references.clear()

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.lr, eps=self.eps)

## Configure your LightningTrainer

In [8]:
from ray.train.lightning import LightningTrainer, LightningConfigBuilder
from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig

# Define the configs for LightningTrainer
lightning_config = (
    LightningConfigBuilder()
    .module(cls=SentimentModel, lr=1e-5, eps=1e-8)
    .trainer(max_epochs=5, accelerator="gpu")
    .checkpointing(save_on_train_epoch_end=False)
    .build()
)

In [9]:
# Save AIR checkpoints according to the performance on validation set
run_config = RunConfig(
    name="ptl-sent-classification",
    checkpoint_config=CheckpointConfig(
        num_to_keep=2,
        checkpoint_score_attribute="matthews_correlation",
        checkpoint_score_order="max",
    ),
)

# Scale the DDP training workload across 4 GPUs
# You can change this config based on your compute resources.
scaling_config = ScalingConfig(
    num_workers=1, use_gpu=True, resources_per_worker={"CPU": 15, "GPU": 1}
)

## Fine-tune the model with LightningTrainer

In [10]:
trainer = LightningTrainer(
    lightning_config=lightning_config,
    run_config=run_config,
    scaling_config=scaling_config,
    datasets={"train": ray_datasets["train"], "val": ray_datasets["validation"]},
    datasets_iter_config={"batch_size": 32},
    preprocessor=preprocessor,
)
result = trainer.fit()

0,1
Current time:,2023-09-07 17:38:50
Running for:,00:19:03.66
Memory:,8.1/30.9 GiB

Trial name,status,loc
LightningTrainer_11eaa_00000,RUNNING,192.168.33.188:12028


[2m[36m(LightningTrainer pid=12028)[0m The `preprocessor` arg to Trainer is deprecated. Apply preprocessor transformations ahead of time by calling `preprocessor.transform(ds)`. Support for the preprocessor arg will be dropped in a future release.
[2m[36m(LightningTrainer pid=12028)[0m Starting distributed worker processes: ['12071 (192.168.33.188)']
[2m[36m(RayTrainWorker pid=12071)[0m Setting up process group for: env:// [rank=0, world_size=1]
Downloading model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]
Downloading model.safetensors:   2%|▏         | 10.5M/436M [00:44<29:50, 238kB/s]
Downloading model.safetensors:   2%|▏         | 10.5M/436M [01:00<29:50, 238kB/s]
Downloading model.safetensors:   5%|▍         | 21.0M/436M [01:01<18:42, 369kB/s]
Downloading model.safetensors:   5%|▍         | 21.0M/436M [01:20<18:42, 369kB/s]
Downloading model.safetensors:   7%|▋         | 31.5M/436M [01:20<15:31, 434kB/s]
Downloading model.safetensors:   7%|▋         | 31.5M/436

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  2.60it/s]


[2m[36m(RayTrainWorker pid=12071)[0m   rank_zero_warn(
[2m[36m(RayTrainWorker pid=12071)[0m Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(BatchMapper._transform_numpy)] -> AllToAllOperator[RandomizeBlockOrder]
[2m[36m(RayTrainWorker pid=12071)[0m Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
[2m[36m(RayTrainWorker pid=12071)[0m Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`
                                                                                                                                     

Epoch 0: : 1it [00:00,  1.07it/s, v_num=0]
Epoch 0: : 2it [00:01,  1.36it/s, v_num=0]
Epoch 0: : 3it [00:02,  1.48it/s, v_num=0]
Epoch 0: : 4it [00:02,  1.57it/s, v_num=0]
Epoch 0: : 5it [00:03,  1.61it/s, v_num=0]
Epoch 0: : 6it [00:03,  1.64it/s, v_num=0]
Epoch 0: : 7it [00:04,  1.66it/s, v_num=0]
Epoch 0: : 8it [00:04,  1.68it/s, v_num=0]
Epoch 0: : 9it [00:05,  1.69it/s, v_num=0]
Epoch 0: : 10it [00:05,  1.70it/s, v_num=0]
Epoch 0: : 11it [00:06,  1.71it/s, v_num=0]
Epoch 0: : 12it [00:06,  1.72it/s, v_num=0]
Epoch 0: : 13it [00:07,  1.73it/s, v_num=0]
Epoch 0: : 14it [00:08,  1.73it/s, v_num=0]
Epoch 0: : 15it [00:08,  1.73it/s, v_num=0]
Epoch 0: : 16it [00:09,  1.74it/s, v_num=0]
Epoch 0: : 17it [00:09,  1.74it/s, v_num=0]
Epoch 0: : 18it [00:10,  1.74it/s, v_num=0]
Epoch 0: : 19it [00:10,  1.75it/s, v_num=0]
Epoch 0: : 20it [00:11,  1.75it/s, v_num=0]
Epoch 0: : 21it [00:11,  1.75it/s, v_num=0]
Epoch 0: : 22it [00:12,  1.75it/s, v_num=0]
Epoch 0: : 23it [00:13,  1.75it/s, v_num=



Epoch 0: : 254it [03:16,  1.29it/s, v_num=0]
Epoch 0: : 255it [03:17,  1.29it/s, v_num=0]
Epoch 0: : 256it [03:18,  1.29it/s, v_num=0]
Epoch 0: : 257it [03:19,  1.29it/s, v_num=0]
Epoch 0: : 258it [03:19,  1.29it/s, v_num=0]
Epoch 0: : 259it [03:20,  1.29it/s, v_num=0]
Epoch 0: : 260it [03:21,  1.29it/s, v_num=0]
Epoch 0: : 261it [03:22,  1.29it/s, v_num=0]
Epoch 0: : 262it [03:22,  1.29it/s, v_num=0]
Epoch 0: : 263it [03:23,  1.29it/s, v_num=0]
Epoch 0: : 264it [03:24,  1.29it/s, v_num=0]
Epoch 0: : 265it [03:25,  1.29it/s, v_num=0]
Epoch 0: : 266it [03:26,  1.29it/s, v_num=0]


2023-09-07 17:39:00,443	INFO tune.py:1148 -- Total run time: 1153.77 seconds (1143.66 seconds for the tuning loop).
Resume training with: Trainer.restore(path="/home/mpp/ray_results/ptl-sent-classification", ...)
- /home/mpp/ray_results/ptl-sent-classification/LightningTrainer_11eaa_00000_0_2023-09-07_17-19-46


[2m[36m(LightningTrainer pid=12028)[0m Traceback (most recent call last):
[2m[36m(LightningTrainer pid=12028)[0m   File "python/ray/_raylet.pyx", line 1364, in ray._raylet.execute_task.function_executor
[2m[36m(LightningTrainer pid=12028)[0m   File "/home/mpp/miniconda3/envs/ray-torch/lib/python3.9/site-packages/ray/_private/function_manager.py", line 726, in actor_method_executor
[2m[36m(LightningTrainer pid=12028)[0m     return method(__ray_actor, *args, **kwargs)
[2m[36m(LightningTrainer pid=12028)[0m   File "/home/mpp/miniconda3/envs/ray-torch/lib/python3.9/site-packages/ray/util/tracing/tracing_helper.py", line 464, in _resume_span
[2m[36m(LightningTrainer pid=12028)[0m     return method(self, *_args, **_kwargs)
[2m[36m(LightningTrainer pid=12028)[0m   File "/home/mpp/miniconda3/envs/ray-torch/lib/python3.9/site-packages/ray/tune/trainable/trainable.py", line 372, in train
[2m[36m(LightningTrainer pid=12028)[0m     result = self.step()
[2m[36m(LightningTr

In [8]:
result

Result(
  metrics={'_report_on': 'validation_end', 'train_loss': 0.08543746918439865, 'matthews_correlation': 0.5930452712523209, 'epoch': 4, 'step': 2675, 'should_checkpoint': True, 'done': True, 'trial_id': '9082c_00000', 'experiment_tag': '0'},
  path='/home/dino/ray_results/ptl-sent-classification/LightningTrainer_9082c_00000_0_2023-09-06_23-51-03',
  checkpoint=LightningCheckpoint(local_path=/home/dino/ray_results/ptl-sent-classification/LightningTrainer_9082c_00000_0_2023-09-06_23-51-03/checkpoint_000004)
)