<a href="https://colab.research.google.com/github/PavleSavic/MLM_consistency/blob/main/consistency.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import random
import logging
import pandas as pd
import numpy as np
import tensorflow as tf
#!pip install transformers datasets evaluate
from transformers import AutoTokenizer, TFAutoModelForMaskedLM

In [2]:
random.seed(123)
tf.keras.mixed_precision.set_global_policy('mixed_float16')
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.ERROR)

## Relations used in analysis

In [None]:
relations = []
with open("final_19_relations.txt") as f:
    lines = f.readlines()
    for l in lines:
        relations.append(l.strip())
print(len(relations))
relations.sort()
print(relations)

19
['associated_morphology_of', 'disease_has_abnormal_cell', 'disease_has_associated_anatomic_site', 'disease_has_normal_cell_origin', 'disease_has_normal_tissue_origin', 'disease_mapped_to_gene', 'disease_may_have_associated_disease', 'disease_may_have_finding', 'disease_may_have_molecular_abnormality', 'gene_associated_with_disease', 'gene_encodes_gene_product', 'gene_product_encoded_by_gene', 'gene_product_has_associated_anatomy', 'gene_product_has_biochemical_function', 'gene_product_plays_role_in_biological_process', 'has_physiologic_effect', 'may_prevent', 'may_treat', 'occurs_after']


## Prompts

In [None]:
prompts = pd.read_csv('prompts.csv')

In [None]:
prompts

Unnamed: 0,pid,default_prompt,human_prompt
0,associated_morphology_of,[X] associated morphology of [Y] .,[X] is associated morphology of [Y] .
1,disease_has_abnormal_cell,[X] disease has abnormal cell [Y] .,[X] has the abnormal cell [Y] .
2,disease_has_associated_anatomic_site,[X] disease has associated anatomic site [Y] .,The disease [X] can stem from the associated a...
3,disease_has_normal_cell_origin,[X] disease has normal cell origin [Y] .,The disease [X] stems from the normal cell [Y] .
4,disease_has_normal_tissue_origin,[X] disease has normal tissue origin [Y] .,The disease [X] stems from the normal tissue [...
5,disease_mapped_to_gene,[X] disease mapped to gene [Y] .,The disease [X] is mapped to gene [Y] .
6,disease_may_have_associated_disease,[X] disease may have associated disease [Y] .,The disease [X] might have the associated dise...
7,disease_may_have_finding,[X] disease may have finding [Y] .,[X] may have [Y] .
8,disease_may_have_molecular_abnormality,[X] disease may have molecular abnormality [Y] .,The disease [X] may have molecular abnormality...
9,gene_associated_with_disease,[X] gene associated with disease [Y] .,The gene [X] is associatied with disease [Y] .


## Masked Language Models

In [None]:
# uncased
bert_models = {'BERT_base' : "google-bert/bert-base-uncased", 'BERT_large': "google-bert/bert-large-uncased",
                'BERT_large_wwm': "google-bert/bert-large-uncased-whole-word-masking"}
# cased
roberta_models = {'RoBERTa_base': "FacebookAI/roberta-base", 'RoBERTa_large': "FacebookAI/roberta-large"}
# uncased
albert_models = {'ALBERT_base': "albert/albert-base-v2", 'ALBERT_xxlarge': "albert/albert-xxlarge-v2"}
# cased
biobert_models = {'BioBERT': "dmis-lab/biobert-base-cased-v1.2"}
# uncased
biomedbert_models = {'BioMedBERT_base_abstract' : "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract",
                     'BioMedBERT_base_full': "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext",
                     'BioMedBERT_large_abstract': "microsoft/BiomedNLP-BiomedBERT-large-uncased-abstract"}

## Example

In [None]:
# List of input texts with masked tokens
texts = ["She is from the city of [MASK].", "This is a great [MASK].", "He is an excellent [MASK]."]

In [None]:
def change_input_format(input):
  new_input = input.replace('[MASK]','<mask>')
  return new_input

In [None]:
def analyze_tokenizer(model_checkpoint, inputs):

  if 'roberta' in model_checkpoint:
    inputs = [change_input_format(input) for input in inputs]

  tokz = AutoTokenizer.from_pretrained(model_checkpoint)
  tokenization = tokz(inputs, return_tensors='tf', padding=True)  # truncation=True, max_length=tokz.model_max_length
  print(f"Tokenization example: {tokenization['input_ids']}")
  for l in tokenization['input_ids']:
    print(f"Decoded tokens: {tokz.decode(l)}")
  print(f"End of sequence token: {tokz.eos_token}")
  print(f"Mask token id: {tokz.mask_token_id}")
  print(f"All special tokens ids: {tokz.all_special_ids}")
  print(f"All special tokens: {tokz.decode(tokz.all_special_ids)}")
  print(f"Maximum model input length: {tokz.model_max_length}")

In [None]:
# One model for each group
print('BERT_base')
analyze_tokenizer(bert_models['BERT_base'], texts)
print('------------------------------------------------------------------------------')
print('BERT_large_wwm')
analyze_tokenizer(bert_models['BERT_large_wwm'], texts)
print('------------------------------------------------------------------------------')
print('RoBERTa_base')
analyze_tokenizer(roberta_models['RoBERTa_base'], texts)
print('------------------------------------------------------------------------------')
print('ALBERT_base')
analyze_tokenizer(albert_models['ALBERT_base'], texts)
print('------------------------------------------------------------------------------')
print('BioBERT')
analyze_tokenizer(biobert_models['BioBERT'], texts)
print('------------------------------------------------------------------------------')
print('BioMedBERT_base_full')
analyze_tokenizer(biomedbert_models['BioMedBERT_base_full'], texts)
print('------------------------------------------------------------------------------')

BERT_base
Tokenization example: [[ 101 2016 2003 2013 1996 2103 1997  103 1012  102]
 [ 101 2023 2003 1037 2307  103 1012  102    0    0]
 [ 101 2002 2003 2019 6581  103 1012  102    0    0]]
Decoded tokens: [CLS] she is from the city of [MASK]. [SEP]
Decoded tokens: [CLS] this is a great [MASK]. [SEP] [PAD] [PAD]
Decoded tokens: [CLS] he is an excellent [MASK]. [SEP] [PAD] [PAD]
End of sequence token: None
Mask token id: 103
All special tokens ids: [100, 102, 0, 101, 103]
All special tokens: [UNK] [SEP] [PAD] [CLS] [MASK]
Maximum model input length: 512
------------------------------------------------------------------------------
BERT_large_wwm
Tokenization example: [[ 101 2016 2003 2013 1996 2103 1997  103 1012  102]
 [ 101 2023 2003 1037 2307  103 1012  102    0    0]
 [ 101 2002 2003 2019 6581  103 1012  102    0    0]]
Decoded tokens: [CLS] she is from the city of [MASK]. [SEP]
Decoded tokens: [CLS] this is a great [MASK]. [SEP] [PAD] [PAD]
Decoded tokens: [CLS] he is an excellen

In [None]:
def get_model_predictions(model_checkpoint:str, inputs:list[str], top_n=5, verbose=0):
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
  model = TFAutoModelForMaskedLM.from_pretrained(model_checkpoint, from_pt=True)

  # Adjusting inputs for RoBERTa models
  if 'roberta' in model_checkpoint:
    inputs = [change_input_format(input) for input in inputs]

  if verbose:
    print(f'Choosen model: {model_checkpoint}')
    model.summary()

  # Tokenizing the inputs
  tokenized_inputs = tokenizer(inputs, return_tensors="tf", padding=True, truncation=True, max_length=128)

  # Getting the token logits from the model
  token_logits = model(**tokenized_inputs).logits

  outputs = []

  for i, input in enumerate(inputs):
                                                                                   # not necessary (all tokenizers have mask_token_id defined)
    mask_token_id = tokenizer.mask_token_id if tokenizer.mask_token_id is not None else tokenizer.convert_tokens_to_ids(["[MASK]"])[0]

    mask_token_index = np.argwhere(tokenized_inputs["input_ids"].numpy()[i] == mask_token_id)[0, 0]

    mask_token_logits = token_logits[i, mask_token_index, :]

    top_tokens = np.argsort(-mask_token_logits.numpy()).tolist()

    predictions = []
    if verbose:
      print(f"Input: {input}")

    for token_id in top_tokens:
        # Skip special tokens
        if token_id in tokenizer.all_special_ids:
          continue

        predicted_token = tokenizer.decode([token_id])

        predictions.append(predicted_token)
        if verbose:
          print(f">>> {input.replace(tokenizer.mask_token, predicted_token)}")

        if len(predictions) == top_n:
          break
    if verbose:
      print()

    outputs.append(predictions)

  return np.array(outputs)

In [None]:
pred = get_model_predictions(model_checkpoint=roberta_models['RoBERTa_base'], inputs=texts, top_n=10)
pred

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/501M [00:00<?, ?B/s]

array([[' Chicago', ' London', ' Toronto', ' Seattle', ' Boston',
        ' Portland', ' Philadelphia', ' Vancouver', ' Minneapolis',
        ' Houston'],
       [' idea', ' example', ' article', ' video', ' post', ' story',
        ' read', ' question', ' book', ' game'],
       [' player', ' writer', ' athlete', ' student', ' shooter',
        ' coach', ' hitter', ' defender', ' guy', ' broadcaster']],
      dtype='<U13')

# Relations

In [None]:
occurs_data = pd.read_csv('occurs_after_1000.csv', usecols=["head_name", "rel", "tail_names"])
occurs_data.head(10)

Unnamed: 0,head_name,rel,tail_names
0,Post influenza vaccination encephalitis,occurs_after,Administration of influenza vaccine
1,Basal cell carcinoma recurrent following cryos...,occurs_after,Cryosurgery
2,Adverse effect from PUVA photochemotherapy,occurs_after,Light therapy || Photochemotherapy with psoral...
3,Allergy to pea,occurs_after,Allergic sensitization
4,Bite of unidentified snake with neurological s...,occurs_after,Animal bite
5,Allergy to hypothalamic hormone,occurs_after,Allergic sensitization
6,Late effect of accidental injury,occurs_after,Traumatic injury
7,Radiotherapy scar,occurs_after,Procedure || Radiation oncology AND/OR radioth...
8,Atonic postpartum hemorrhage,occurs_after,Delivery procedure
9,Late effect of skin and subcutaneous tissue in...,occurs_after,Injury || Traumatic injury


In [None]:
# Prompts to use
rel_name = occurs_data['rel'][0]
default_prompt = prompts.loc[prompts['pid'] == rel_name]['default_prompt'].tolist()[0]
human_prompt  = prompts.loc[prompts['pid'] == rel_name]['human_prompt'].tolist()[0]
print(f"Default prompt: {default_prompt}\nHuman prompt: {human_prompt}")

Default prompt: [X] occurs after [Y] .
Human prompt: [X] occurs after [Y] .


In [None]:
# Preparing inputs
def prepare_inputs(data, prompt:str):
  # number of rows
  n = len(data)

  inputs = [prompt for _ in range(n)]
  heads = data['head_name'].tolist()

  inputs = [input.replace('[X]', head) for input, head in zip(inputs, heads)]
  inputs = [input.replace('[Y] ', '[MASK]') for input in inputs]
  return inputs

In [None]:
inputs = prepare_inputs(occurs_data, default_prompt)
print(inputs[:10])

['Post influenza vaccination encephalitis occurs after [MASK].', 'Basal cell carcinoma recurrent following cryosurgery occurs after [MASK].', 'Adverse effect from PUVA photochemotherapy occurs after [MASK].', 'Allergy to pea occurs after [MASK].', 'Bite of unidentified snake with neurological signs occurs after [MASK].', 'Allergy to hypothalamic hormone occurs after [MASK].', 'Late effect of accidental injury occurs after [MASK].', 'Radiotherapy scar occurs after [MASK].', 'Atonic postpartum hemorrhage occurs after [MASK].', 'Late effect of skin and subcutaneous tissue injury occurs after [MASK].']


In [None]:
# Getting predictions
pred_1 = get_model_predictions(model_checkpoint=biomedbert_models['BioMedBERT_base_abstract'], inputs=inputs, top_n=1)
pred_1

tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/225k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

array([['vaccination'],
       ['years'],
       ['transplantation'],
       ['birth'],
       ['birth'],
       ['surgery'],
       ['birth'],
       ['radiotherapy'],
       ['delivery'],
       ['surgery'],
       ['trauma'],
       ['stroke'],
       ['surgery'],
       ['trauma'],
       ['injection'],
       ['surgery'],
       ['surgery'],
       ['transplantation'],
       ['surgery'],
       ['trauma'],
       ['surgery'],
       ['transplantation'],
       ['surgery'],
       ['immunization'],
       ['chemotherapy'],
       ['cholecystectomy'],
       ['pregnancy'],
       ['esophagectomy'],
       ['varicella'],
       ['treatment'],
       ['surgery'],
       ['surgery'],
       ['cesarean'],
       ['laminectomy'],
       ['surgery'],
       ['trauma'],
       ['colonoscopy'],
       ['catheterization'],
       ['surgery'],
       ['surgery'],
       ['surgery'],
       ['ingestion'],
       ['surgery'],
       ['injection'],
       ['delivery'],
       ['surgery'],
     

In [None]:
tails = occurs_data['tail_names'].tolist()
tails = list(map(lambda x: x.split(' || '), tails))
tails

[['Administration of influenza vaccine'],
 ['Cryosurgery'],
 ['Light therapy', 'Photochemotherapy with psoralens and ultraviolet A'],
 ['Allergic sensitization'],
 ['Animal bite'],
 ['Allergic sensitization'],
 ['Traumatic injury'],
 ['Procedure', 'Radiation oncology AND/OR radiotherapy'],
 ['Delivery procedure'],
 ['Injury', 'Traumatic injury'],
 ['Fat necrosis'],
 ['Spontaneous cerebral hemorrhage'],
 ['Transplantation',
  'Implantation of prosthetic device',
  'Surgical construction of arteriovenous shunt'],
 ['Injury of knee', 'Traumatic event'],
 ['Procedure', 'Injection'],
 ['Colostomy', 'Procedure'],
 ['Procedure'],
 ['Corneal transplant'],
 ['Allergic sensitization'],
 ['Traumatic injury', 'Traumatic event'],
 ['Extraction of cataract', 'Implantation of phakic intraocular lens implant'],
 ['Transplantation of bone marrow', 'Grafting procedure'],
 ['Allergic sensitization'],
 ['Active or passive immunization'],
 ['Allergic sensitization'],
 ['Implantation of prosthetic device', 

In [None]:
def compute_accuracy(predictions, tails):
  hits = 0
  n = len(predictions)

  for i in range(n):
    preds = [prediction.strip().lower() for prediction in predictions[i]]
    tls = [tail.strip().lower() for tail in tails[i]]
    if set(preds).intersection(tls):
      hits += 1

  return (hits/n)*100

In [6]:
def cosine_similarity(v1, v2):
    v1 = np.array(v1)
    v2 = np.array(v2)

    dot_product = np.dot(v1, v2)

    norm_v1 = np.linalg.norm(v1)
    norm_v2 = np.linalg.norm(v2)

    cosine_sim = dot_product / (norm_v1 * norm_v2)

    return cosine_sim

vector1 = [1, 2, 3]
vector2 = [4, 5, 6]

similarity = cosine_similarity(vector1, vector2)
print("Cosine Similarity:", similarity)


Cosine Similarity: 0.9746318461970762


In [None]:
top_1_acc = compute_accuracy(pred_1, tails)
print(f'Top 1 accuracy: {top_1_acc:.2f} %')

Top 1 accuracy: 4.00 %


In [None]:
pred_10 = get_model_predictions(model_checkpoint=biomedbert_models['BioMedBERT_base_full'], inputs=inputs, top_n=10)
pred_10

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertForMaskedLM: ['cls.predictions.decoder.bias']
- This IS expected if you are initializing TFBertForMaskedLM from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertForMaskedLM from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertForMaskedLM were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForMaskedLM for predictions without further training.


Choosen model: microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext
Model: "tf_bert_for_masked_lm_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 bert (TFBertMainLayer)      multiple                  108891648 
                                                                 
 mlm___cls (TFBertMLMHead)   multiple                  24459834  
                                                                 
Total params: 109514298 (417.76 MB)
Trainable params: 109514298 (417.76 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
>>> Post-radiation stricture of intestine occurs after chemotherapy.
>>> Post-radiation stricture of intestine occurs after chemoradiotherapy.
>>> Post-radiation stricture of intestine occurs after chemoradiation.
>>> Post-radiation stricture of intestine occurs after surgery.
>>> Post-radiation stricture of intestine occurs after rt.
>>> Post-radiation stricture of intestine occurs after brachytherapy.
>>> Post-radiation stricture of intestine occurs after treatment.

Input: Allergy to tree resin occurs after [MASK].
>>> Allergy to tree resin occurs after surgery.
>>> Allergy to tree resin occurs after transplantation.
>>> Allergy to tree resin occurs after exposure.
>>> Allergy to tree resin occurs after childhood.
>>> Allergy to tree resin occurs after extraction.
>>> Allergy to tree resin occurs after birth.
>>> Allergy to tree resin occurs after treatment.
>>> Allergy to tree resin occurs after ingestion.
>>> Allergy

array([['vaccination', 'birth', 'immunization', ..., 'delivery',
        'influenza', 'seroconversion'],
       ['radiotherapy', 'years', 'surgery', ..., 'treatment', 'excision',
        'irradiation'],
       ['surgery', 'radiotherapy', 'treatment', ..., 'birth',
        'radiation', 'therapy'],
       ...,
       ['menopause', 'childbirth', 'birth', ..., 'thyroidectomy',
        'puberty', 'splenectomy'],
       ['trauma', 'surgery', 'stroke', ..., 'earthquake', 'operation',
        'accidents'],
       ['surgery', 'transplantation', 'exposure', ..., 'grafting',
        'injection', 'dialysis']], dtype='<U32')

In [None]:
top_10_acc = compute_accuracy(pred_10, tails)
print(f'Top 10 accuracy: {top_10_acc:.2f} %')

Top 10 accuracy: 12.10 %


# MLM acccuracy measuring function

In [None]:
def compute_mml_top_n_accuracy(model_checkpoint:str, relation_dataset:str, dataset_frac=1 , top_n=5, random_state=123, verbose=0):

  data = pd.read_csv(relation_dataset, usecols=["head_name", "rel", "tail_names"])
  # For quicker testing due to resource limitations
  data_chunk = data.sample(frac=dataset_frac, random_state=random_state).reset_index(drop=True)

  rel_name = data_chunk['rel'][0]
  default_prompt = prompts.loc[prompts['pid'] == rel_name]['default_prompt'].tolist()[0]

  inputs = prepare_inputs(data_chunk, default_prompt)

  predicted_objects = get_model_predictions(model_checkpoint=model_checkpoint, inputs=inputs, top_n=top_n, verbose=verbose)

  true_objects = data_chunk['tail_names'].tolist()
  true_objects = list(map(lambda x: x.split(' || '), true_objects))

  return compute_accuracy(predicted_objects, true_objects)

In [None]:
# Comparing BERT models top 10 acc on a chunk of data (10%)
print('Top 10 accuracy')

relation = 'occurs_after'
print(f"Relation: \033[1m{relation}\033[0m")

for k, v in bert_models.items():
  acc = compute_mml_top_n_accuracy(v, f'{relation}_1000.csv', dataset_frac=0.1, top_n=10)
  print(f"{k}: {acc:.2f}\n")

for k, v in roberta_models.items():
  acc = compute_mml_top_n_accuracy(v, f'{relation}_1000.csv', dataset_frac=0.1, top_n=10)
  print(f"{k}: {acc:.2f}\n")

for k, v in albert_models.items():
  acc = compute_mml_top_n_accuracy(v, f'{relation}_1000.csv', dataset_frac=0.1, top_n=10)
  print(f"{k}: {acc:.2f}\n")

for k, v in biobert_models.items():
  acc = compute_mml_top_n_accuracy(v, f'{relation}_1000.csv', dataset_frac=0.1, top_n=10)
  print(f"{k}: {acc:.2f}\n")

for k, v in biomedbert_models.items():
  acc = compute_mml_top_n_accuracy(v, f'{relation}_1000.csv', dataset_frac=0.1, top_n=10)
  print(f"{k}: {acc:.2f}\n")

## Restricted candidate set & Multi-token issue

In [None]:
# Conditional MLM
def fill_masks_independently(model_checkpoint: str, input_query: str, mask_token="[MASK]", top_n=5):

    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
    tokenized_input = tokenizer(input_query, return_tensors="tf", padding=True, truncation=True)

    mask_token_id = tokenizer.mask_token_id if tokenizer.mask_token_id is not None else tokenizer.convert_tokens_to_ids([mask_token])[0]
    mask_indices = tf.where(tf.equal(tokenized_input["input_ids"], mask_token_id))

    model = TFAutoModelForMaskedLM.from_pretrained(model_checkpoint)
    token_logits = model(**tokenized_input).logits

    predictions = []
    for mask_index in mask_indices:
        mask_position = mask_index[1].numpy()
        mask_logits = token_logits[0, mask_position, :]
        top_token_ids = tf.argsort(-mask_logits)[:top_n].numpy()
        top_tokens = tokenizer.decode(top_token_ids)
        predictions.append(top_tokens)

    return predictions

In [None]:
def fill_masks_autoregressively(model_checkpoint: str, input_query: str, mask_token="[MASK]", top_n=5):

    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
    tokenized_input = tokenizer(input_query, return_tensors="tf", padding=True, truncation=True)

    mask_token_id = tokenizer.mask_token_id if tokenizer.mask_token_id is not None else tokenizer.convert_tokens_to_ids([mask_token])[0]
    mask_indices = tf.where(tf.equal(tokenized_input["input_ids"], mask_token_id))

    model = TFAutoModelForMaskedLM.from_pretrained(model_checkpoint)
    #token_logits = model(**tokenized_input).logits

    predictions = []
    for mask_index in mask_indices:
        mask_position = mask_index[1].numpy()
        context = tokenized_input["input_ids"][:, :mask_position]

        for _ in range(top_n):
            mask_logits = model(input_ids=context).logits[0, -1, :]
            predicted_token_id = tf.argmax(mask_logits, axis=-1).numpy()
            predicted_token = tokenizer.decode(predicted_token_id)

            context = tf.concat([context, [[predicted_token_id]]], axis=-1)

            if tokenizer.decode(predicted_token_id) in tokenizer.all_special_tokens or predicted_token in ['.', '!', '?']:
              break

        predictions.append(predicted_token)

    return predictions

In [None]:
# TO DO
def fill_masks_by_confidence(model_checkpoint: str, input_query: str, mask_token="[MASK]", top_n=5):

    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
    tokenized_input = tokenizer(input_query, return_tensors="tf", padding=True, truncation=True)

    mask_token_id = tokenizer.mask_token_id if tokenizer.mask_token_id is not None else tokenizer.convert_tokens_to_ids([mask_token])[0]
    mask_indices = tf.where(tf.equal(tokenized_input["input_ids"], mask_token_id))

    model = TFAutoModelForMaskedLM.from_pretrained(model_checkpoint)
    token_logits = model(**tokenized_input).logits

    predictions = []
    for mask_index in mask_indices:
        mask_position = mask_index[1].numpy()
        mask_logits = token_logits[0, mask_position, :]
        top_token_ids = tf.argsort(-mask_logits)[:top_n].numpy()
        top_tokens = tokenizer.decode(top_token_ids)
        predictions.append(top_tokens)

    return predictions

In [None]:
fill_masks_independently(bert_models['BERT_base'], "Paris is [MASK][MASK] to visit.")

['also not always a definitely', 'fun easy pleasant welcome hard']

In [None]:
fill_masks_autoregressively(bert_models['BERT_base'], "Paris is [MASK][MASK] to visit.")

['and', '"']

In [None]:
fill_masks_by_confidence(bert_models['BERT_base'], "Paris is [MASK][MASK] to visit.")

['also not always a definitely', 'fun easy pleasant welcome hard']