**1. Set up environment**

In [None]:
!pip install torch==1.11.0 transformers==4.19.1 sacremoses==0.0.53

**2. Setup code**

In [47]:
import torch
from transformers import *
import operator
from collections import OrderedDict
import sys
import traceback
import argparse
import string


import logging

#DEFAULT_MODEL_PATH='bert-large-cased'
#DEFAULT_MODEL_PATH='bert-base-cased' #works best for names
#DEFAULT_MODEL_PATH='bert-base-uncased'
DEFAULT_MODEL_PATH='./'
DEFAULT_TO_LOWER=False
DEFAULT_TOP_K = 20
ACCRUE_THRESHOLD = 1

def init_model(model_path,to_lower):
    logging.basicConfig(level=logging.INFO)
    print("******* MODEL[path] is:",model_path," lower casing is set to:",to_lower)
    tokenizer = BertTokenizer.from_pretrained(model_path,do_lower_case=to_lower)
    model = BertForMaskedLM.from_pretrained(model_path)
    #tokenizer = RobertaTokenizer.from_pretrained(model_path,do_lower_case=to_lower)
    #model = RobertaForMaskedLM.from_pretrained(model_path)
    model.eval()
    return model,tokenizer




def perform_task(model,tokenizer,top_k,accrue_threshold,text,patched):
    text = '[CLS] ' + text + ' [SEP]' 
    tokenized_text = tokenizer.tokenize(text)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)

    # Create the segments tensors.
    segments_ids = [0] * len(tokenized_text)

    print(tokenized_text)

    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])


    with torch.no_grad():
        predictions = model(tokens_tensor, segments_tensors)
        for i in range(len(tokenized_text)):
                #if (i != 0 and i != len(tokenized_text) - 1):
                #    continue
                results_dict = {}
                masked_index = i
                neighs_dict = {}
                if (patched):
                    for j in range(len(predictions[0][0][0,masked_index])):
                        if (float(predictions[0][0][0,masked_index][j].tolist()) > accrue_threshold):
                            tok = tokenizer.convert_ids_to_tokens([j])[0]
                            results_dict[tok] = float(predictions[0][0][0,masked_index][j].tolist())
                        tok = tokenizer.convert_ids_to_tokens([j])[0]
                        if (tok in tokenized_text):
                            neighs_dict[tok] = float(predictions[0][0][0,masked_index][j].tolist())
                else:
                    for j in range(len(predictions[0][0][masked_index])):
                        if (float(predictions[0][0][masked_index][j].tolist()) > accrue_threshold):
                            tok = tokenizer.convert_ids_to_tokens([j])[0]
                            results_dict[tok] = float(predictions[0][0][masked_index][j].tolist())
                        tok = tokenizer.convert_ids_to_tokens([j])[0]
                        if (tok in tokenized_text):
                            neighs_dict[tok] = float(predictions[0][0][masked_index][j].tolist())
                k = 0
                sorted_d = OrderedDict(sorted(results_dict.items(), key=lambda kv: kv[1], reverse=True))
                print("********* Top predictions for token: ",tokenized_text[i])
                for index in sorted_d:
                    if (index in string.punctuation or index.startswith('##') or len(index) == 1 or index.startswith('.') or index.startswith('[')):
                        continue
                    print(index,round(float(sorted_d[index]),4))
                    k += 1
                    if (k > top_k):
                        break
                print("********* Closest sentence neighbors in output to the token :  ",tokenized_text[i])
                sorted_d = OrderedDict(sorted(neighs_dict.items(), key=lambda kv: kv[1], reverse=True))
                for index in sorted_d:
                    if (index in string.punctuation or index.startswith('##') or len(index) == 1 or index.startswith('.') or index.startswith('[')):
                        continue
                    print(index,round(float(sorted_d[index]),4))
                print()
                print()
                #break





**3. Load model of choice once** 

In [None]:
#model = "bert-base-cased"
model = "ajitrajasekharan/biomedical"
tolower=False
model,tokenizer = init_model(model,tolower)

**4. Modify input and execute for results**

_**Note:** Results shown for all tokenized terms including masked term_

In [50]:
topk = DEFAULT_TOP_K
patched = False
threshold = ACCRUE_THRESHOLD
#text = "John flew from New York to [MASK]"
text = "Parkinson who works for XCorp suffers from progressive [MASK]"
perform_task(model,tokenizer,topk,threshold,text,patched)

['[CLS]', 'Parkinson', 'who', 'works', 'for', 'X', '##Co', '##rp', 'suffer', '##s', 'from', 'progressive', '[MASK]', '[SEP]']
********* Top predictions for token:  [CLS]
surgery 7.5769
who 7.3091
Surgery 7.2822
levodopa 6.9597
of 6.9088
Rehabilitation 6.6983
and 6.6847
Methods 6.683
history 6.6558
Currently 6.6367
symptoms 6.5641
one 6.4077
Stroke 6.3296
progresses 6.3196
degeneration 6.2876
20 6.2857
waiting 6.2578
Patients 6.2323
PD 6.1677
Motor 6.135
One 6.1258
********* Closest sentence neighbors in output to the token :   [CLS]
who 7.3091
Parkinson 5.6706
progressive 5.531
for 4.6574
suffer 4.3762
from 3.9757
works 3.128


********* Top predictions for token:  Parkinson
Parkinson 15.0947
Patients 14.5075
Patient 13.7469
Subject 13.0993
Those 12.4626
patient 12.395
Participant 12.2981
One 12.2895
Individuals 11.9122
Subjects 11.9035
People 11.6763
Participants 11.6334
and 11.3784
subject 11.249
patients 10.8569
or 10.671
The 10.602
Persons 10.394
person 10.3795
PD 10.1144
Investiga