In [1]:
from transformers import BertForTokenClassification, BertTokenizer
import torch
import numpy as np
from hpobert.dataset import HPODataset
from hpobert.utils import bio_to_entity_tokens, character_annotations_to_spacy_doc, token_span_to_char_span
import spacy
from spacy import displacy

In [2]:
data_file = "../data/meh_eyedisease.jsonl"
device = 'cpu'
model_checkpoint_path = '../models/hponer_epoch167_f1_0.7012.pth/'
bert_type = "dmis-lab/biobert-v1.1"

In [3]:
dataset = HPODataset(data_file)   
tag2idx, tag_values = dataset.get_tag_info()

tokenizer = BertTokenizer.from_pretrained(bert_type, do_lower_case=False)
model = BertForTokenClassification.from_pretrained(model_checkpoint_path)

In [4]:
test_sentence = "The patient has loss of vision on both eyes and a history of severe nyctalopia and macular atrophy."

In [5]:
tokenized_sentence = tokenizer.encode(test_sentence)
input_ids = torch.tensor([tokenized_sentence]).cpu()

In [6]:
model.eval()
# Predict
with torch.no_grad():
	output = model(input_ids)
label_indices = np.argmax(output[0].to('cpu').numpy(), axis=2)

# join bpe split tokens
tokens = tokenizer.convert_ids_to_tokens(input_ids.to('cpu').numpy()[0])
new_tokens, new_labels = [], []
for token, label_idx in zip(tokens, label_indices[0]):
	if token not in ['[CLS]', '[SEP]', '[PAD]']:
		if token.startswith("##"):
			new_tokens[-1] = new_tokens[-1] + token[2:]
		else:
			new_labels.append(tag_values[label_idx])
			new_tokens.append(token)

In [7]:
for token, label in zip(new_tokens, new_labels):
    print("{}\t{}".format(label, token))

O	The
O	patient
O	has
B-pnt	loss
I-pnt	of
I-pnt	vision
O	on
O	both
O	eyes
O	and
O	a
O	history
O	of
O	severe
B-pnt	nyctalopia
O	and
B-pnt	macular
I-pnt	atrophy
O	.


In [8]:
# Convert to spacy document for display
nlp = spacy.blank('en')
doc = nlp(test_sentence)

out_spans = bio_to_entity_tokens(new_labels)
ann_token_span = {'text': test_sentence, 'spans': out_spans}
ann_char_span = token_span_to_char_span(ann_token_span)

doc = character_annotations_to_spacy_doc(ann_char_span, nlp)

In [9]:
displacy.render(doc, style="ent")