# BERT for Token Classification
In the last few notebooks, we learnt how to use BERT for a sequence classification task. In this notebook, we will finetune BERT for token-level classification tasks. Token-level classification tasks are tasks where each token in a sequence is tagged with a label. 

The notebook is organized in four sections. In the first section, we define the Named Entity Recognition (NER) task and go through the conll2003 NER dataset. In the next section, we will use HuggingFace's `dataset` library to easily download and process the dataset. In the third section, we will go through `LightningDataModule` that enables us to encapsulate all data-processing related procedures in one place. Finally, in the last section, we will implement a model using `PyTorch-Lightning`.

Topics covered:
- Token-level classification
- HuggingFace `datasets`
- PyTorch `LightningDataModule`
- `seqeval`



#### Install packages
As usual, install all the necessary packages:

In [2]:
!pip install --quiet transformers pytorch-lightning seqeval datasets
!pip install ipywidgets

Collecting ipywidgets
  Downloading ipywidgets-7.6.3-py2.py3-none-any.whl (121 kB)
[K     |████████████████████████████████| 121 kB 2.7 MB/s eta 0:00:01
Collecting jupyterlab-widgets>=1.0.0
  Downloading jupyterlab_widgets-1.0.0-py3-none-any.whl (243 kB)
[K     |████████████████████████████████| 243 kB 7.6 MB/s eta 0:00:01
Collecting widgetsnbextension~=3.5.0
  Downloading widgetsnbextension-3.5.1-py2.py3-none-any.whl (2.2 MB)
[K     |████████████████████████████████| 2.2 MB 16.8 MB/s eta 0:00:01
Installing collected packages: widgetsnbextension, jupyterlab-widgets, ipywidgets
Successfully installed ipywidgets-7.6.3 jupyterlab-widgets-1.0.0 widgetsnbextension-3.5.1


# Named Entity Recognition
Named Entity Recognition is a token-level classification task, where each token in a sentence is classified as an entity (loosely speaking, a proper name) or not. The recognized entities are then further classified into their corresponding types such as `PERSON`, `LOCATION`, `ORGANIZATION` etc. 

Example:
```
I want to fly from BLR to YYZ on 31 Dec, 2020
```

In the above example, `BLR`, `YYZ` are entities of type `LOCATION` and `31 Dec, 2020` is an entity of type `DATE`. All other tokens are non-entities of type `O` (others). 

If you are not already familiar with the task, take a look at this [video](https://www.youtube.com/watch?v=5SUzf6252_0).


## Dataset
For our implementation, we will be using the `conll2003` dataset first introduced in [this](https://www.aclweb.org/anthology/W03-0419.pdf) paper. The dataset supports the recognition of four entity types: `PER` (PERSON), `LOC` (LOCATION), `ORG` (ORGANIZATION), and `MISC` (MISCELLANEOUS). The following example taken from the dataset demonstrates the format:
```
U.N. NNP I-NP I-ORG
official NN I-NP O
Ekeus NNP I-NP I-PER
heads VBZ I-VP O
for IN I-PP O
Baghdad NNP I-NP I-LOC
. . O O
```

Each line has four columns in the following order: token information, part-of-speech tagging, chunk tag, and NER tag. The NER tags are of the format `B-XXX`, `I-XXX`, or `O`, where `XXX` represents an entity type (ex: `PER`).
- `B-XXX` represents a single entity or beginning of an entity if there are more than one token in an entity. 
For example, `Priyanka Chopra` would have the tags `B-PER I-PER` or `Toronto` will be tagged as `B-LOC`.
- All non-entities will be tagged as `O`.


# HuggingFace `datasets`
HuggingFace's `datasets` library provides an easy way to download and use common datasets and metrics. It has several advantages including interoperability support for PyTorch and memory management while working with large datasets.  You can find the list of supported datasets can also be found using the `list_dataset()` method.


To load a dataset, simiply use the `load_dataset` method from the library.




In [3]:
from datasets import load_dataset
dataset = load_dataset("conll2003")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2603.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1781.0, style=ProgressStyle(description…


Downloading and preparing dataset conll2003/conll2003 (download: 4.63 MiB, generated: 9.78 MiB, post-processed: Unknown size, total: 14.41 MiB) to /Users/kbak/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/40e7cb6bcc374f7c349c83acd1e9352a4f09474eb691f64f364ee62eb65d0ca6...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset conll2003 downloaded and prepared to /Users/kbak/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/40e7cb6bcc374f7c349c83acd1e9352a4f09474eb691f64f364ee62eb65d0ca6. Subsequent calls will reuse this data.


Looking into the dataset object, we see that it has three splits: `train`, `validation`, and `test`. Each split is of type `Dataset` with the four fields: `id`, `tokens`, `pos_tags`, `chunk_tags`, `ner_tags`

In [4]:
dataset

DatasetDict({
    train: Dataset({
        features: ['chunk_tags', 'id', 'ner_tags', 'pos_tags', 'tokens'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['chunk_tags', 'id', 'ner_tags', 'pos_tags', 'tokens'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['chunk_tags', 'id', 'ner_tags', 'pos_tags', 'tokens'],
        num_rows: 3453
    })
})

The datasets in each of the splits need to be processed (tokenization, padding etc). That is, we want to apply a series of actions/functions for each example in the split. This is done using the `map()` method of the `Dataset` class as defined [here](https://huggingface.co/docs/datasets/_modules/datasets/arrow_dataset.html#Dataset.map):
- first, we remove unimportant columns in the dataset using the argument `remove_columns`.
- we then process each example using `convert_to_features` method. This method (just like `convert_examples_to_features` from our last notebook), which we will soon define, maps `tokens` to ids, adds special characters, looks after padding, and determines the right attention mask. 

```python
dataset["train"] = dataset["train"].map(
                      convert_to_features,
                      remove_columns=["id", "chunk_tags", "pos_tags"]
                   )
```
 


## Converting examples to features

We can leverage `encode_plus()` method from the HuggingFace tokenizer class to easily prepare features for the model.
 `encode_plus()` automatically adds the special tokens (`CLS`, `SEP`) at the beginning  and end of the sequence respectively.
- `max_length`: maximum length of the sequence input into the model, beyond which the sequence will be truncated (`truncation=True`).
- `encode_plus` returns `input_ids`, `attention_mask`, `token_type_ids` by default. We don't need `token_type_ids` for our implementation, and we can specify this by `return_token_type_ids=False`.
- `pad_to_max_length`: Pad up to `max_length` of the sequence.
- `is_split_into_words`: If the input sequence is already tokenized, we need to set this flag as `True`.


Since the first token in a sequence is always `CLS` and since tokens with lengths smaller than `max_length` are padded with a `PAD` token, their corresponding predictions should be ignored. Therefore, we set the corresponding labels with a `PAD_LABEL_TOKEN` (set to `-100` later in our implementation). This is so that the predictions for tokens that have a `PAD_LABEL_TOKEN` label can be easily ignored during evaluation. 

```python
def convert_to_features(example):
    # convert tokens to ids
    features = tokenizer.encode_plus(
        example["tokens"],
        max_length=MAX_LEN,
        pad_to_max_length=True,
        truncation=True,
        return_token_type_ids=False,
        is_split_into_words=True,
    )

    # pad labels with PAD_LABEL_TOKEN
    features["labels"] = [PAD_LABEL_TOKEN] + example["ner_tags"]
    padding_length = MAX_LEN - len(features["labels"])
    features["labels"] += [PAD_LABEL_TOKEN] * padding_length

    return features
```



# `LightningDataModule`

PyTorch Lightning also provides a `LightningDataModule` module that strctures all the steps needed to process a dataset. Since all data-related implementation is encapsulated in one place, it helps in easy reproducibility of experiments. When sharing your model implementation, the other ML practitioners also don't have to look at multiple places to determine how exactly the data was processed.

We define a PyTorch data lightning module by inheriting `pl.LightningDataModule`. Just like PyTorch Lightning, the `LightningDataModule` module comes with a pre-defined template and it is important to use these pre-defined function names to process the dataset. 


Alright, let's dive right into it!
- `train_dataloader()`:
returns the train dataloader

- `test_dataloader()`:
returns the test dataloader

- `val_dataloader()`:
returns the validation dataloader

- `prepare_data()`:
Based on the [LightningDataModule documentation](https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html#prepare-data), implementation corresponding to data download, tokenization should go here. In a distributed-GPU setting, this part of the code will be run on a single GPU rather than all the GPUs. In our implementation, we download the dataset and the tokenizer in this function.

__Note:__ Please note that the code cells below are just markdown and are non-executable. The executable code is placed under __`LightningDataModule`: putting it all together__ section

```python
def prepare_data(self):
    load_dataset("conll2003")
    AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)
```

- `setup()`:
In contrast to `prepare_data()`, the `setup()` method should implement actions that will be run on every single GPU.  In our implementation, we go through each split in our dataset, and convert tokens to input ids.

```python
def setup(self):
    self.dataset = load_dataset("conll2003")

    for split in self.dataset.keys():
        self.dataset[split] = self.dataset[split].map(
            self.convert_to_features,
            remove_columns=["id", "chunk_tags", "pos_tags"],
        )
        column_names = ["input_ids", "attention_mask", "labels"]
        self.dataset[split].set_format(type="torch", columns=column_names)
```


## `LightningDataModule`: putting it all together

In [5]:
import pytorch_lightning as pl
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from datasets import load_dataset
from torch.utils.data.dataset import random_split
import torch


class NERDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32):
        super().__init__()
        self.batch_size = batch_size
        self.max_len = 128
        self.label2idx = {
            tag: idx
            for idx, tag in enumerate(dataset["train"].features["ner_tags"].feature.names)
        }
        self.idx2label = {v:k for k, v in self.label2idx.items()}
        self.pad_token_label_id = -100


    def setup(self):
      # load dataset and tokenizer
        self.dataset = load_dataset("conll2003")
        self.tokenizer = AutoTokenizer.from_pretrained(
            "bert-base-uncased",
            use_fast=True
        )

      # convert examples in each split to features
        for split in self.dataset.keys():
            self.dataset[split] = self.dataset[split].map(
                self.convert_to_features,
                remove_columns=["id", "chunk_tags", "pos_tags"],
            )
            column_names = ["input_ids", "attention_mask", "labels"]
            self.dataset[split].set_format(type="torch", columns=column_names)

    def prepare_data(self):
        # download conll2003 dataset and BERT tokenizer
        load_dataset("conll2003")
        AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)

    def train_dataloader(self):
        # return train dataloader
        return DataLoader(self.dataset["train"], batch_size=self.batch_size)

    def val_dataloader(self):
        # return validation dataloader
        return DataLoader(self.dataset["validation"], batch_size=self.batch_size)

    def test_dataloader(self):
        # return test dataloader
        return DataLoader(self.dataset["test"], batch_size=self.batch_size)

    def convert_to_features(self, example):
        # convert tokens to ids
        features = self.tokenizer.encode_plus(
            example["tokens"],
            is_split_into_words=True,
            max_length=self.max_len,
            pad_to_max_length=True,
            truncation=True,
            return_token_type_ids=False
        )

        # Pad labels with pad_label_token
        # since first token is a `CLS` token, pad with token label id
        features["labels"] = [self.pad_token_label_id] + example["ner_tags"]
        padding_length = self.max_len - len(features["labels"])
        features["labels"] += [self.pad_token_label_id] * padding_length

        return features

  rank_zero_deprecation(


Create a `LightningDataModule` object and process data. As discussed earlier, the `prepare_data()` method should be called before the `setup()` method since it downloads the dataset and the tokenizer.

In [6]:
ner_data = NERDataModule()
ner_data.prepare_data()
ner_data.setup()

Reusing dataset conll2003 (/Users/kbak/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/40e7cb6bcc374f7c349c83acd1e9352a4f09474eb691f64f364ee62eb65d0ca6)
Reusing dataset conll2003 (/Users/kbak/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/40e7cb6bcc374f7c349c83acd1e9352a4f09474eb691f64f364ee62eb65d0ca6)


HBox(children=(FloatProgress(value=0.0, max=14041.0), HTML(value='')))






HBox(children=(FloatProgress(value=0.0, max=3250.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=3453.0), HTML(value='')))




# Model

## `AutoModelForTokenClassification`:


Since the task at hand is a token-level classification task, we will load BERT model with `AutoModelForTokenClassification`. 

```python
model = AutoModelForTokenClassification.from_pretrained(
            "bert-base-uncased",
            num_labels=NUM_LABELS,
        )
```

## `configure_optimizers()`:

In this implementation, we will use `AdamW` optimizer instead of `SGD` because  `AdamW` helps in faster convergence. Let's also set weight decay for all parameters to `0.01` except for `bias` terms and Layer Norm.

```python
no_decay = ["bias", "LayerNorm.weight"]

optimizer_grouped_parameters = [
    {
      # apply weight decay to all parameters except the ones in no_decay
      "params": [
        p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
      ],
      "weight_decay_rate": 0.01
    },
      # set weight decay to 0.0 for items in no_decay
    {
      "params": [
        p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        "weight_decay_rate": 0.0
    },
]

optimizer = AdamW(
    optimizer_grouped_parameters,
    lr=3e-5,
    eps=1e-8
)
```


## `validation_epoch_end()`:

One of the other pre-defined method of PyTorch Lightning is `validation_epoch_end`. PyTorch Lightning calls the `validation_epoch_end` method when the entire validation dataset is processed. The items returned by `validaion_step()` method are collected and passed as a list to the `validation_epoch_end` method once the model has gone through the entire validation dataset. The `outputs` argument below therefore is a list of dictionaries; each dictionary containing predicted classes (`pred`) and target classes (`target`).
- first we concatenate all predictions, and for each token, we pick the class with maximum probability
- similarly, we concatenate all target/true labels
- we ignore predictions for tokens that have special characters by checking if the corresponding label_id is equal to padded token label id.
- finally, we compute f1 score and accuray.

```python
def validation_epoch_end(self, outputs):
    # Concatenate all predictions
    preds = np.concatenate([x["pred"] for x in outputs], axis=0)
    preds = np.argmax(preds, axis=2)
    # Concatenate all target labels
    out_label_ids = np.concatenate([x["target"] for x in outputs], axis=0)

    out_label_list = [[] for _ in range(out_label_ids.shape[0])]
    preds_list = [[] for _ in range(out_label_ids.shape[0])]

    for i in range(out_label_ids.shape[0]):
        for j in range(out_label_ids.shape[1]):
            # ignore predictions for tokens that have `pad_token_label_id`
            if out_label_ids[i, j] != self.dataset.pad_token_label_id:
                out_label_list[i].append(self.dataset.idx2label[out_label_ids[i][j]])
                preds_list[i].append(self.dataset.idx2label[preds[i][j]])

    # calculate effective measures using seqeval's f1_score and accuracy_score
    f1 = f1_score(out_label_list, preds_list)
    acc = accuracy_score(out_label_list, preds_list)
```

#### Evaluation with `seqeval`:
`seqeval` is a python library used for sequence evaluation tasks such as NER. Let's follow the documentation [here](https://github.com/chakki-works/seqeval) and calculate accuracy and f1 score for a sequence of NER predictions:

```python
from seqeval.metrics import accuracy_score, f1_score

y_true = [["O", "O", "O", "B-MISC", "I-MISC", "I-MISC", "O"], ["B-PER", "I-PER", "O"]]
y_pred = [["O", "O", "B-MISC", "I-MISC", "I-MISC", "I-MISC", "O"], ["B-PER", "I-PER", "O"]]
print(f1_score(y_true, y_pred)) #prints 0.5
print(accuracy_score(y_true, y_pred)) #prints 0.8

```


# Final Model implementation

In [7]:
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import f1_score
from seqeval.metrics import accuracy_score, f1_score
import numpy as np
from transformers import AutoModelForTokenClassification, AdamW
from pytorch_lightning import Trainer


class BertNerClassification(pl.LightningModule):
    def __init__(
        self,
        dataset,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.dataset = dataset
        self.model = AutoModelForTokenClassification.from_pretrained(
                    "bert-base-uncased",
                    num_labels=len(self.dataset.label2idx),
                )

    def forward(self, **inputs):
        outputs = self.model(**inputs)

        return outputs

    def configure_optimizers(self):
      param_optimizer = list(self.named_parameters())
      no_decay = ["bias", "LayerNorm.weight"]
      optimizer_grouped_parameters = [
          {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
          "weight_decay_rate": 0.01},
          {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
          "weight_decay_rate": 0.0}
      ]

      optimizer = AdamW(
          optimizer_grouped_parameters,
          lr=3e-5,
          eps=1e-8
      )

      return optimizer

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs[0]
        
        self.log(
            "train_loss",
            loss,
            on_epoch=True,
            on_step=True,
            prog_bar=True,
            logger=True
        )

        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        outputs = self(**batch)
        val_loss, logits = outputs[:2]
        preds = logits.detach().cpu().numpy()
        out_label_ids = batch["labels"].detach().cpu().numpy()
        
        self.log(
            "val_loss",
            val_loss,
            on_epoch=True,
            on_step=True,
            prog_bar=True,
            logger=True
        )

        return {"pred": preds, "target": out_label_ids}


    def validation_epoch_end(self, outputs):
        preds = np.concatenate([x["pred"] for x in outputs], axis=0)
        preds = np.argmax(preds, axis=2)
        out_label_ids = np.concatenate([x["target"] for x in outputs], axis=0)

        out_label_list = [[] for _ in range(out_label_ids.shape[0])]
        preds_list = [[] for _ in range(out_label_ids.shape[0])]

        for i in range(out_label_ids.shape[0]):
            for j in range(out_label_ids.shape[1]):
                # ignore predictions for tokens that have `pad_token_label_id`
                if out_label_ids[i, j] != self.dataset.pad_token_label_id:
                    out_label_list[i].append(self.dataset.idx2label[out_label_ids[i][j]])
                    preds_list[i].append(self.dataset.idx2label[preds[i][j]])

        # calculate effective measures using seqeval's f1_score and accuracy_score
        f1 = f1_score(out_label_list, preds_list)
        acc = accuracy_score(out_label_list, preds_list)

        self.log(
            "f1",
            f1,
            on_epoch=True,
            prog_bar=True,
            logger=True
        )
        
        self.log(
            "acc",
            acc,
            on_epoch=True,
            prog_bar=True,
            logger=True
        )


# Running the model
Create a model object

In [8]:
bert_ner_model = BertNerClassification(ner_data)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForTokenClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-u

## Combining data + model
In our previous implementations, `pytorch-lightning` would look for train dataloader in the `train_dataloader()` method. Since we have already defined the dataset using `LightningDataModule` module, we can simply pass the `LightningDataModule` object to the `fit()` method and PyTorch Lighning takes care of loading train, valid, and test dataloaders.

In [None]:
#uncomment the below if you are running on a GPU (which is ideal, a CPU could take forever)
trainer = Trainer(max_epochs=5)#, gpus="0")
trainer.fit(bert_ner_model, ner_data)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | BertForTokenClassification | 108 M 
-----------------------------------------------------
108 M     Trainable params
0         Non-trainable params
108 M     Total params
435.594   Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

  rank_zero_warn(
  rank_zero_warn(


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

## Inference

Load the best model from checkpoint. Your checkpoint number might be different from what you see in the following cell.

In [None]:
# Load the best model and perform inference
model = BertNerClassification.load_from_checkpoint(
    "lightning_logs/version_0/checkpoints/epoch=4.ckpt",
)
#Set model in eval mode
model.eval()

Set the input text in the format required by the `convert_to_features()` method. The input sentence will be tokenized---`["My", "name", "is", "Don", "Quixote"]`.  Also remember that `ner_tags` in the `conll2003` dataset area already numericalized. Therefore, the labels `["O", "O", "O", "O", "B-PER", "I-PER"]` are mapped to `[0, 0, 0, 1, 2]`.

In [None]:
input_text = {
    "tokens": ["My", "name", "is", "Don", "Quixote"],
    # ["O", "O", "O", "O", "B-PER", "I-PER"]
    "ner_tags": [0, 0, 0, 1, 2]
}

Extract features corresponding to the input example:

In [None]:
features = ner_data.convert_to_features(input_text)

Convert features to tensors and add a new dimension along 0-axis to indicate batching. 



In [None]:
features["input_ids"] = torch.tensor(features["input_ids"]).unsqueeze(0)
features["attention_mask"] = torch.tensor(features["attention_mask"]).unsqueeze(0)
features["labels"] = torch.tensor(features["labels"]).unsqueeze(0)

Make predictions:

In [None]:
predictions = model(**features)

Obtain the classes with maximum predictions

In [None]:
preds = np.argmax(predictions[1].detach().numpy(), axis=2)
# convert List[List[int]] -> List[int] and a numpy array
preds = preds.squeeze()

Go through the predictions and ignore the predictions that have `pad_token_label_id`.

In [None]:
targets = features["labels"].squeeze()
preds_list = []

for i in range(targets.shape[0]):
  # ignore predictions for tokens that have `pad_token_label_id`
  if targets[i] != ner_data.pad_token_label_id:

      preds_list.append(ner_data.idx2label[preds[i]])

Here is the final predicted entities:

In [None]:
preds_list