This notebook finetunes the Luke model on the named entity recognition (NER) task: identifying organizations mentioned in news. It consists of two parts:


1.   Clean and preprocess data for the model's use
2.   Finetune the model



**Part 1: data cleaning & preprocessing**

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import os
import pandas as pd
import unicodedata
import numpy as np
import spacy
import gc
from tqdm import tqdm, trange
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from collections import defaultdict

In [4]:
data = pd.read_csv('/content/drive/MyDrive/capstone/data.csv')

In [5]:
data = data.sort_values(['Sentence', 'Company A', 'Company B'])
all_sentences = data['Sentence'].tolist()
company_a = data['Company A'].tolist()
company_a = [co.strip() for co in company_a]
company_b = data['Company B'].tolist()
company_b = [co.strip() for co in company_b]

In [6]:
'''
In the provided dataset, some examples share the same sentence.
For the NER task, we need to combine those examples and label all organizations that occur in the same sentence.
'''
orgs = []
sent_orgs = set([company_a[0], company_b[0]])
sentences = [all_sentences[0].strip()]
for i, sent in enumerate(all_sentences):
    if i == 0:
        continue
    if sent != all_sentences[i - 1]:
        sentences.append(sent.strip())
        assert(len(sent_orgs) >= 2)
        orgs.append(sent_orgs)
        sent_orgs = set()
    sent_orgs.add(company_a[i])
    sent_orgs.add(company_b[i])
assert(len(sent_orgs) >= 2)
orgs.append(sent_orgs)
assert(len(sentences) == len(orgs))

In [7]:
'''
Define global variables, including model hyperparameters that are relevant
during the dataset creation phase.
'''
label_list = ['O', 'ORG']
id2label = {i: label for i, label in enumerate(label_list)}
label2id = {label: i for i, label in enumerate(label_list)}

SEED = 0
# each example contains at most MAX_ENTITY_LENGTH entities
MAX_ENTITY_LENGTH = 1024
# the model only considers spans that are at most ENTITY_WORD_LIMIT words
ENTITY_WORD_LIMIT = 10
# each example contains at most MAX_LENGTH tokens 
# here, "token" is at sub-word level, created from spacy tokenization & LukeTokenizer
MAX_LENGTH = 512

In [8]:
nlp = spacy.load("en_core_web_sm")
docs = list(nlp.pipe(sentences))

In [9]:
'''
Create a list of dict (one dict per doc) from the docs (created by spacy) and their corresponding organizations.
Each dict consists of the following:
text, 
words in the text, 
all spans to be considered by the model indexed at character level ("entity_spans") and word level ("original_word_spans"),
labels that correspond to each span (as indices in label_list).
'''
examples = []
skipped = 0
total_org_count = 0
org_word_len = []
for doc_i, doc in tqdm(enumerate(docs)):
    text = doc.text
    entity_spans = []
    original_word_spans = []
    words = [token.text for token in doc]
    labels = []
    org_occur = set()
    for token_start in doc:
        if len(entity_spans) == MAX_ENTITY_LENGTH:
            break
        for token_end in doc[token_start.i:token_start.i + ENTITY_WORD_LIMIT]:
            entity_start, entity_end = token_start.idx, token_end.idx + len(token_end)
            entity_spans.append((entity_start, entity_end))
            original_word_spans.append((token_start.i, token_end.i + 1))
            span = text[entity_start:entity_end]
            if span in orgs[doc_i]:
                labels.append(1)
                org_word_len.append(token_end.i - token_start.i + 1)
                org_occur.add(span)
            else:
                labels.append(0)
            # once MAX_ENTITY_LENGTH spans are accumulated for an example, we stop adding additional examples
            # data exploratory analysis shows that this truncation leads us to lose 2 valid examples
            if len(entity_spans) == MAX_ENTITY_LENGTH:
                break
    # if an example does not have at least 2 unique organization appearances, we exclude it
    if len(org_occur) < 2:
        skipped += 1
        continue
    total_org_count += sum(labels)
    examples.append(dict(text=text, words=words, entity_spans=entity_spans, original_word_spans=original_word_spans, labels=labels))
print(f"\n{skipped} sentences were skipped since less than 2 organizations were found in these sentences.\n" 
      + f"{len(examples)} sentences containing {total_org_count} organization occurrences remain.")
print(f"The longest organization contains {max(org_word_len)} words.")

253it [00:00, 694.85it/s]


26 sentences were skipped since less than 2 organizations were found in these sentences.
227 sentences containing 658 organization occurrences remain.
The longest organization contains 8 words.





In [10]:
train_examples, val_examples = train_test_split(examples, test_size=0.2, random_state=SEED)

**Part 2: model finetuning**

In [11]:
!pip install transformers
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [12]:
import torch
from transformers import LukeTokenizer, LukeForEntitySpanClassification, Trainer, TrainingArguments, set_seed
from datasets import Dataset, load_metric

In [13]:
set_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [14]:
# may want to experiment with "studio-ousia/luke-large-finetuned-conll-2003"
tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", max_entity_length=MAX_ENTITY_LENGTH, task="entity_span_classification", cache_dir=".cache/")

In [16]:
train_dataset = Dataset.from_list(train_examples)
val_dataset = Dataset.from_list(val_examples)

In [17]:
'''
Use LukeTokenizer to tokenize the dataset
'''
def tokenize(example):
    tokenized_inputs = tokenizer(
        text=example['text'], 
        entity_spans=[tuple(span) for span in example['entity_spans']], 
        max_length=MAX_LENGTH,
        max_entity_length=MAX_ENTITY_LENGTH,
        truncation=True,
        return_tensors='pt'
    )

    for k, v in tokenized_inputs.items():
        if isinstance(v, torch.Tensor) and v.shape[0] == 1:
            tokenized_inputs[k] = v.squeeze()

    tokenized_inputs["labels"] = torch.tensor(example['labels'])
    return tokenized_inputs

In [18]:
tokenized_train_data = train_dataset.map(tokenize)
tokenized_val_data = val_dataset.map(tokenize)

  0%|          | 0/181 [00:00<?, ?ex/s]

  0%|          | 0/46 [00:00<?, ?ex/s]

In [19]:
'''
Compute precision, recall, f1, and accuracy (for evaluation on the validation dataset during training)
'''
def compute_metrics(p):
    precision = load_metric("precision")
    recall = load_metric("recall")
    f1 = load_metric("f1")
    accuracy = load_metric("accuracy")

    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions, true_labels = [], []
    for prediction, label in zip(predictions, labels):
        for p, l in zip(prediction, label):
            # -100 represents special tokens [CLS] and [SEP] that should be ignored in metric computation
            if l != -100:
                true_predictions.append(p)
                true_labels.append(l)

    precision_score = precision.compute(predictions=true_predictions, references=true_labels, average="macro")["precision"]
    recall_score = recall.compute(predictions=true_predictions, references=true_labels, average="macro")["recall"]
    f1_score = f1.compute(predictions=true_predictions, references=true_labels, average="macro")["f1"]
    accuracy_score = accuracy.compute(predictions=true_predictions, references=true_labels)["accuracy"]
    return {"precision": precision_score, "recall": recall_score, "f1": f1_score, "accuracy": accuracy_score}

In [20]:
# may want to experiment with "studio-ousia/luke-large-finetuned-conll-2003"
model = LukeForEntitySpanClassification.from_pretrained("studio-ousia/luke-base", ignore_mismatched_sizes=True, id2label=id2label, label2id=label2id, cache_dir=".cache/").to(device)

Some weights of the model checkpoint at studio-ousia/luke-base were not used when initializing LukeForEntitySpanClassification: ['entity_predictions.transform.LayerNorm.weight', 'entity_predictions.bias', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.dense.weight', 'entity_predictions.transform.dense.bias', 'entity_predictions.transform.dense.weight', 'lm_head.layer_norm.bias', 'entity_predictions.transform.LayerNorm.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing LukeForEntitySpanClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LukeForEntitySpanClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LukeForEntitySpanClassificati

In [21]:
EPOCHS = 20
LR = 1e-5
WD = 0.01
BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 8

training_args = TrainingArguments(
    # change folder name here, to avoid replacing the previous model's outputs
    output_dir="output/model-name/", 
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    learning_rate=LR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=EPOCHS,
    weight_decay=WD,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    load_best_model_at_end=True
)

In [22]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_data,
    eval_dataset=tokenized_val_data,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [23]:
# to prevent CUDA out of memory issues -- if they still exist, restart runtime (but first, download important files!)
torch.cuda.empty_cache()
gc.collect() # run until the number < 100

131

In [24]:
# define checkpoint to resume training from, if needed
CKPT = None
train_result = trainer.train(resume_from_checkpoint=CKPT)
trainer.save_model()
trainer.save_state()

The following columns in the training set don't have a corresponding argument in `LukeForEntitySpanClassification.forward` and have been ignored: original_word_spans, words, entity_spans, text. If original_word_spans, words, entity_spans, text are not expected by `LukeForEntitySpanClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 181
  Num Epochs = 20
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 8
  Total optimization steps = 440
  Number of trainable parameters = 274506754


Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
0,0.1095,0.034757,0.496585,0.5,0.498287,0.993171
1,0.0342,0.023238,0.496585,0.5,0.498287,0.993171
2,0.0242,0.019504,0.924239,0.621385,0.68782,0.994546
3,0.0199,0.017882,0.831333,0.707617,0.755051,0.994593
4,0.0168,0.016785,0.85032,0.714657,0.765954,0.994878
5,0.0145,0.015961,0.816418,0.7421,0.774231,0.994593
6,0.0137,0.015538,0.78279,0.78669,0.784726,0.994119
7,0.012,0.016468,0.868059,0.697439,0.756643,0.994925
8,0.0112,0.015352,0.804218,0.804218,0.804218,0.994688
9,0.0099,0.015535,0.814449,0.773207,0.792356,0.994736


The following columns in the evaluation set don't have a corresponding argument in `LukeForEntitySpanClassification.forward` and have been ignored: original_word_spans, words, entity_spans, text. If original_word_spans, words, entity_spans, text are not expected by `LukeForEntitySpanClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 46
  Batch size = 1
  """
  _warn_prf(average, modifier, msg_start, len(result))
Saving model checkpoint to output/ner-base-20/checkpoint-22
Configuration saved in output/ner-base-20/checkpoint-22/config.json
Model weights saved in output/ner-base-20/checkpoint-22/pytorch_model.bin
tokenizer config file saved in output/ner-base-20/checkpoint-22/tokenizer_config.json
Special tokens file saved in output/ner-base-20/checkpoint-22/special_tokens_map.json
added tokens file saved in output/ner-base-20/checkpoint-22/added_tokens.json
The following columns in the evaluation set don't have a corresponding ar

In [None]:
'''
Predict on the validation dataset and save output in CONLL format
'''
# change output_path first! save in drive, and then go to https://drive.google.com/drive/my-drive to download it
output_path = "/content/drive/MyDrive/capstone/model-name-val-pred.conll"

# if the model has finished training normally, the best checkpoint (based on validation loss) is loaded in the end
# however, if it did not finish training normally,
# or if you prefer to load the model from another epoch, maybe one with a higher f1, uncomment this line:
# model = LukeForEntitySpanClassification.from_pretrained("output/model-name/checkpoint-xxx").to(device) 

text_list = [e["text"] for e in val_dataset]
words_list = [e["words"] for e in val_dataset]
original_word_spans_list = [e["original_word_spans"] for e in val_dataset]
entity_spans_list = []
for e in val_dataset:
    entity_spans_list.append([tuple(span) for span in e["entity_spans"]])

with open(output_path, "w", encoding="utf-8") as f:
    for i in range(len(text_list)):
        text = text_list[i]
        words = words_list[i]
        entity_spans = entity_spans_list[i]
        original_word_spans = original_word_spans_list[i]

        inputs = tokenizer(
            text=text, 
            entity_spans=entity_spans, 
            max_length=MAX_LENGTH,
            max_entity_length=MAX_ENTITY_LENGTH,
            truncation=True,
            return_tensors='pt'
        )
        inputs = inputs.to(device)
        with torch.no_grad():
            outputs = model(**inputs)

        logits = outputs.logits
        max_logits, max_indices = logits[0].max(dim=1)

        predictions = []
        for logit, index, span in zip(max_logits, max_indices, original_word_spans):
            if index != 0:  # the span is not NIL
                predictions.append((logit, span, model.config.id2label[int(index)]))

        # construct an IOB2 label sequence
        predicted_sequence = ["O"] * len(words)
        for _, span, label in sorted(predictions, key=lambda o: o[0], reverse=True):
            if all([o == "O" for o in predicted_sequence[span[0] : span[1]]]):
                predicted_sequence[span[0]] = "B-" + label
                if span[1] - span[0] > 1:
                    predicted_sequence[span[0] + 1 : span[1]] = ["I-" + label] * (span[1] - span[0] - 1)

        for token, label in zip(words, predicted_sequence):
            f.write(f"{token} {label}\n")
        f.write("\n")

In [26]:
# zip the best checkpoint and save it in drive (it could take a while)
# then go to drive, download it, delete it, and find it in trash to delete it forever (files in trash also take up space)
!zip -r /content/drive/MyDrive/capstone/model-name.zip output/model-name/checkpoint-xxx

  adding: output/ner-base-20/checkpoint-198/ (stored 0%)
  adding: output/ner-base-20/checkpoint-198/vocab.json (deflated 68%)
  adding: output/ner-base-20/checkpoint-198/special_tokens_map.json (deflated 85%)
  adding: output/ner-base-20/checkpoint-198/pytorch_model.bin (deflated 8%)
  adding: output/ner-base-20/checkpoint-198/merges.txt (deflated 53%)
  adding: output/ner-base-20/checkpoint-198/scheduler.pt (deflated 49%)
  adding: output/ner-base-20/checkpoint-198/optimizer.pt (deflated 63%)
  adding: output/ner-base-20/checkpoint-198/rng_state.pth (deflated 27%)
  adding: output/ner-base-20/checkpoint-198/config.json (deflated 52%)
  adding: output/ner-base-20/checkpoint-198/training_args.bin (deflated 48%)
  adding: output/ner-base-20/checkpoint-198/trainer_state.json (deflated 77%)
  adding: output/ner-base-20/checkpoint-198/entity_vocab.json (deflated 65%)
  adding: output/ner-base-20/checkpoint-198/tokenizer_config.json (deflated 77%)
  adding: output/ner-base-20/checkpoint-198