In [None]:
from transformers import RobertaTokenizerFast, RobertaForTokenClassification, BatchEncoding, CharSpan
from datasets import load_dataset
import torch, os, numpy as np

import utils

import importlib
importlib.reload(utils);

In [None]:
path = 'checkpoints/roberta-training-default-dataset-fp16/checkpoint-1005'

tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')
model = RobertaForTokenClassification.from_pretrained('EdoardoLuciani/roberta-on-english-ner', num_labels=37)

In [None]:
ds_label_tag_mapping, ds_tag_label_mapping = utils.load_label_mapping()

In [None]:
text = [
    "Barack Obama was born in Hawaii and served as the 44th President of the United States.",
    "Apple Inc. is headquartered in Cupertino, California, and was founded by Steve Jobs, Steve Wozniak, and Ronald Wayne.",
    "The Eiffel Tower, located in Paris, France, is one of the most famous landmarks in the world.",
    "Amazon's CEO, Andy Jassy, announced new plans for expanding their headquarters in Seattle.",
    "The novel 'To Kill a Mockingbird' was written by Harper Lee and published in 1960.",
    "Tesla's new Gigafactory in Berlin is expected to produce thousands of electric vehicles each year.",
    "The Great Wall of China stretches across northern China and was built to protect against invasions.",
    "On July 20, 1969, Neil Armstrong and Buzz Aldrin became the first humans to walk on the moon as part of the Apollo 11 mission.",
    "The United Nations headquarters is located in New York City, and it is an international organization founded in 1945.",
    "The Nobel Prize in Literature 2020 was awarded to Louise Glück, an American poet.",
]

In [None]:
tokens: BatchEncoding = tokenizer(text, truncation=True, padding='longest')

inputs = {k:torch.tensor(v) for (k,v) in tokens.items()}

In [None]:
with torch.no_grad():
    logits = model(inputs['input_ids'], inputs['attention_mask']).logits

predicted_token_class_ids = logits.argmax(-1).tolist()

In [None]:
def print_labels_texts(tokens: BatchEncoding, batch_index: int, labels: list[int], ds_tag_label_mapping: dict[int, str]):
    char_spans = [(tokens.token_to_chars(batch_or_token_index=batch_index,token_index=i),label) for i,label in enumerate(labels)][1:-1]

    grouped_char_spans = []
    old_label = 0

    for span, label in char_spans:
        if label == 0:
            pass
        elif old_label == 0 or (ds_tag_label_mapping[grouped_char_spans[-1][1]][2:] != ds_tag_label_mapping[label][2:]):
            grouped_char_spans.append((span, label))
        else:
            current_span, current_label = grouped_char_spans[-1]
            grouped_char_spans[-1] = (CharSpan(current_span.start, span.end), current_label)

        old_label = label


    formatted_spans: str = ''
    for span, _ in grouped_char_spans:
        formatted_spans += ' ' * (span.start - len(formatted_spans))
        formatted_spans += '-' * (span.end - span.start)
    print(formatted_spans)


    formatted_spans: str = ''
    for span, label in grouped_char_spans:
        formatted_spans += ' ' * (span.start - len(formatted_spans))
        formatted_spans += ds_tag_label_mapping[label][2:]
    print(formatted_spans)


for i, (sentence, labels) in enumerate(zip(text, predicted_token_class_ids)):
    print(sentence)
    print_labels_texts(tokens, i, labels, ds_tag_label_mapping)
    print('\n')