# Intro
Download RJNLPBA from https://arxiv.org/abs/1901.10219

Place extracted files in the same directory.

In [None]:
# %pip install numpy scikit-learn torch tqdm pandas spacy seqeval transformers datasets

# Import

In [None]:
from datetime import datetime
import os
import re
from collections import OrderedDict
import pandas as pd
import numpy as np

from sklearn.model_selection import ShuffleSplit
import torch

from transformers import (AutoTokenizer,
                          AutoModelForTokenClassification, 
                          DataCollatorForTokenClassification,
                          get_scheduler,
                          Trainer,
                          TrainingArguments,
                          EarlyStoppingCallback,
                          logging
                         )
from datasets import (Dataset,
                      Split,
                      Features,
                      Value,
                      ClassLabel,
                      load_metric,
                     )

from tqdm.auto import tqdm
from seqeval.metrics import classification_report
from seqeval.scheme import BILOU
from spacy.training import biluo_tags_to_spans
from spacy.tokens import Doc
import spacy

os.environ["TOKENIZERS_PARALLELISM"] = "false"

device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
try:
    if torch.backends.mps.is_available():
        device = torch.device("mps")
except:
    pass

# Load pretrained models

In [None]:
model_name = "dmis-lab/biobert-base-cased-v1.2"
# possible other models: "sberbank-ai/bert-base-NER-reptile-5-datasets"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name, num_labels=21)
model = model.to(device)
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

# Test tokenizer correctness

In [None]:
tokenized = tokenizer(["ELK-1 immunoglobulin,", "SAP-1,"])
for i in tokenized['input_ids']:
    print(i)
    print(tokenizer.decode(i))

# Preprocessing

## Label name mapping to ID

In [None]:
label_names = OrderedDict({
    1: "B-DNA",
    2: "I-DNA",
    3: "E-DNA",
    4: "S-DNA",
    5: "B-RNA",
    6: "I-RNA",
    7: "E-RNA",
    8: "S-RNA",
    9: "B-protein",
    10: "I-protein",
    11: "E-protein",
    12: "S-protein",
    13: "B-cell_line",
    14: "I-cell_line",
    15: "E-cell_line",
    16: "S-cell_line",
    17: "B-cell_type",
    18: "I-cell_type",
    19: "E-cell_type",
    20: "S-cell_type",
    0: "O",
    -100: -100,
})
label_names_to_id = OrderedDict([(type_name, type_id) for (type_id, type_name) in label_names.items()])



## Tokenize

In [None]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, max_length=512, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                # Set the special tokens to -100.
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                # This is the first token, only label the first token of a given word.
                label_ids.append(label[word_idx])
            else:
                # inter-word tokens
                # method 1: use the same label
                # label_ids.append(label[previous_word_idx])
                # method 2: ignore these tokens
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append([label_names_to_id[l] for l in label_ids])

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

## File loading

In [None]:
def convert_bi_tags(original_tags):
    # convert BI to BIES
    new_tags_all = []
    for tags in original_tags:
        new_tags = []
        for tag_i, tag in enumerate(tags):
            newtag = tag
            if tag.startswith('E'):
                raise ValueError("Input tags already are BIE format!")
            if tag.startswith('B'):
                if tag_i == len(tags) - 1: # if last label
                    newtag = 'S' + tag[1:]
                elif tags[tag_i + 1][0] in ['O', 'B']: # if next label starts with O or B
                    # if next label is O or B
                    newtag = 'S' + tag[1:]
            if tag.startswith('I'):
                if tag_i == len(tags) - 1:  # if last label
                    newtag = 'E' + tag[1:]
                elif tags[tag_i + 1][0] in ['O', 'B']: # if next label starts with O or B
                    # if next label is O or B
                    newtag = 'E' + tag[1:]

            new_tags.append(newtag)

        new_tags_all.append(new_tags)
    return new_tags_all

def load_CONLL_sentences(file_name):
    """
    load a CONLL format file, as in:
    WORD \t TAG
    WORD \t TAG
    ...
    return (List[List[words]], List[List[tags]])
    """
    with open(file_name, 'r', encoding='utf-8') as in_f:
        lines = [l.strip().split('\t') for l in in_f.readlines()]
    return_sents = []
    return_tags_answer = []
    temp_sent = []
    temp_tags = []
    max_sent_len = 0
    for line in lines:
        if len(line) < 2:
            if len(temp_sent) > 0:
                return_sents.append(temp_sent)
                return_tags_answer.append(temp_tags)
                max_sent_len = max(max_sent_len, len(temp_sent))
                temp_sent = []
                temp_tags = []
            continue
        temp_sent.append(line[0])
        temp_tags.append(line[1])

    if len(temp_sent) > 0:
        return_sents.append(temp_sent)
        return_tags_answer.append(temp_tags)
        max_sent_len = max(max_sent_len, len(temp_sent))
        temp_sent = []
        temp_tags = []
    print(f"Max sent length in {file_name}: {max_sent_len}")
    return_tags_answer = convert_bi_tags(return_tags_answer)
    return return_sents, return_tags_answer


In [None]:
# make sure tag converter works
tags = [['B-1', 'B-1', 'I-1', 'I-1', 'O', 'O', 'B-2', 'O'], ['B-3'], ['O']]
convert_bi_tags(tags)

# Metrics

In [None]:

metric = load_metric("seqeval")

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=-1) # assume predictions dim = batch, len, classes

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_names[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

# Training

## Load training data

In [None]:
examples = dict()
sent_tags = load_CONLL_sentences("./Genia4ERtask1.iob2")
examples['tokens'] = sent_tags[0]
examples['ner_tags'] = sent_tags[1]

examples = tokenize_and_align_labels(examples)
training_data = Dataset.from_dict(examples)

for train_ids, test_ids in ShuffleSplit(n_splits=1, test_size=0.1, random_state=1331).split(training_data):
    ds_train = Dataset.from_dict(training_data[train_ids])
    ds_valid = Dataset.from_dict(training_data[test_ids])

# ds_train.set_format("pt")
# ds_valid.set_format("pt")

## Create trainer
Customize trainer to use Mac M1 GPU if available (https://github.com/huggingface/transformers/issues/17971)

In [None]:
class TrainingArgumentsWithMPSSupport(TrainingArguments):

    @property
    def device(self) -> torch.device:
        if torch.cuda.is_available():
            return torch.device("cuda")
        if torch.backends.mps.is_available():
            return torch.device("mps")
        return torch.device("cpu")

training_args = TrainingArgumentsWithMPSSupport(
    report_to='all',
    logging_steps=10,
    logging_strategy="steps",
    optim="adamw_torch",
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=2,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds_train,
    eval_dataset=ds_valid,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)


## Start training

In [None]:

trainer.train()

## Testing
Test validation set

In [None]:
predictions, labels, _ = trainer.predict(ds_valid)
predictions = np.argmax(predictions, axis=-1) # last axis is logit

# Remove ignored index (special tokens)
true_predictions = [
    [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]
true_labels = [
    [label_names[l] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]

results = metric.compute(predictions=true_predictions, references=true_labels)
results

## Load test set

In [None]:
test_sentences_tags = dict()
sent_tags = load_CONLL_sentences("./Genia4EReval1.iob2")
test_sentences_tags['tokens'] = sent_tags[0]
test_sentences_tags['ner_tags'] = sent_tags[1]

test_data_tokenized = tokenize_and_align_labels(test_sentences_tags)
test_data = Dataset.from_dict(test_data_tokenized)
# test_data.set_format("pt")


## Run test

In [None]:
predictions, labels, _ = trainer.predict(test_data)
predictions = np.argmax(predictions, axis=-1)

# Remove ignored index (special tokens)
true_predictions = [
    [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]
true_labels = [
    [label_names[l] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]

results = metric.compute(predictions=true_predictions, references=true_labels)
results

# Load final model

In [None]:
logging.set_verbosity(logging.ERROR)
model = AutoModelForTokenClassification.from_pretrained("./final_model/")
logging.set_verbosity(logging.WARNING)
model = model.to(device)

# Predict test set

In [None]:
answers = test_sentences_tags['ner_tags']
# if needed, convert BIE back to BI here

predictions = []
with test_data.formatted_as('torch'):
    for d in tqdm(test_data):
        with torch.no_grad():
            # usually we only need input_ids, others are all defaults
            inputs = {"input_ids": d["input_ids"].unsqueeze(0).to(model.device),
                     #   "attention_mask": torch.tensor([d["attention_mask"], ], device=model.device),
                     #   "token_type_ids": torch.tensor([d["token_type_ids"], ], device=model.device),
                     }
            pred = model(**inputs)['logits'][0].cpu().detach().numpy()
            pred = np.argmax(pred, axis=-1).tolist()
            predictions.append(pred)


## Convert prediction IDs back to text
Also map to original words

In [None]:
num_of_sents = len(test_data_tokenized['input_ids'])
predicted_labels = []
for sent_i in range(num_of_sents):
    words = test_sentences_tags['tokens'][sent_i]
    word_id_map = test_data_tokenized.word_ids(sent_i)
    sent_output = [''] * len(words)
    for i, c in enumerate(predictions[sent_i]):
        if word_id_map[i] is None:
            # special tokens
            continue
        if sent_output[word_id_map[i]] == '':
            lab = label_names[c]
            # convert BIE back to BI
            # if lab[0] == "S":
            #     lab = "B" + lab[1:]
            # if lab[0] == "E":
            #     lab = "I" + lab[1:]
            sent_output[word_id_map[i]] = lab

    predicted_labels.append(sent_output)

## Write to output file

In [None]:
num_of_sents = len(test_data_tokenized['input_ids'])
if len(predicted_labels) != num_of_sents:
    print("ERROR num of input sentences and predicted labels do not match!")

with open("testout.iob2", 'w') as outfile:
    for sent_i in range(num_of_sents):
        words = test_sentences_tags['tokens'][sent_i]
        tags = predicted_labels[sent_i]
        sent = '\n'.join([f"{w}\t{gue}" for (w, gue) in zip(words, tags)])
        outfile.write(f"{sent}\n")
        if sent_i != num_of_sents - 1:
            outfile.write("\n")


## Testing prediction of one input

In [None]:
words = [
    "1", 
    ",", 
    "25-Dihydroxyvitamin", 
    "D3", 
    "receptors", 
    "in", 
    "lymphocytes", 
    "and", 
    "T-", 
    "and", 
    "B-lymphocyte", 
    "count", 
    "in", 
    "patients", 
    "with", 
    "glomerulonephritis", 
    ]
tokenized_inputs = tokenizer(words, is_split_into_words=True, 
                             return_tensors="pt")
word_id_map = tokenized_inputs.word_ids(0)                   
input_ids = tokenized_inputs['input_ids'][0]
tokenized_inputs_t = tokenized_inputs.to(model.device)
with torch.no_grad():
    model_output = model(**tokenized_inputs_t)['logits'][0]
model_output = np.argmax(model_output.cpu().detach().numpy(), axis=-1)
if len(input_ids) != len(model_output):
    print("ERROR")

sent_output = [''] * len(words)
# Reconstruct original input, because BERT use sub-word tokens
for i, c in enumerate(model_output):
    if word_id_map[i] is None:
        continue
    if sent_output[word_id_map[i]] == '':
        # this word has no label yet
        # do not overwrite existing label
        sent_output[word_id_map[i]] = label_names[c]

if len(sent_output) != len(words):
    print(f"ERROR: output labels {len(sent_output)} - original words {len(words)}")
else:
    max_word_len = max(len(w) for w in words)
    # better print format
    for w, c in zip(words, sent_output):
        print(f"{w:>{max_word_len}s}\t{c}")

## Evaluate test set metric

In [None]:
classification_report(y_true=answers, y_pred=predicted_labels, scheme=BILOU, output_dict=True)

# Generate HTML report

## HTML template

In [None]:

html_head = \
"""
<html>
<meta charset='UTF-8'>
<head>
<script>
	function add_listeners() {
		const tp_btn = document.querySelector('#tp_btn');
		const fp_btn = document.querySelector('#fp_btn');
		const fn_btn = document.querySelector('#fn_btn');
		tp_btn.addEventListener('click', function (event) { toggle_class_visibility(event); }, false);
		fp_btn.addEventListener('click', function (event) { toggle_class_visibility(event); }, false);
		fn_btn.addEventListener('click', function (event) { toggle_class_visibility(event); }, false);
	
	}
	function toggle_class_visibility(event) {
		const is_shown = !event.srcElement.classList.contains("not-shown");
		const this_class = event.srcElement.classList[0];
		const spans_in_the_class = document.querySelectorAll("."+this_class);
		if (is_shown) {
			spans_in_the_class.forEach( item => {
				item.classList.add("not-shown");
			})
		}
		else {
			spans_in_the_class.forEach( item => {
				item.classList.remove("not-shown");
			})
		}
        const sentences = document.querySelectorAll(".sentence");
        sentences.forEach(sent => {
            const spans_count = sent.querySelectorAll(".TP, .FP, .FN").length;
            const not_shown_spans_count = sent.querySelectorAll(".TP.not-shown, .FP.not-shown, .FN.not-shown").length;
            if (spans_count === not_shown_spans_count) {
                sent.classList.add("hide");
            }
            else {
                sent.classList.remove("hide");
            }

        });
    
	}
	document.addEventListener('DOMContentLoaded', add_listeners);
</script>
<style>
body {
	background-color: #33302c;
	color: #aab1a3;
	line-height: 1.2em;
}

.TP {
	background-color: #44690a;
    padding-left: 2px;
    padding-right: 2px;
    border: solid;
    border-color: darkslategray;
    border-width: 2px;
}

.FP {
	background-color: #69120b;
    padding-left: 2px;
    padding-right: 2px;
    border: solid;
    border-color: darkslategray;
    border-width: 2px;
}

.FN {
	background-color: #0c346d;
    padding-left: 2px;
    padding-right: 2px;
    border: solid;
    border-color: darkslategray;
    border-width: 2px;
}

.not-shown {
	background-color: transparent!important;
    border: none!important;
}

.hide {
    display: none!important;
}

.Config {
	background-color: #545454;
    width: 5em;
    display: inline-block;
}

.fix_top {
	position: sticky;
	top: 0;
	background-color: #676760;
}
</style>
</head>
<body>
<div id="top_panel" class="fix_top">
<p><span class="TP" id="tp_btn">TP</span>
<span class="FP" id="fp_btn">FP</span>
<span class="FN" id="fn_btn">FN</span>
</p>
</div>
"""

html_end = """
</body></html>
"""

default_ne_types = ["protein", "DNA", "RNA", "cell_line", "cell_type"]


In [None]:
def generate_html_analysis(sent_words, gold_tags, predict_tags, ne_types=None):
    return_list = []
    if ne_types is None: ne_types = default_ne_types
    for ne_type in ne_types:
        sent = sent_words.copy()
        should_add = False
        for pred in predict_tags:
            if pred.label_ != ne_type: continue
            if any(span_equal(pred, g) for g in gold_tags):
                sent[pred.start] = "<span class='TP'>" + sent[pred.start]
                sent[pred.end - 1] = sent[pred.end - 1] + "</span>"
                should_add = True
            else:
                sent[pred.start] = "<span class='FP'>" + sent[pred.start]
                sent[pred.end - 1] = sent[pred.end - 1] + "</span>"
                should_add = True
        for go in gold_tags:
            if go.label_ != ne_type: continue
            if all(not span_equal(go, p) for p in predict_tags):
                sent[go.start] = "<span class='FN'>" + sent[go.start]
                sent[go.end - 1] = sent[go.end - 1] + "</span>"
                should_add = True
        sent.insert(0, f"<span class='Config'>{ne_type}:</span>")
                
        if should_add:
            # return_list.append('<div class="sentence">' + sent[0] + "<br /></div>\n")
        # else:
            return_list.append('<div class="sentence">' + " ".join(sent) + "<br /></div>\n")
    return "".join(return_list)

def span_equal(span_1, span_2):
    if span_1.start == span_2.start and span_1.end == span_2.end and span_1.label_ == span_2.label_:
        return True
    return False


## Fix error BILUO tag sequence

In [None]:
def fix_labels(labels):
    # try to fix error
    labels_fix = []
    for l_pos, l in enumerate(labels):
        if l[0] == "I":
            if l_pos == 0:
                # first label is I, should be B
                labels_fix.append("B" + l[1:])
                continue
            if l_pos == len(labels) - 1:
                # last label is I, should be L
                labels_fix.append("L" + l[1:])
                continue
            if labels[l_pos - 1][0] == "O":
                # previous label is O, should be B
                labels_fix.append("B" + l[1:])
                continue
            if labels[l_pos - 1][0] == "U":
                # previous label is U, change previous label to B
                # NOTE: danger
                labels_fix[-1] = "B" + l[1:]
                labels_fix.append(l)
                continue
            if labels[l_pos - 1][0] == "L":
                # previous label is L, change previous label to I
                # NOTE: danger
                labels_fix[-1] = "I" + l[1:]
                labels_fix.append(l)
                continue
            if labels[l_pos - 1][0] in ["I", "B"]:
                # good case
                labels_fix.append(l)
                continue
        if l[0] == "L":
            if l_pos == 0:
                # first label is L, should be U
                labels_fix.append("U" + l[1:])
                continue
            if labels[l_pos - 1][0] in ["B", "I"]:
                # good case
                labels_fix.append(l)
                continue
            if labels[l_pos - 1][0] in ["L", "U", "O"]:
                # previous label is L, U, O, should be O
                # NOTE: may change
                labels_fix.append("O")
                continue
        if l[0] == "U":
            if labels[l_pos - 1][0] in ["B", "I", "U"]:
                # previous label is B, I, U, change to O
                # NOTE: danger
                labels_fix.append("O")
                continue
        labels_fix.append(l)
    return labels_fix

In [None]:
nlp = spacy.blank("en")
html_strings = []

for sent_i in range(num_of_sents):

    words = test_sentences_tags['tokens'][sent_i]
    doc = Doc(vocab=nlp.vocab, words=words, spaces=[True] * len(words))
    word_id_map = test_data_tokenized.word_ids(sent_i)
    pred = predicted_labels[sent_i]
    answer = answers[sent_i]
    # convert BIOES to BILUO
    pred = [t.replace('E-', 'L-') for t in pred]
    pred = [t.replace('S-', 'U-') for t in pred]
    answer = [t.replace('E-', 'L-') for t in answer]
    answer = [t.replace('S-', 'U-') for t in answer]    
    # fix errors
    pred_fix = fix_labels(pred)
    try:
        pred_spans = biluo_tags_to_spans(doc, tags=pred_fix)
    except Exception as e:
        print('\n'.join([f"{p[0]} {p[1]}" for p in zip(doc, pred_fix)]))
        print(e)
        break
    try:
        ans_spans = biluo_tags_to_spans(doc, tags=answer)
    except Exception as e:
        print("Why is answer incorrect?")
        print('\n'.join([f"{p[0]} {p[1]}" for p in zip(doc, answer)]))
        print(e)
        break
    
    html_string = generate_html_analysis(words, ans_spans, pred_spans)
    html_strings.append(html_string)

In [None]:
datetime_now = f"{datetime.now().strftime('%Y%m%d%H%M')}"
with open(f'error_output_{datetime_now}.html', 'w') as outfile:
        outfile.write(html_head + '<br />\n'.join(html_strings) + html_end)