In [5]:
from transformers import RobertaTokenizerFast, RobertaForTokenClassification, BatchEncoding, CharSpan
from datasets import load_dataset
import torch, os, numpy as np

import utils

import importlib
importlib.reload(utils);

In [6]:
path = 'checkpoints/original-models/checkpoint-1005'

tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base', add_prefix_space=True)
model = RobertaForTokenClassification.from_pretrained(path, num_labels=37)

In [7]:
ds = load_dataset("tner/ontonotes5", keep_in_memory=True, num_proc=os.cpu_count())
inspected = ds['test'][0]

ds_label_tag_mapping, ds_tag_label_mapping = utils.load_label_mapping()

In [8]:
text = "Yesterday I met the president of the USA Joe Biden."
tokens: BatchEncoding = tokenizer(text, truncation=True)

inputs = {k:torch.tensor([v]) for (k,v) in tokens.items()}

In [9]:
with torch.no_grad():
    logits = model(inputs['input_ids'], inputs['attention_mask']).logits

predicted_token_class_ids = logits.argmax(-1)[0].tolist()

In [10]:
def print_labels_texts(tokens: BatchEncoding, labels: list[int], ds_tag_label_mapping: dict[int, str]):
    char_spans = [(tokens.token_to_chars(i),label) for i,label in enumerate(labels)][1:-1]

    grouped_char_spans = []
    old_label = 0

    for span, label in char_spans:
        if label == 0:
            pass
        elif old_label == 0 or (ds_tag_label_mapping[grouped_char_spans[-1][1]][2:] != ds_tag_label_mapping[label][2:]):
            grouped_char_spans.append((span, label))
        else:
            current_span, current_label = grouped_char_spans[-1]
            grouped_char_spans[-1] = (CharSpan(current_span.start, span.end), current_label)

        old_label = label


    formatted_spans: str = ''
    for span, _ in grouped_char_spans:
        formatted_spans += ' ' * (span.start - len(formatted_spans))
        formatted_spans += '-' * (span.end - span.start)
    print(formatted_spans)


    formatted_spans: str = ''
    for span, label in grouped_char_spans:
        formatted_spans += ' ' * (span.start - len(formatted_spans))
        formatted_spans += ds_tag_label_mapping[label][2:]
    print(formatted_spans)

print(text)
print_labels_texts(tokens, predicted_token_class_ids, ds_tag_label_mapping)

Yesterday I met the president of the USA Joe Biden.
---------                            --- ---------
DATE                                 ORG PERSON
