This notebook finetunes the Luke model on the named entity recognition (NER) task: identifying organizations mentioned in news. It consists of two parts:


1.   Clean and preprocess data for the model's use
2.   Finetune the model



**Part 1: data cleaning & preprocessing**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import pandas as pd
import unicodedata
import numpy as np
import spacy
import gc
from tqdm import tqdm, trange
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from collections import defaultdict

In [None]:
nlp = spacy.load("en_core_web_sm")

In [None]:
data = pd.read_csv('/content/drive/MyDrive/capstone/Cleaned_full_data.csv', index_col=0)
validation_index = pd.read_csv('/content/drive/MyDrive/capstone/valid_ids.csv', header = None)
valid_ids = validation_index[0].tolist()
val_data = data.iloc[valid_ids]
train_ids = [i for i in data.index if i not in valid_ids]
train_data =  data.iloc[train_ids]

In [None]:
def preprocess_data(data):
    data = data.sort_values(['sentence', 'entity_a', 'entity_b'])
    all_sentences = data['sentence'].tolist()
    company_a = data['entity_a'].tolist()
    company_a = [co.strip() for co in company_a]
    company_b = data['entity_b'].tolist()
    company_b = [co.strip() for co in company_b]
    return all_sentences, company_a, company_b

In [None]:
def combine_samples(all_sentences, company_a, company_b):
    """In the provided dataset, some samples share the same sentence.
    For the NER task, we need to combine those samples and label all organizations that occur in the same sentence.
    """
    orgs = []
    sent_orgs = set([company_a[0], company_b[0]])
    sentences = [all_sentences[0].strip()]
    for i, sent in enumerate(all_sentences):
        if i == 0:
            continue
        if sent != all_sentences[i - 1]:
            sentences.append(sent.strip())
            assert(len(sent_orgs) >= 2)
            orgs.append(sent_orgs)
            sent_orgs = set()
        sent_orgs.add(company_a[i])
        sent_orgs.add(company_b[i])
    assert(len(sent_orgs) >= 2)
    orgs.append(sent_orgs)
    assert(len(sentences) == len(orgs))
    docs = list(nlp.pipe(sentences))
    return docs, orgs

In [None]:
"""Define global variables, including model hyperparameters that are relevant
during the dataset creation phase.
"""
label_list = ['O', 'ORG']
id2label = {i: label for i, label in enumerate(label_list)}
label2id = {label: i for i, label in enumerate(label_list)}

SEED = 0
# each example contains at most MAX_ENTITY_LENGTH entities
MAX_ENTITY_LENGTH = 1024
# the model only considers spans that are at most ENTITY_WORD_LIMIT words
ENTITY_WORD_LIMIT = 10
# each example contains at most MAX_LENGTH tokens
# here, "token" is at sub-word level, created from spacy tokenization & LukeTokenizer
MAX_LENGTH = 512

In [None]:
def create_examples(docs, orgs):
    """Create a list of dict (one dict per doc) from the docs (created by spacy) and their corresponding organizations.
    Each dict consists of the following:
    text,
    words in the text,
    all spans to be considered by the model indexed at character level ("entity_spans") and word level ("original_word_spans"),
    labels that correspond to each span (as indices in label_list).
    """
    examples = []
    skipped = 0
    total_org_count = 0
    org_word_len = []
    for doc_i, doc in tqdm(enumerate(docs)):
        text = doc.text
        entity_spans = []
        original_word_spans = []
        words = [token.text for token in doc]
        labels = []
        org_occur = set()
        for token_start in doc:
            if len(entity_spans) == MAX_ENTITY_LENGTH:
                break
            for token_end in doc[token_start.i:token_start.i + ENTITY_WORD_LIMIT]:
                entity_start, entity_end = token_start.idx, token_end.idx + len(token_end)
                entity_spans.append((entity_start, entity_end))
                original_word_spans.append((token_start.i, token_end.i + 1))
                span = text[entity_start:entity_end]
                if span in orgs[doc_i]:
                    labels.append(1)
                    org_word_len.append(token_end.i - token_start.i + 1)
                    org_occur.add(span)
                else:
                    labels.append(0)
                # once MAX_ENTITY_LENGTH spans are accumulated for an example, we stop adding additional examples
                # data exploratory analysis shows that this truncation leads us to lose 2 valid examples
                if len(entity_spans) == MAX_ENTITY_LENGTH:
                    break
        # if an example does not have at least 2 unique organization appearances, we exclude it
        if len(org_occur) < 2:
            skipped += 1
            continue
        total_org_count += sum(labels)
        examples.append(dict(text=text, words=words, entity_spans=entity_spans, original_word_spans=original_word_spans, labels=labels))
    print(f"\n{skipped} sentences were skipped since less than 2 organizations were found in these sentences.\n"
        + f"{len(examples)} sentences containing {total_org_count} organization occurrences remain.")
    print(f"The longest organization contains {max(org_word_len)} words.")
    return examples

In [None]:
# create train_examples
train_all_sentences, train_company_a, train_company_b = preprocess_data(train_data)
train_docs, train_orgs = combine_samples(train_all_sentences, train_company_a, train_company_b)
train_examples = create_examples(train_docs, train_orgs)

356it [00:00, 1797.13it/s]


6 sentences were skipped since less than 2 organizations were found in these sentences.
350 sentences containing 1041 organization occurrences remain.
The longest organization contains 6 words.





In [None]:
# create val_examples
val_all_sentences, val_company_a, val_company_b = preprocess_data(val_data)
val_docs, val_orgs = combine_samples(val_all_sentences, val_company_a, val_company_b)
val_examples = create_examples(val_docs, val_orgs)

90it [00:00, 1839.48it/s]


0 sentences were skipped since less than 2 organizations were found in these sentences.
90 sentences containing 254 organization occurrences remain.
The longest organization contains 5 words.





**Part 2: model finetuning**

In [None]:
!pip install transformers
!pip install datasets

In [None]:
import torch
from transformers import LukeTokenizer, LukeForEntitySpanClassification, Trainer, TrainingArguments, set_seed
from datasets import Dataset, load_metric

In [None]:
set_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [None]:
# may want to experiment with "studio-ousia/luke-large-finetuned-conll-2003"
tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", max_entity_length=MAX_ENTITY_LENGTH, task="entity_span_classification", cache_dir=".cache/")

Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/15.3M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/33.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.00k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.04k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/836 [00:00<?, ?B/s]

In [None]:
train_dataset = Dataset.from_list(train_examples)
val_dataset = Dataset.from_list(val_examples)

In [None]:
"""Use LukeTokenizer to tokenize the dataset"""
def tokenize(example):
    tokenized_inputs = tokenizer(
        text=example['text'],
        entity_spans=[tuple(span) for span in example['entity_spans']],
        max_length=MAX_LENGTH,
        max_entity_length=MAX_ENTITY_LENGTH,
        truncation=True,
        return_tensors='pt'
    )

    for k, v in tokenized_inputs.items():
        if isinstance(v, torch.Tensor) and v.shape[0] == 1:
            tokenized_inputs[k] = v.squeeze()

    tokenized_inputs["labels"] = torch.tensor(example['labels'])
    return tokenized_inputs

In [None]:
tokenized_train_data = train_dataset.map(tokenize)
tokenized_val_data = val_dataset.map(tokenize)

  0%|          | 0/350 [00:00<?, ?ex/s]

  0%|          | 0/90 [00:00<?, ?ex/s]

In [None]:
"""
Compute precision, recall, f1, and accuracy (for evaluation on the validation dataset during training)
"""
def compute_metrics(p):
    precision = load_metric("precision")
    recall = load_metric("recall")
    f1 = load_metric("f1")
    accuracy = load_metric("accuracy")

    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions, true_labels = [], []
    for prediction, label in zip(predictions, labels):
        for p, l in zip(prediction, label):
            # -100 represents special tokens [CLS] and [SEP] that should be ignored in metric computation
            if l != -100:
                true_predictions.append(p)
                true_labels.append(l)

    precision_score = precision.compute(predictions=true_predictions, references=true_labels, average="macro")["precision"]
    recall_score = recall.compute(predictions=true_predictions, references=true_labels, average="macro")["recall"]
    f1_score = f1.compute(predictions=true_predictions, references=true_labels, average="macro")["f1"]
    accuracy_score = accuracy.compute(predictions=true_predictions, references=true_labels)["accuracy"]
    return {"precision": precision_score, "recall": recall_score, "f1": f1_score, "accuracy": accuracy_score}

In [None]:
# may want to experiment with "studio-ousia/luke-large-finetuned-conll-2003"
model = LukeForEntitySpanClassification.from_pretrained("studio-ousia/luke-base", ignore_mismatched_sizes=True, id2label=id2label, label2id=label2id, cache_dir=".cache/").to(device)

Downloading:   0%|          | 0.00/1.10G [00:00<?, ?B/s]

Some weights of the model checkpoint at studio-ousia/luke-base were not used when initializing LukeForEntitySpanClassification: ['lm_head.dense.weight', 'entity_predictions.transform.dense.weight', 'entity_predictions.transform.LayerNorm.weight', 'lm_head.bias', 'entity_predictions.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'entity_predictions.transform.dense.bias', 'entity_predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing LukeForEntitySpanClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LukeForEntitySpanClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LukeForEntitySpanClassificati

In [None]:
EPOCHS = 20
LR = 1e-5
WD = 0.01
BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 8

training_args = TrainingArguments(
    # change folder name here, to avoid replacing the previous model's outputs
    output_dir="output/ner-full-data/",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    learning_rate=LR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=EPOCHS,
    weight_decay=WD,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    load_best_model_at_end=True
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_data,
    eval_dataset=tokenized_val_data,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
# to prevent CUDA out of memory issues -- if they still exist, restart runtime (but first, download important files!)
torch.cuda.empty_cache()
gc.collect() # run until the number < 100

0

In [None]:
# define checkpoint to resume training from, if needed
CKPT = None
train_result = trainer.train(resume_from_checkpoint=CKPT)
trainer.save_model()
trainer.save_state()

The following columns in the training set don't have a corresponding argument in `LukeForEntitySpanClassification.forward` and have been ignored: text, words, original_word_spans, entity_spans. If text, words, original_word_spans, entity_spans are not expected by `LukeForEntitySpanClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 350
  Num Epochs = 20
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 8
  Total optimization steps = 860
  Number of trainable parameters = 274506754


Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
0,0.0714,0.023613,0.996304,0.501969,0.502067,0.992608
1,0.0211,0.016664,0.82222,0.823479,0.822848,0.99477
2,0.0151,0.013836,0.838346,0.811888,0.824569,0.995033
3,0.0128,0.014766,0.851937,0.700169,0.754856,0.994332
4,0.0112,0.012521,0.827073,0.892155,0.856611,0.995355
5,0.0097,0.012479,0.835338,0.878538,0.855606,0.995471
6,0.0087,0.012983,0.822551,0.915616,0.863097,0.995384
7,0.0081,0.013261,0.813875,0.895886,0.850045,0.995004
8,0.0072,0.01367,0.812303,0.893903,0.848291,0.994946
9,0.0068,0.012971,0.839943,0.890364,0.863379,0.995676


The following columns in the evaluation set don't have a corresponding argument in `LukeForEntitySpanClassification.forward` and have been ignored: text, words, original_word_spans, entity_spans. If text, words, original_word_spans, entity_spans are not expected by `LukeForEntitySpanClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 90
  Batch size = 1
  precision = load_metric("precision")


Downloading builder script:   0%|          | 0.00/2.58k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/2.52k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/1.65k [00:00<?, ?B/s]

Saving model checkpoint to output/model-name/checkpoint-43
Configuration saved in output/model-name/checkpoint-43/config.json
Model weights saved in output/model-name/checkpoint-43/pytorch_model.bin
tokenizer config file saved in output/model-name/checkpoint-43/tokenizer_config.json
Special tokens file saved in output/model-name/checkpoint-43/special_tokens_map.json
added tokens file saved in output/model-name/checkpoint-43/added_tokens.json
The following columns in the evaluation set don't have a corresponding argument in `LukeForEntitySpanClassification.forward` and have been ignored: text, words, original_word_spans, entity_spans. If text, words, original_word_spans, entity_spans are not expected by `LukeForEntitySpanClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 90
  Batch size = 1
Saving model checkpoint to output/model-name/checkpoint-86
Configuration saved in output/model-name/checkpoint-86/config.json
Model weights 

In [33]:
"""Predict on the validation dataset and save output in CONLL format and csv format
"""
# change output paths first! save in drive, and then go to https://drive.google.com/drive/my-drive to download it
output_conll_path = "/content/drive/MyDrive/capstone/ner-full-data-val-pred.conll"
output_csv_path = "/content/drive/MyDrive/capstone/ner-full-data-val-pred.csv"

# if the model has finished training normally, the best checkpoint (based on validation loss) is loaded in the end
# however, if it did not finish training normally,
# or if you prefer to load the model from another epoch, maybe one with a higher f1, uncomment this line:
# model = LukeForEntitySpanClassification.from_pretrained("output/model-name/checkpoint-688").to(device)

text_list = [e["text"] for e in val_dataset]
words_list = [e["words"] for e in val_dataset]
original_word_spans_list = [e["original_word_spans"] for e in val_dataset]
entity_spans_list = []
for e in val_dataset:
    entity_spans_list.append([tuple(span) for span in e["entity_spans"]])
all_predicted_entities = [list() for e in val_dataset]

with open(output_conll_path, "w", encoding="utf-8") as f:
    for i in range(len(text_list)):
        text = text_list[i]
        words = words_list[i]
        entity_spans = entity_spans_list[i]
        original_word_spans = [tuple(span) for span in original_word_spans_list[i]]
        ows2text = dict([(ows, text[entity_spans[i][0]:entity_spans[i][1]]) for i, ows in enumerate(original_word_spans)])

        inputs = tokenizer(
            text=text,
            entity_spans=entity_spans,
            max_length=MAX_LENGTH,
            max_entity_length=MAX_ENTITY_LENGTH,
            truncation=True,
            return_tensors='pt'
        )
        inputs = inputs.to(device)
        with torch.no_grad():
            outputs = model(**inputs)

        logits = outputs.logits
        max_logits, max_indices = logits[0].max(dim=1)

        predictions = []
        for logit, index, span in zip(max_logits, max_indices, original_word_spans):
            if index != 0:  # the span is not NIL
                predictions.append((logit, span, model.config.id2label[int(index)]))

        # construct an IOB2 label sequence and a list of predicted non-NULL entities
        predicted_sequence = ["O"] * len(words)
        predicted_entities = []
        for _, span, label in sorted(predictions, key=lambda o: o[0], reverse=True):
            if all([o == "O" for o in predicted_sequence[span[0] : span[1]]]):
                predicted_sequence[span[0]] = "B-" + label
                if span[1] - span[0] > 1:
                    predicted_sequence[span[0] + 1 : span[1]] = ["I-" + label] * (span[1] - span[0] - 1)
                predicted_entities.append(ows2text[span])
        all_predicted_entities[i] = predicted_entities

        for token, label in zip(words, predicted_sequence):
            f.write(f"{token} {label}\n")
        f.write("\n")
    
with open(output_csv_path, "w", encoding="utf-8") as f:
    for text, entities in zip(text_list, all_predicted_entities):
        entities_concat = ','.join(entities)
        f.write(f'"{text}","{entities_concat}"\n')

In [None]:
# zip the best checkpoint and save it in drive (it could take a while)
# then go to drive, download it, delete it, and find it in trash to delete it forever (files in trash also take up space)
!zip -r /content/drive/MyDrive/capstone/ner-full-data.zip output/model-name/checkpoint-688

updating: output/model-name/checkpoint-688/ (stored 0%)
updating: output/model-name/checkpoint-688/rng_state.pth (deflated 27%)
updating: output/model-name/checkpoint-688/added_tokens.json (deflated 28%)
updating: output/model-name/checkpoint-688/pytorch_model.bin (deflated 8%)
updating: output/model-name/checkpoint-688/scheduler.pt (deflated 49%)
updating: output/model-name/checkpoint-688/training_args.bin (deflated 48%)
updating: output/model-name/checkpoint-688/tokenizer_config.json (deflated 77%)
updating: output/model-name/checkpoint-688/optimizer.pt (deflated 63%)
updating: output/model-name/checkpoint-688/entity_vocab.json (deflated 65%)
updating: output/model-name/checkpoint-688/merges.txt (deflated 53%)
updating: output/model-name/checkpoint-688/trainer_state.json (deflated 78%)
updating: output/model-name/checkpoint-688/config.json (deflated 52%)
updating: output/model-name/checkpoint-688/special_tokens_map.json (deflated 85%)
updating: output/model-name/checkpoint-688/vocab.