# Fine-tuning BERT for entity labeling
This notebook contains starter code for finetuning a BERT-style model for the task of entity recognition.

In [None]:
!pip install protobuf==3.20.2
!pip install transformers
!pip install datasets
!pip install evaluate
!pip install seqeval

In [None]:
# This code block just contains standard setup code for running in Python
import time

# PyTorch imports
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset #random_split
import numpy as np

# Fix the random seed(s) for reproducability
torch.random.manual_seed(8942764)
torch.cuda.manual_seed(8942764)
np.random.seed(8942764)


# Set the device 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Alternatively if you're working on a MAC with at least a M1 chip you can set the device as follows
device = torch.device('mps' if torch.backends.mps.device else 'cpu')

In [None]:
from transformers import AutoTokenizer, BertModel, DataCollatorForTokenClassification

import evaluate

In [None]:
# Load the dataset
from datasets import ClassLabel, Sequence, load_dataset

data_splits = load_dataset('json', data_files={'train': 'dinos_and_deities_train_bio.jsonl', 'dev': 'dinos_and_deities_dev_bio_sm.jsonl', 'test': 'dinos_and_deities_test_bio_nolabels.jsonl'})

In [None]:
label_names_fname = "dinos_and_deities_train_bio.jsonl.labels"
labels_int2str = []
with open(label_names_fname) as f:
    labels_int2str = f.read().split()
print(f"Labels: {labels_int2str}")

In [None]:
labels_str2int = {l: i for i, l in enumerate(labels_int2str)}

data_splits.cast_column("ner_tags", Sequence(ClassLabel(names=labels_int2str)))
print(data_splits)

In [None]:
# initialize pretrained BERT tokenizer. This might take a while the first time it's run because the model needs to be downloaded.
# Note: if you change the BERT model later, don't forget to also change this!!
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

In [None]:
# Check out some examples from the dataset
print(data_splits["train"][8])
print(data_splits["dev"][5])

In [None]:
# This dataset is split into a train, validation and test set, and each token has a label.
# Data from the dataset can generally be accessed like a Python dict.
print(data_splits['train'].features)

# Print the original sentence (which is whitespace tokenized).
example_input_tokens = data_splits['train'][8]['tokens']
print(f"Original tokens: {example_input_tokens}")

# Print the labels of the sentence.
example_ner_labels = data_splits['train'][8]['ner_tags']
print(f"NER labels: {example_ner_labels}")

# Map integer to string labels for the sentence
example_mapped_labels = [labels_int2str[l] for l in example_ner_labels]
print(f'Labels: {example_mapped_labels}')

# Print the sentence split into tokens.
example_tokenized = tokenizer(example_input_tokens, is_split_into_words=True)
print('BERT Tokenized: ', example_tokenized.tokens())

# Print the number of tokens in the vocabulary
print(f'Vocab size: {tokenizer.vocab_size}')

# # Print the sentence mapped to token ids.
print('Token IDs: ', tokenizer.convert_tokens_to_ids(example_tokenized.tokens()))

# Of course, there are now way more tokens than labels! Fortunately the HF tokenizer
# provides a function that will give us the mapping:
print(example_tokenized.word_ids())

In [None]:
# Function that uses that along with the original labels to get the new set of labels
# for each BERT-tokenized token.
def align_labels_with_tokens(labels, word_ids):
    new_labels = []
    current_word = None
    for word_id in word_ids:
        if word_id != current_word:
            # Start of a new word!
            current_word = word_id
            label = -100 if word_id is None else labels[word_id]
            new_labels.append(label)
        elif word_id is None:
            # Special token
            new_labels.append(-100)
        else:
            # Same word as previous token
            label = labels[word_id]
            str_label = labels_int2str[label]
            if str_label[0] == 'B':
                new_str_label = 'I' + str_label[1:]
                label = labels_str2int[new_str_label]
            new_labels.append(label)

    return new_labels

In [None]:
tokenizer_aligned_labels = align_labels_with_tokens(example_ner_labels, example_tokenized.word_ids())
print(f'Aligned labels: {tokenizer_aligned_labels}')
print(f'Mapped aligned labels: {[labels_int2str[l] if l >= 0 else "_" for l in tokenizer_aligned_labels]}')

In [None]:
# Let's check the function on the example from before. 
# The special tokens don't have labels, so we'll just replace those with _
aligned_labels = align_labels_with_tokens(example_ner_labels, example_tokenized.word_ids())
print(f"Tokens: {example_tokenized.tokens()}")
print(f"Aligned labels: {[labels_int2str[l] if l >= 0 else '_' for l in aligned_labels]}")

In [None]:
# Need to get the whole dataset into this format, so need to write a fn
# we can apply efficiently across all examples using Dataset.map.
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"], truncation=True, is_split_into_words=True
    )
    all_labels = examples["ner_tags"]
    new_labels = []
    for i, labels in enumerate(all_labels):
        word_ids = tokenized_inputs.word_ids(i)
        new_labels.append(align_labels_with_tokens(labels, word_ids))

    tokenized_inputs["labels"] = new_labels
    return tokenized_inputs

In [None]:
# Now we can apply that fn to tokenize all the data
tokenized_data_splits = data_splits.map(
    tokenize_and_align_labels,
    batched=True,
    remove_columns=data_splits["train"].column_names,
)

In [None]:
# Testing batcher
print("Examples:")
for i in range(2):
    print(tokenized_data_splits["train"][i]["labels"])

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
batch = data_collator([tokenized_data_splits["train"][i] for i in range(2)])

In [None]:
# Evaluation: we can use the seqeval library to handle calculating span-level precision, recall and F1
metric = evaluate.load("seqeval")

labels = data_splits["train"][0]["ner_tags"]
labels = [labels_int2str[i] for i in labels]
print(labels)

predictions = labels.copy()
predictions[0] = "O"
metric.compute(predictions=[predictions], references=[labels])

In [None]:
def postprocess(predictions, labels):
    true_labels = [[labels_int2str[l] for l in label if l != -100] for label in labels]
    true_predictions = [[labels_int2str[p] for (p, l) in zip(prediction, label) if l!=-100] for prediction, label in zip(predictions, labels)]
    return true_labels, true_predictions

In [None]:
# This code runs evaluation on test data.

@torch.no_grad()
def run_eval(model, dataset, batch_size, device, collate_fn=None):
    model.eval().to(device)
    dataloader = DataLoader(dataset, batch_size, shuffle =False, collate_fn=collate_fn)

    lossfn = nn.NLLLoss()
    val_loss = []

    for i, batch in enumerate(dataloader):
        batch = {k:v.to(device) for k,v in batch.items() if isinstance(v, torch.Tensor)}
        y = batch.pop('labels')

        logits = model(**batch)
        B, T, C = logits.shape
        loss = lossfn(logits.reshape(-1, C), y.reshape(-1))

        pred = logits.argmax(-1)
        val_loss.append(loss.item())

        true_labels, true_predictions = postprocess(pred, y)
        metric.add_batch(predictions = true_predictions, references= true_labels)

    results = metric.compute()
    validation_loss = np.mean(val_loss)
    return validation_loss, results

In [None]:
@torch.no_grad()
def valid(model, dataset, batch_size, device, collate_fn=None):
    model.eval().to(device)
    dataloader = DataLoader(dataset, batch_size, shuffle =False, collate_fn=collate_fn)

    lossfn = nn.NLLLoss()
    eval_preds, eval_labels = [], []
    eval_loss = []

    for i, batch in enumerate(dataloader):
        batch = {k:v.to(device) for k,v in batch.items() if isinstance(v, torch.Tensor)}
        y = batch.pop('labels')

        logits = model(**batch)

        B, T, C = logits.shape
        loss = lossfn(logits.reshape(-1, C), y.reshape(-1))

        flattened_targets = y.reshape(-1)
        active_logits = logits.reshape(-1, C)
        flattened_predictions = torch.argmax(active_logits, axis = -1)

        active_accuracy = y.view(-1) != -100

        labels = torch.masked_select(flattened_targets, active_accuracy)
        predictions = torch.masked_select(flattened_predictions, active_accuracy)

        eval_labels.extend(labels)
        eval_preds.extend(predictions)
        eval_loss.append(loss.item())

    eval_loss = np.mean(eval_loss)

    labels = [labels_int2str[i] for i in eval_labels]
    predictions = [labels_int2str[i] for i in eval_preds]

    results = metric.compute(predictions=[predictions], references=[labels])

    return eval_loss, results

In [None]:
# This code trains the model and evaluates it on test data. It should print
# progress messages during training indicating loss, accuracy and training speed.

def train(model,
          train_dataset,
          val_dataset,
          num_epochs,
          batch_size,
          optimizer_cls,
          lr,
          weight_decay,
          device,
          collate_fn=None,
          log_every=100):
  
    model = model.train().to(device)
    dataloader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=collate_fn)

    if optimizer_cls == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif optimizer_cls == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif optimizer_cls == 'AdamW':
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    train_loss_history = []
    train_acc_history = []
    val_loss_history = []
    val_acc_history = []

    lossfn = nn.NLLLoss()
    for e in range(num_epochs):
        model.train(True)
        epoch_loss_history = []
        epoch_acc_history = []
        start_time = time.time()
        for i, batch in enumerate(dataloader):
            batch = {k:v.to(device) for k,v in batch.items() if isinstance(v, torch.Tensor)}
            y = batch.pop('labels')
          
            logits = model(**batch)
            loss = lossfn(logits, y)

            pred = logits.argmax(1)
            acc = (pred == y).float().mean()

            epoch_loss_history.append(loss.item())
            epoch_acc_history.append(acc.item())

            if (i % log_every == 0):
                speed = 0 if i == 0 else log_every/(time.time()-start_time)
                print(f'epoch: {e}\t iter: {i}\t train_loss: {np.mean(epoch_loss_history):.3e}\t train_acc:{np.mean(epoch_acc_history):.3f}\t speed:{speed:.3f} b/s')
                start_time = time.time()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
      
    val_loss, val_metrics = run_eval(model, val_dataset, batch_size, device, collate_fn=collate_fn, print_out=False)

    val_acc = val_metrics['overall_accuracy']
    val_p = val_metrics['overall_precision']
    val_r = val_metrics['overall_recall']
    val_f1 = val_metrics['overall_f1']

    train_loss_history.append(np.mean(epoch_loss_history))
    train_acc_history.append(np.mean(epoch_acc_history))
    val_loss_history.append(val_loss.item())
    val_acc_history.append(val_acc)
    print(f'epoch: {e}\t train_loss: {train_loss_history[-1]:.3e}\t train_accuracy:{train_acc_history[-1]:.3f}\t val_loss: {val_loss_history[-1]:.3e}\t val_acc:{val_acc_history[-1]:.3f}\t val_p:{val_p:.3f}\t val_r:{val_r:.3f}\t val_f1:{val_f1:.3f}')

    return model, (train_loss_history, train_acc_history, val_loss_history, val_acc_history)

In [None]:
# This code defines the token classification class using BERT.
# The classifier is defined on top of the final layer of BERT.
# The classifier has 1 hidden layer with 128 hidden nodes though we have found that
# using a smaller number of hidden nodes does not make much difference,
 
class BertForTokenClassification(nn.Module):
    def __init__(self, bert_pretrained_config_name, num_classes, freeze_bert=False, dropout_prob=0.1):
        '''
        BERT with a classification MLP
        args:
        - bert_pretrained_config_name (str): model name from huggingface hub
        - num_classes (int): number of classes in the classification task
        - freeze_bert (bool): [default False] If true gradients are not computed for
                              BERT's parameters.
        - dropout_prob (float): [default 0.1] probability of dropping each activation.
        '''
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_pretrained_config_name)
        self.bert.requires_grad_(not freeze_bert)
        self.classifier = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size, 128),
            nn.Tanh(),
            nn.Dropout(dropout_prob),
            nn.Linear(128, num_classes),
        )

    def forward(self, **bert_kwargs):
        output = self.bert(**bert_kwargs)
        logits = self.classifier(output.last_hidden_state)
        return logits

In [None]:
# multiply your learning rate by k when using batch size of kN
lr = 4*2e-5 
weight_decay = 0.01
epochs = 5
batch_size = 16
dropout_prob = 0.2
freeze_bert = False

In [None]:
# At the end of each epoch, you also see validation loss and validation accuracy.

# Make sure this is the same as you use for tokenization!
bert_model = 'bert-base-cased'

num_labels = len(labels_int2str)
print(f"Num labels: {num_labels}")

bert_cls = BertForTokenClassification(bert_model, num_labels, dropout_prob=dropout_prob, freeze_bert=freeze_bert)

print(f'Trainable parameters: {sum([p.numel() for p in bert_cls.parameters() if p.requires_grad])}\n')

# Flag for setting "debug" mode. Set debug to False for full training.
debug = False

In [None]:
# Sample a subset of the training data for faster iteration in debug mode
subset_size = 1000
subset_indices = torch.randperm(len(tokenized_data_splits['train']))[:subset_size]
train_subset = Subset(tokenized_data_splits['train'], subset_indices)

bert_cls, bert_cls_logs = train(bert_cls, tokenized_data_splits['train'] if not debug else train_subset, tokenized_data_splits['dev'],
                                num_epochs=epochs, batch_size=batch_size, optimizer_cls='AdamW',
                                lr=lr, weight_decay=weight_decay, device=device,
                                collate_fn=data_collator, log_every=10 if debug else 100)

# Final eval
final_loss, final_metrics = run_eval(bert_cls, tokenized_data_splits['dev'], batch_size=32, device=device, collate_fn=data_collator)
final_acc = final_metrics['overall_accuracy']
final_p = final_metrics['overall_precision']
final_r = final_metrics['overall_recall']
final_f1 = final_metrics['overall_f1']
print(f'\nFinal Loss: {final_loss:.3e}\t Final Accuracy: {final_acc:.3f}\t dev_p:{final_p:.3f}\t dev_r:{final_r:.3f}\t dev_f1:{final_f1:.3f}')

In [None]:
import matplotlib.pyplot as plt

In [None]:
loss  = bert_cls_logs[0]
acc   = bert_cls_logs[1]
val_loss = bert_cls_logs[2]
val_acc = bert_cls_logs[3]
epochs = range(len(loss))
print(epochs)
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and Validation loss')
plt.legend()