
# Finetune Transformers Models with PyTorch Lightning

* **Author:** PL team
* **License:** CC BY-SA
* **Generated:** 2021-06-28T09:27:48.748750

This notebook will use HuggingFace's `datasets` library to get data, which will be wrapped in a `LightningDataModule`.
Then, we write a class to perform text classification on any dataset from the [GLUE Benchmark](https://gluebenchmark.com/).
(We just show CoLA and MRPC due to constraint on compute/disk)


---
Open in [![Open In Colab](https://colab.research.google.com/assets/colab-badge.png){height="20px" width="117px"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/text-transformers.ipynb)

Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)
| Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)
| Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)

In [1]:
from datetime import datetime
from typing import Optional

import datasets
import torch
from pytorch_lightning import LightningDataModule, LightningModule, seed_everything, Trainer
from torch.utils.data import DataLoader
from transformers import (
    AdamW,
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)

AVAIL_GPUS = min(1, torch.cuda.device_count())

## Training BERT with Lightning

### Lightning DataModule for GLUE

In [2]:
class GLUEDataModule(LightningDataModule):

    task_text_field_map = {
        'cola': ['sentence'],
        'sst2': ['sentence'],
        'mrpc': ['sentence1', 'sentence2'],
        'qqp': ['question1', 'question2'],
        'stsb': ['sentence1', 'sentence2'],
        'mnli': ['premise', 'hypothesis'],
        'qnli': ['question', 'sentence'],
        'rte': ['sentence1', 'sentence2'],
        'wnli': ['sentence1', 'sentence2'],
        'ax': ['premise', 'hypothesis']
    }

    glue_task_num_labels = {
        'cola': 2,
        'sst2': 2,
        'mrpc': 2,
        'qqp': 2,
        'stsb': 1,
        'mnli': 3,
        'qnli': 2,
        'rte': 2,
        'wnli': 2,
        'ax': 3
    }

    loader_columns = [
        'datasets_idx', 'input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions',
        'labels'
    ]

    def __init__(
        self,
        model_name_or_path: str,
        task_name: str = 'mrpc',
        max_seq_length: int = 128,
        train_batch_size: int = 32,
        eval_batch_size: int = 32,
        **kwargs
    ):
        super().__init__()
        self.model_name_or_path = model_name_or_path
        self.task_name = task_name
        self.max_seq_length = max_seq_length
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size

        self.text_fields = self.task_text_field_map[task_name]
        self.num_labels = self.glue_task_num_labels[task_name]
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)

    def setup(self, stage: str):
        self.dataset = datasets.load_dataset('glue', self.task_name)

        for split in self.dataset.keys():
            self.dataset[split] = self.dataset[split].map(
                self.convert_to_features,
                batched=True,
                remove_columns=['label'],
            )
            self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]
            self.dataset[split].set_format(type="torch", columns=self.columns)

        self.eval_splits = [x for x in self.dataset.keys() if 'validation' in x]

    def prepare_data(self):
        datasets.load_dataset('glue', self.task_name)
        AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)

    def train_dataloader(self):
        return DataLoader(self.dataset['train'], batch_size=self.train_batch_size)

    def val_dataloader(self):
        if len(self.eval_splits) == 1:
            return DataLoader(self.dataset['validation'], batch_size=self.eval_batch_size)
        elif len(self.eval_splits) > 1:
            return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]

    def test_dataloader(self):
        if len(self.eval_splits) == 1:
            return DataLoader(self.dataset['test'], batch_size=self.eval_batch_size)
        elif len(self.eval_splits) > 1:
            return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]

    def convert_to_features(self, example_batch, indices=None):

        # Either encode single sentence or sentence pairs
        if len(self.text_fields) > 1:
            texts_or_text_pairs = list(
                zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])
            )
        else:
            texts_or_text_pairs = example_batch[self.text_fields[0]]

        # Tokenize the text/text pairs
        features = self.tokenizer.batch_encode_plus(
            texts_or_text_pairs, max_length=self.max_seq_length, pad_to_max_length=True, truncation=True
        )

        # Rename label to labels to make it easier to pass to model forward
        features['labels'] = example_batch['label']

        return features

**You could use this datamodule with standalone PyTorch if you wanted...**

In [3]:
dm = GLUEDataModule('distilbert-base-uncased')
dm.prepare_data()
dm.setup('fit')
next(iter(dm.train_dataloader()))

Reusing dataset glue (/home/edmundlylee/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Reusing dataset glue (/home/edmundlylee/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Loading cached processed dataset at /home/edmundlylee/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9e81c8360baad98a.arrow
100%|██████████| 1/1 [00:00<00:00, 60.31ba/s]
100%|██████████| 2/2 [00:00<00:00, 35.47ba/s]


{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'input_ids': tensor([[  101,  2572,  3217,  ...,     0,     0,     0],
         [  101,  9805,  3540,  ...,     0,     0,     0],
         [  101,  2027,  2018,  ...,     0,     0,     0],
         ...,
         [  101,  1996,  2922,  ...,     0,     0,     0],
         [  101,  6202,  1999,  ...,     0,     0,     0],
         [  101, 16565,  2566,  ...,     0,     0,     0]]),
 'labels': tensor([1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1,
         1, 1, 0, 0, 1, 1, 1, 0])}

### Transformer LightningModule

In [4]:
class GLUETransformer(LightningModule):

    def __init__(
        self,
        model_name_or_path: str,
        num_labels: int,
        task_name: str,
        learning_rate: float = 2e-5,
        adam_epsilon: float = 1e-8,
        warmup_steps: int = 0,
        weight_decay: float = 0.0,
        train_batch_size: int = 32,
        eval_batch_size: int = 32,
        eval_splits: Optional[list] = None,
        **kwargs
    ):
        super().__init__()

        self.save_hyperparameters()

        self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_name_or_path, config=self.config
        )
        self.metric = datasets.load_metric(
            'glue', self.hparams.task_name, experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S")
        )

    def forward(self, **inputs):
        return self.model(**inputs)

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs[0]
        return loss

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        outputs = self(**batch)
        val_loss, logits = outputs[:2]

        if self.hparams.num_labels >= 1:
            preds = torch.argmax(logits, axis=1)
        elif self.hparams.num_labels == 1:
            preds = logits.squeeze()

        labels = batch["labels"]

        return {'loss': val_loss, "preds": preds, "labels": labels}

    def validation_epoch_end(self, outputs):
        if self.hparams.task_name == 'mnli':
            for i, output in enumerate(outputs):
                # matched or mismatched
                split = self.hparams.eval_splits[i].split('_')[-1]
                preds = torch.cat([x['preds'] for x in output]).detach().cpu().numpy()
                labels = torch.cat([x['labels'] for x in output]).detach().cpu().numpy()
                loss = torch.stack([x['loss'] for x in output]).mean()
                self.log(f'val_loss_{split}', loss, prog_bar=True)
                split_metrics = {
                    f"{k}_{split}": v
                    for k, v in self.metric.compute(predictions=preds, references=labels).items()
                }
                self.log_dict(split_metrics, prog_bar=True)
            return loss

        preds = torch.cat([x['preds'] for x in outputs]).detach().cpu().numpy()
        labels = torch.cat([x['labels'] for x in outputs]).detach().cpu().numpy()
        loss = torch.stack([x['loss'] for x in outputs]).mean()
        self.log('val_loss', loss, prog_bar=True)
        self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)
        return loss

    def setup(self, stage=None) -> None:
        if stage != 'fit':
            return
        # Get dataloader by calling it - train_dataloader() is called after setup() by default
        train_loader = self.train_dataloader()

        # Calculate total steps
        tb_size = self.hparams.train_batch_size * max(1, self.trainer.gpus)
        ab_size = self.trainer.accumulate_grad_batches * float(self.trainer.max_epochs)
        self.total_steps = (len(train_loader.dataset) // tb_size) // ab_size

    def configure_optimizers(self):
        """Prepare optimizer and schedule (linear warmup and decay)"""
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(
            optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon
        )

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.hparams.warmup_steps,
            num_training_steps=self.total_steps,
        )
        scheduler = {'scheduler': scheduler, 'interval': 'step', 'frequency': 1}
        return [optimizer], [scheduler]

## Training

#### CoLA

See an interactive view of the
CoLA dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=cola)

In [5]:
seed_everything(42)

dm = GLUEDataModule(model_name_or_path='albert-base-v2', task_name='cola')
dm.setup('fit')
model = GLUETransformer(
    model_name_or_path='albert-base-v2',
    num_labels=dm.num_labels,
    eval_splits=dm.eval_splits,
    task_name=dm.task_name,
)

trainer = Trainer(max_epochs=3, gpus=AVAIL_GPUS)
trainer.fit(model, dm)

Global seed set to 42
Reusing dataset glue (/home/edmundlylee/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
100%|██████████| 9/9 [00:00<00:00, 41.15ba/s]
100%|██████████| 2/2 [00:00<00:00, 58.86ba/s]
100%|██████████| 2/2 [00:00<00:00, 53.88ba/s]
Some weights of the model checkpoint at albert-base-v2 were not used when initializing AlbertForSequenceClassification: ['predictions.decoder.bias', 'predictions.LayerNorm.weight', 'predictions.decoder.weight', 'predictions.bias', 'predictions.LayerNorm.bias', 'predictions.dense.bias', 'predictions.dense.weight']
- This IS expected if you are initializing AlbertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AlbertForSequenceClassification from the checkpoint of a model that you expect

Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                      

  mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)
Global seed set to 42


Epoch 0:   0%|          | 1/301 [00:00<00:28, 10.67it/s]  

  rank_zero_warn(


Epoch 1:   0%|          | 0/301 [00:00<00:00, 2228.64it/s, loss=0.62, v_num=0, val_loss=0.608, matthews_correlation=0.000] 

  mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)


Epoch 2:   0%|          | 0/301 [00:00<00:00, 1833.98it/s, loss=0.614, v_num=0, val_loss=0.608, matthews_correlation=0.000] 

  mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)


Epoch 2: 100%|██████████| 301/301 [00:44<00:00,  6.81it/s, loss=0.614, v_num=0, val_loss=0.608, matthews_correlation=0.000]


  mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)


#### MRPC

See an interactive view of the
MRPC dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=mrpc)

In [6]:
seed_everything(42)

dm = GLUEDataModule(
    model_name_or_path='distilbert-base-cased',
    task_name='mrpc',
)
dm.setup('fit')
model = GLUETransformer(
    model_name_or_path='distilbert-base-cased',
    num_labels=dm.num_labels,
    eval_splits=dm.eval_splits,
    task_name=dm.task_name
)

trainer = Trainer(max_epochs=3, gpus=AVAIL_GPUS)
trainer.fit(model, dm)

Global seed set to 42
Downloading: 100%|██████████| 29.0/29.0 [00:00<00:00, 30.9kB/s]
Downloading: 100%|██████████| 411/411 [00:00<00:00, 316kB/s]
Downloading: 100%|██████████| 213k/213k [00:00<00:00, 540kB/s]
Downloading: 100%|██████████| 436k/436k [00:00<00:00, 1.06MB/s]
Reusing dataset glue (/home/edmundlylee/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
100%|██████████| 4/4 [00:00<00:00, 36.94ba/s]
100%|██████████| 1/1 [00:00<00:00, 61.51ba/s]
100%|██████████| 2/2 [00:00<00:00, 39.06ba/s]
Downloading: 100%|██████████| 263M/263M [00:22<00:00, 11.5MB/s]
Some weights of the model checkpoint at distilbert-base-cased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the ch

                                                              

  rank_zero_warn(
Global seed set to 42


Epoch 0:   0%|          | 0/128 [00:00<00:00, 3075.00it/s]  

  rank_zero_warn(


Epoch 2: 100%|██████████| 128/128 [00:09<00:00, 13.28it/s, loss=0.61, v_num=1, val_loss=0.614, accuracy=0.684, f1=0.812]


#### MNLI

 - The MNLI dataset is huge, so we aren't going to bother trying to train on it here.
 - We will skip over training and go straight to validation.

See an interactive view of the
MRPC dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=mnli)

In [7]:
dm = GLUEDataModule(
    model_name_or_path='distilbert-base-cased',
    task_name='mnli',
)
dm.setup('fit')
model = GLUETransformer(
    model_name_or_path='distilbert-base-cased',
    num_labels=dm.num_labels,
    eval_splits=dm.eval_splits,
    task_name=dm.task_name
)

trainer = Trainer(gpus=AVAIL_GPUS, progress_bar_refresh_rate=20)
trainer.validate(model, dm.val_dataloader())

Downloading and preparing dataset glue/mnli (download: 298.29 MiB, generated: 78.65 MiB, post-processed: Unknown size, total: 376.95 MiB) to /home/edmundlylee/.cache/huggingface/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...


Downloading: 100%|██████████| 313M/313M [00:34<00:00, 9.17MB/s]
  1%|          | 3/393 [00:00<00:18, 20.80ba/s]

Dataset glue downloaded and prepared to /home/edmundlylee/.cache/huggingface/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.


100%|██████████| 393/393 [00:10<00:00, 36.33ba/s]
100%|██████████| 10/10 [00:00<00:00, 40.68ba/s]
100%|██████████| 10/10 [00:00<00:00, 32.63ba/s]
100%|██████████| 10/10 [00:00<00:00, 43.74ba/s]
100%|██████████| 10/10 [00:00<00:00, 41.65ba/s]
Some weights of the model checkpoint at distilbert-base-cased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertFor

Validating:   0%|          | 0/615 [00:00<?, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Validating: 100%|██████████| 615/615 [00:12<00:00, 47.43it/s]--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'accuracy_matched': 0.32399389147758484,
 'accuracy_mismatched': 0.3184499740600586,
 'val_loss_matched': 1.104953408241272,
 'val_loss_mismatched': 1.1044032573699951}
--------------------------------------------------------------------------------
DATALOADER:1 VALIDATE RESULTS
{'accuracy_matched': 0.32399389147758484,
 'accuracy_mismatched': 0.3184499740600586,
 'val_loss_matched': 1.104953408241272,
 'val_loss_mismatched': 1.1044032573699951}
--------------------------------------------------------------------------------


[{'val_loss_matched': 1.104953408241272,
  'accuracy_matched': 0.32399389147758484,
  'val_loss_mismatched': 1.1044032573699951,
  'accuracy_mismatched': 0.3184499740600586},
 {'val_loss_matched': 1.104953408241272,
  'accuracy_matched': 0.32399389147758484,
  'val_loss_mismatched': 1.1044032573699951,
  'accuracy_mismatched': 0.3184499740600586}]

In [9]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6008 (pid 101745), started 0:00:59 ago. (Use '!kill 101745' to kill it.)