In [1]:
import torch

In [2]:
path = 'checkpoints/checkpoint-2000'
#path = 'surrey-nlp/roberta-base-finetuned-abbr'

In [3]:
import transformers
from datasets import load_dataset
datasets = load_dataset("surrey-nlp/PLOD-CW")
TEXT2ID = {
    "B-O": 0,
    "B-AC": 1,
    "B-LF": 2,
    "I-LF": 3,
}
datasets = datasets.map(lambda x: {"ner_tags": [TEXT2ID[tag] for tag in x["ner_tags"]]})
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(path)
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(-100)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                label_ids.append(label[word_idx] if True else -100)
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs
tokenized_datasets = datasets.map(tokenize_and_align_labels, batched=True)
from transformers import AutoModelForTokenClassification

In [4]:
model = AutoModelForTokenClassification.from_pretrained(path, num_labels=4)

In [5]:
pipeline = transformers.pipeline("ner", model=model, tokenizer=tokenizer, ignore_labels=[])

def choose(i=None):
    if i is None:
        i = torch.randint(0, len(datasets["test"]["tokens"]), (1,)).item()
    output = pipeline(" ".join(datasets["test"]["tokens"][i]))
    words = datasets["test"]["tokens"][i]
    truth = datasets["test"]["ner_tags"][i]

    return words, output, truth

def choose_multiple(nb=5):
    indices = torch.randint(0, len(datasets["test"]["tokens"]), (nb,))
    words = []
    outputs = []
    truths = []
    for i in indices:
        w, o, t = choose(i)
        words.append(w)
        outputs.append(o)
        truths.append(t)
    return words, outputs, truths

In [6]:
from colorama import Back, Style

TEXT2ID = {
    "O": 0,
    "B-AC": 1,
    "B-LF": 2,
    "I-LF": 3,
}

def vizu(words, output, truth, type=None):
    sentence = " ".join(words)
    out_words = []
    out_label = []
    out_truth = []
    index = 1
    for i in range(len(output)):
        start = output[i]['start']
        end = output[i]['end']
        word = output[i]['word']
        if type==1 and 'Ġ' in word: 
            out_words.append(' ')
            out_label.append(0)
            index += 1
        elif type!=1 and word[0] != '#':
            out_words.append(' ')
            out_label.append(0)
            index += 1
        out_words.append(sentence[start:end])
        if type==1:
            #print(output[i]['entity'])
            out_label.append(TEXT2ID[output[i]['entity']])
        else:
            out_label.append(int(output[i]['entity'][-1]))
    col = {0: Back.BLACK, 1: Back.RED, 2: Back.GREEN, 3: Back.BLUE, 4: Back.MAGENTA}
    out_label = out_label[1:]
    out_words = out_words[1:]
    print('Output:  ', end='')
    for i in range(len(out_words)):
        print(col[out_label[i]], end='')
        print(out_words[i], end='')
        print(Style.RESET_ALL, end='')
    print()
    print('Truth:   ', end='')
    for i in range(len(words)):
        print(col[truth[i]], end='')
        print(words[i] + ' ', end='')
        print(Style.RESET_ALL, end='')
    print()
    print()

In [41]:
words, outputs, truths = choose_multiple()
for i in range(len(words)):
    vizu(words[i], outputs[i], truths[i], type=0)

Output:  [40mWe[0m[40m [0m[40mincluded[0m[40m [0m[40mthose[0m[40m [0m[40mwith[0m[40m [0m[40mfollow[0m[40m [0m[40m-[0m[40m [0m[40mup[0m[40m [0m[40mvisits[0m[40m [0m[40mand[0m[40m [0m[42mpulmonary[0m[40m [0m[44mfunction[0m[40m [0m[44mtests[0m[40m [0m[40m([0m[40m [0m[41mP[0m[41mFT[0m[41ms[0m[40m [0m[40m)[0m[40m [0m[40mavailable[0m[40m [0m[40mafter[0m[40m [0m[40mhospital[0m[40mization[0m[40m [0m[40m.[0m
Truth:   [40mWe [0m[40mincluded [0m[40mthose [0m[40mwith [0m[40mfollow [0m[40m- [0m[40mup [0m[40mvisits [0m[40mand [0m[42mpulmonary [0m[44mfunction [0m[44mtests [0m[40m( [0m[41mPFTs [0m[40m) [0m[40mavailable [0m[40mafter [0m[40mhospitalization [0m[40m. [0m

Output:  [42mPartial[0m[40m [0m[44mleast[0m[40m [0m[44msquares[0m[40m [0m[44mregression[0m[40m [0m[40m([0m[40m [0m[41mPL[0m[41mSR[0m[40m [0m[40m)[0m[40m [0m[40m,[0m[40m [0m[40ma[0m[40m