# Finetuning A Bert Model for Span Categorization

This notebook aims to demonstrate how to finetune a Span Categorization Model with BERT and use it to produce out-of-sample predicted probabilities for each token of each span class. These are required to find label issues in span classification dataset with cleanlab. The specific span classification task we consider here is Extractive Question Answering with the SQuAD dataset, and we train a Transformer model from HuggingFace's transformers library. This notebook demonstrates how to produce the `pred_probs`, using them to find label issues is demonstrated in `"find label errors in span classification dataset"` tutorial.

***Note: running this notebook requires the .py files in the same folder.***

Overview of what we'll do in this notebook:
    - Read and process an Extractive Question Answering dataset.
    - Compute out-of-sample predicted probability by training a BERT transformer model via cross-validation
    - Separate question and context tokens for more meaningful error detection

## 1. Load data and required dependencies

In [1]:
import os
from datasets import load_dataset
from sklearn.model_selection import KFold
import numpy as np

from cleanlab.internal.token_classification_utils import process_token

from spanbert import QA, QATrainer
from span_classification_tutorial_utils import to_dict

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
NUM_EXAMPLES = 100
NUM_SPLITS = 3
model_folder_path = "./folds"

raw_datasets = load_dataset("squad", split=f"train[:{NUM_EXAMPLES}]")

We will use the `context`, `question`, and `answers` fields in the dataset. Let's print the first example in our dataset.

In [3]:
print("Context: ", raw_datasets[0]["context"])
print("Question: ", raw_datasets[0]["question"])
print("Answers: ", raw_datasets[0]["answers"])

Context:  Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.
Question:  To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
Answers:  {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}


The `context` and `question` fields are straightforward to use. The `answers` field is in format that is used by common span classification datasets. The `text` field is the actual answer to the question in the context and the `answer_start` field contains the starting character index of the answer in the context.

In [3]:
label2id = {'O': 0, 'ANS': 1}
id2label = {v:k for k, v in label2id.items()}
max_length = 384

model_checkpoint = "bert-base-cased"
model = QA(model_checkpoint)
trainer = QATrainer(model, max_length, label2id)

Some weights of BertForSpanClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


No label map found, please train the model with the QATrainer class


We will be training a custom span classification that is based on the BERT model. Our custom span classification model will help us produce predicted probabilities needed for each token. You can check out more details about the model in the "spanbert.py" file.

## 2. Train Span Classification Model using Cross-Validation

To compute out-of-sample predicted probabilities (`pred_probs`) using cross validation, we first partition the dataset into `k = 3` disjoint folds and train one span classification model on each fold.

In [4]:
# Define the number of splits for cross-validation
kf = KFold(n_splits=NUM_SPLITS)
splits = kf.split(raw_datasets)

fold_to_ids = {} # store the validation ids for each fold
for i, (train_ids, val_ids) in enumerate(splits):
    fold_to_ids[i] = val_ids
    output_dir = os.path.join(model_folder_path, f"model_fold_{i}")
    if os.path.exists(output_dir):
        print(f"Model for fold {i} already exists, skipping training")
    else:
        train_ds = raw_datasets.select(train_ids)
        trainer.train(
            train_ds,
            output_dir,
            num_train_epochs=3, # 3 epochs is enough for demonstration
        )

Map: 100%|██████████| 66/66 [00:00<00:00, 586.33 examples/s]
100%|██████████| 27/27 [09:12<00:00, 20.47s/it]


{'train_runtime': 552.6949, 'train_samples_per_second': 0.358, 'train_steps_per_second': 0.049, 'train_loss': 0.08430206334149395, 'epoch': 3.0}


Map: 100%|██████████| 67/67 [00:00<00:00, 506.23 examples/s]
100%|██████████| 27/27 [11:06<00:00, 24.69s/it]


{'train_runtime': 666.4771, 'train_samples_per_second': 0.302, 'train_steps_per_second': 0.041, 'train_loss': 0.021081337222346553, 'epoch': 3.0}


Map: 100%|██████████| 67/67 [00:00<00:00, 396.85 examples/s]
100%|██████████| 27/27 [12:21<00:00, 27.46s/it]


{'train_runtime': 741.4954, 'train_samples_per_second': 0.271, 'train_steps_per_second': 0.036, 'train_loss': 0.016514805731949984, 'epoch': 3.0}


## 3. Compute Out-of-Sample Predicted Probabilities

We obtain the predicted class probabilities for each token using the model where this token was held out from the training set. From our custom QA model, we collect the predicted probabilities, tokenized version of question and context, and their the correct label.

In [5]:
sentence_tokens = {}
sentence_probs = {}
sentence_labels = {}

for i in range(NUM_SPLITS):
    model_path = os.path.join(model_folder_path, f"model_fold_{i}")
    model = QA(model_path)

    indices = fold_to_ids[i]
    val_ds = raw_datasets.select(indices)
    for i, index in enumerate(indices):
        sentence_probs[index], tokens, sentence_labels[index] = model.predict(val_ds[i])

        replace = [('#', ''), ('``', '"'), ("''", '"')]
        sentence_tokens[index] = [process_token(t, replace) for t in tokens]

sentence_tokens = [sentence_tokens[i] for i in range(NUM_EXAMPLES)]
sentence_probs = [sentence_probs[i] for i in range(NUM_EXAMPLES)]
sentence_labels = [sentence_labels[i] for i in range(NUM_EXAMPLES)]


## 4. Isolate Question and Context

During prediction we give the model both the question and the context. However, only the context portion have labels in the original SQuAD dataset. So, we separete the `pred_probs` for questions and context and only look at the context related `pred_probs`.

In [6]:
# in the list of tokens identify the index of the first [SEP] token
sentence_questions = [q["question"] for q in raw_datasets]

for i in range(len(sentence_tokens)):
    sep_idx = sentence_tokens[i].index('[SEP]')
    sentence_tokens[i] = sentence_tokens[i][sep_idx + 1:-1]
    sentence_probs[i] = sentence_probs[i][sep_idx + 1:-1]
    sentence_labels[i] = sentence_labels[i][sep_idx + 1:-1]

Our model predicted the `pred_probs` for both the answer span (`ANS`) and the other span (`O`). We only care about the `ANS` span tokens so we isolate their `pred_probs` below.

In [7]:
final_labels = []
for labels in sentence_labels:
    temp = [lab[label2id["ANS"]] for lab in labels]
    final_labels.append(temp)

final_pred_probs = []
for probs in sentence_probs:
    temp = [prob[label2id["ANS"]] for prob in probs]
    final_pred_probs.append(np.array(temp))

pred_probs_dict = to_dict(final_pred_probs)
labels_dict = to_dict(final_labels)
tokens_dict = to_dict(sentence_tokens)
question_dict = to_dict(sentence_questions)

Finally, we obtained properly formatted `pred_probs`, `labels`, `tokens`, and `questions` for use with `cleanlab.experimental.span_classification`. We use the `to_dict` function to convert them into a suitable format to save as `.npz` files.

In [8]:
np.savez('pred_probs.npz', **pred_probs_dict)
np.savez('labels.npz', **labels_dict)
np.savez('tokens.npz', **tokens_dict)
np.savez('questions.npz', **question_dict)