In [4]:
from transformers import AutoTokenizer, AutoModelForTokenClassification

mt_path = "RohanVB/umlsbert_ner"
model = AutoModelForTokenClassification.from_pretrained(mt_path)
tok = AutoTokenizer.from_pretrained(mt_path)
model = model.cuda()

import jsonlines

instructions = list(jsonlines.open('data/instruction_dataall.jsonl'))

Some weights of the model checkpoint at RohanVB/umlsbert_ner were not used when initializing BertForTokenClassification: ['bert.embeddings.tui_type_embeddings.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [8]:
import random
import torch
i = random.choice(instructions)
prompt = i['input'].strip()+' '+i['output'].strip()
# prompt = "Is ipsA , a novel LacI-type regulator , required for inositol-derived lipid formation in Corynebacteria and Mycobacteria?The development of new drugs against tuberculosis and diphtheria is focused on disrupting the biogenesis of the cell wall, the unique architecture of which confers resistance against current therapies. The enzymatic pathways involved in the synthesis of the cell wall by these pathogens are well understood, but the underlying regulatory mechanisms are largely unknown. This characterization of IpsA function and of its regulon sheds light on the complex transcriptional control of cell wall biogenesis in the mycolata taxon and generates novel targets for drug development."
prompts = [prompt] * 2 
print(prompt)
print(tok.batch_decode(tok(prompt).input_ids))

def batch_ner(prompts, model, tok):
    inp = tok(prompts, return_tensors='pt',padding=True).to(model.device)
    id2label = lambda i: model.config.id2label[i] if i != 0 else ""
    batch_input_ids = inp['input_ids']
    batch_ner_output = model(**inp).logits.argmax(-1)
    batch_words = []
    batch_labels = []
    for input_ids, ner_output in zip(batch_input_ids, batch_ner_output):
        nonzero_idxs = torch.nonzero(ner_output).squeeze()
        words = []
        labels = []
        start_idxs = []
        cur_word = []
        last_ner_id = ner_output[nonzero_idxs[0]].item()
        last_ner_id = last_ner_id - 3 if last_ner_id >= 4 else last_ner_id
        last_idx = nonzero_idxs[0]
        for idx in nonzero_idxs:
            ner_id = ner_output[idx].item()
            if ner_id >= 4: # B TAG: start of new ner word
                if cur_word: # if there is a current word, finish it
                    words.append(tok.decode(cur_word, skip_special_tokens=True))
                    cur_word.clear()
                cur_word.append(input_ids[idx])
                labels.append(id2label(ner_id))
                start_idxs.append(idx)
                last_ner_id = ner_id - 3
            elif ner_id == last_ner_id and idx == last_idx + 1: # I TAG: continue current ner word
                cur_word.append(input_ids[idx])
                last_ner_id = ner_id
            else: # I TAG: start of new ner word
                # raise ValueError("I tag without a B tag")
                # print("I tag without a B tag")
                if cur_word: # if there is a current word, finish it
                    words.append(tok.decode(cur_word, skip_special_tokens=True))
                    cur_word.clear()
                cur_word.append(input_ids[idx])
                labels.append(id2label(ner_id+3))
                start_idxs.append(idx)
                last_ner_id = ner_id
            last_idx = idx
        if cur_word:
            words.append(tok.decode(cur_word, skip_special_tokens=True))
            cur_word.clear()
        # post processing merge subtokens
        new_words = []
        new_labels = []
        last_start_idx = start_idxs[0]
        for i, (word, label, start_idx) in enumerate(zip(words, labels, start_idxs)):
            if word.startswith('##'):
                if start_idx - 1 == last_start_idx: # if the subtoken is directly after the last subtoken, append to the last word
                    new_words[-1] += word[2:]
            else:
                new_words.append(word)
                new_labels.append(label)
            last_start_idx = start_idx
        batch_words.append(new_words)
        batch_labels.append(new_labels)
    return batch_words, batch_labels, batch_ner_output

batch_words, batch_labels, ner_output = batch_ner(prompts, model, tok)
print(batch_words)
print(batch_labels)
print(ner_output)

Which of the following drug is not used for overactive bladder? .
['[CLS]', 'which', 'of', 'the', 'following', 'drug', 'is', 'not', 'used', 'for', 'over', '##active', 'bladder', '?', '.', '[SEP]']
[['overactive bladder'], ['overactive bladder']]
[['B-problem'], ['B-problem']]
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 3, 3, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 3, 3, 0, 0, 0]], device='cuda:0')
