<a href="https://colab.research.google.com/github/IVN-RIN/bio-med-BIT/blob/main/notebooks/BioBIT_Language_Modeling_Evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **BioBIT <u>Language Modeling</u> Evaluation**

*Tommaso M Buonocore, University of Pavia, 2022*

*Last edited: 05/12/2022*

*Related paper: [Localising In-Domain Adaptation of Transformer-Based Biomedical Language Models](https://www.medrxiv.org/content/XXXXXX)*

---

This notebook contains some tests carried out on the **BioBIT** and **MedBIT** models regarding the *Masked Language Modelling* (MLM) task, which consists in partially masking the model input text and verifying whether it is able to correctly predict the missing portions of the sequence.

The tests are carried out using the 'fill_mask' method of HuggingFace, which does not support the replacement of more than one token for each input. This means, that if you mask the word "asma", which is actually represented in BERT as two tokens, "\[as, ##ma\]", the expected result is "as", not "asma".

## Init


In [None]:
%%capture
!pip3 install datasets transformers seqeval
from transformers import pipeline,AutoModel, AutoTokenizer, AutoModelForMaskedLM, logging
from datasets import load_metric

###Imports

In [None]:
import pandas as pd
import numpy as np
import torch
from random import sample
import math 
import json
import string
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import os

We rely upon google drive for file management

In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
models_path = 'your/path/biobert_models'
os.chdir('gdrive/MyDrive/'+models_path)

In [None]:
logging.set_verbosity_error() # reduce verbosity

In [None]:
checkpoints = {'baseBIT':'dbmdz/bert-base-italian-xxl-cased',
               'bioBIT_xs':'bio-tiny', 
               'bioBIT_s':'bio-small', 
               'bioBIT_m':'bio-medium', 
               'bioBIT_l':'bmi-labmedinfo/bioBIT', 
               #'medBIT_OR':'med-reg-original',
               #'medBIT':'bmi-labmedinfo/medBIT',
               #'medBIT_r0':'med-reg-v0',
               #'medBIT_r1':'med-reg-v1',
               #'medBIT_r2':'med-reg-v2',
               #'medBIT_r3':'med-reg-v3',
               #'medBIT_rf':'med-frozen',
               #'medBIT_r3f':'med-reg-v3-clean-fixed-10',
               #'medBIT_r3+':'bmi-labmedinfo/medBIT-r3-plus',
               #'medBIT_r12+':'med-reg-v12'
               } 

tokenizers = {}
for k,v in checkpoints.items():
  tokenizers[k] = AutoTokenizer.from_pretrained(v, truncation=True)

### Utilities

In [None]:
# Evaluate MLM of model on a list of sentences
def evaluate_mlm(model, tokenizer, masked_sentences, top_k = 5, print_progression = True):
    # Define pipeline
    fill_mask = pipeline(
      task='fill-mask',
      model=model,
      tokenizer=tokenizer,
      top_k=top_k
    )
    # Apply pipeline to every sentence
    if print_progression:
      results = [fill_mask(el) for el in tqdm(masked_sentences, position=0, leave=True)]
    else:
      results = [fill_mask(el) for el in masked_sentences]
    return results

In [None]:
# Get value of [MASK] in masked sentence
def get_masked_word(original_sentences, masked_sentences):
    masked = []
    exclude = set(string.punctuation)
    for i, el in enumerate(masked_sentences):
        begin = el.find('[MASK]')
        end = begin+original_sentences[i][begin:].find(' ') # Look for first white space after begin of masked word
        if end<begin:
            end = begin+original_sentences[i][begin:].find('.') # Look for first period after begin of masked word
        token = original_sentences[i][begin:end]
        token = ''.join(ch for ch in token if ch not in exclude) # Remove punctuation
        masked.append(token)
    return masked

In [None]:
# Print results of MLM in a convenient way
def get_results(results, original_sentences, masked_sentences, tokenizer, verbose=True):
    refs = get_masked_word(original_sentences, masked_sentences)
    scores = [0]*len(refs)
    for i, sentence in enumerate(results):
        if verbose: print(f"""------------------- Sentence #{i+1} -------------------\nMasked sentence = {masked_sentences[i]}\nContext size = {len(tokenizer(original_sentences[i])['input_ids'])} tokens\nTarget = {refs[i]}\nPredicted:\n""")
        for j,el in enumerate(sentence):
            if refs[i].lower()==el['token_str'].lower():
              scores[i]=round(1/(j+1),2)
            if verbose: print(f"{j+1}) {el['token_str']} [{(el['score']*100):.2f}]")
    mlm_score = sum(scores)/len(scores)
    if verbose: print(f"""\nMLM score = {mlm_score:.2f}\n-- score details: {scores}\n""")
    return {'avg':mlm_score,'scores':scores}

In [None]:
def get_dataframe(checkpoints, tokenizers, original_sentences, masked_sentences, top_k=5, verbose = False):
  targets = get_masked_word(original_sentences, masked_sentences)
  columns = ['avg_score']+targets
  df = pd.DataFrame(columns=columns, index = checkpoints.keys())
  for k,v in checkpoints.items():
    r = evaluate_mlm(v, tokenizers[k], masked_sentences, top_k=top_k)
    s = get_results(r, original_sentences, masked_sentences, tokenizers[k], verbose=verbose)
    row = [s['avg']]+s['scores']
    df.loc[k]=row
  return(df)

## MLM test on biomedical concepts

This test evaluates the lexical comprehension of the language models, focusing on biomedical concepts. The masked words are mainly nouns related to the biomedical field that can be understood by looking at the sorrounding context.

Sentences have different lenghts and belong to different categories.

Input data format:

```json
data: [
    {
      "source": url,
      "type": category,
      "original": original sentence,
      "masked": ["masked sentence 1", "masked sentence 2", "..."]
    },
    {..},
    {..}
]
```

### Load MLM dataset

In [None]:
f = open("/content/mlm_data.json", encoding="utf-8")
json_data = json.load(f)
f.close()

In [None]:
sentences = json_data["data"]
#uncomment next line for downsampling
#sentences = [sentences[i] for i in sample(range(len(json_data["data"])),10)]
masked_count = sum([len(sentence["masked"]) for sentence in sentences])
num_tok = [len(t) for t in tokenizers[list(tokenizers.keys())[0]]([sentence["original"] for sentence in sentences])['input_ids']]

print(f"Original sentences: {len(sentences)}")
print(f"Masked sentences: {masked_count}")
print(f"Sentence length: {min(num_tok)}-{max(num_tok)} tokens (avg: {np.mean(num_tok):.0f})")

### Categories Distribution

In [None]:
# categories are collected in Italian, build a it2en dict to get the english version
cat_it2en = {"Allergologia":"Allergy","Altro":"Other","Bioetica":"Bioethics","Biologia Cellulare":"Cell Biology","Cardiologia":"Cardiology","Chirurgia":"Surgery","Diabetologia":"Diabetology","Ematologia":"Hematology","Endocrinologia":"Endocrinology","Epidemiologia":"Epidemiology","Farmacologia":"Pharmacology","Fisiologia":"Physiology","Malattie Rare":"Rare Diseases","Nefrologia":"Nephrology","Neurologia":"Neurology","Odontoiatria":"Dentistry","Oncologia":"Oncology","Ortopedia":"Orthopedics","Pediatria":"Pediatrics","Pneumologia":"Pneumology","Psichiatria":"Psychiatry","Radiologia":"Radiology"}
# textbooks categories have been identified and collected manually
textbooks_cat = {"Allergologia":1, "Altro":45, "Bioetica":5,"Biologia Cellulare":2, "Cardiologia":4,"Chirurgia":4,"Diabetologia":0,"Ematologia":4,"Endocrinologia":0,"Epidemiologia":1,"Farmacologia":14,"Fisiologia":6,"Malattie Rare":1,"Nefrologia":1,"Neurologia":10,"Odontoiatria":3,"Oncologia":5,"Ortopedia":4,"Pediatria":5,"Pneumologia":5,"Psichiatria":21,"Radiologia":8}

textbooks_cat_fig = sum([[cat_it2en[k]]*v for k,v in textbooks_cat.items()],[])
mlm_cat_fig = sum([[cat_it2en[sentence["type"]]]*len(sentence["masked"]) for sentence in sentences],[])

#plot
font = {'size'   : 18}
plt.rc('font', **font)
fig, ax = plt.subplots(figsize=(18, 8))
ax.tick_params(axis='x', rotation=90)
ax.set_xlabel('\nCategory')
ax.set_ylabel('% Total')
plt.hist([sorted(mlm_cat_fig), sorted(textbooks_cat_fig)], density=True, bins = np.arange(len(set(mlm_cat_fig))+1)-0.5)
ax.legend(loc='upper right', labels=['MLM Dataset', 'Textbooks Corpus'])
plt.savefig('/content/mlm_data_cat_distribution.png')

### Results

Note: ```[MASK]``` replacement is token-based. If you mask a word that is represented by WordPiece as a sequence of multiple tokens, you'll get unexpected results.

example:

```
INPUT: "L'[MASK] è una condizione in cui le vie respiratorie si restringono e si gonfiano." 
(EN: "[MASK] is a condition in which your airways narrow and swell."])

MASKED WORD: "asma" (EN: "asthma")

EXPECTED OUTPUT: "asma"
ACTUAL OUPTUT: "as"

CORRECT = False
```

This happens because "asma" is tokenized as ```["as","##ma"]```, and the model replaces [MASK] with a single token while we are masking a whole multi-token word.


In [None]:
tokenizers['bio-full'].convert_ids_to_tokens(tokenizers['bio-full']("asma")["input_ids"])

Collect the results for each model checkpoint and save everything into a dataframe.

In [None]:
#warning: it takes approx 10 mins for each model on a CPU (1s/it worst case)
original_sentences = sum([[sentence["original"]]*len(sentence["masked"]) for sentence in sentences],[])
masked_sentences = sum([sentence["masked"] for sentence in sentences],[])
df = get_dataframe(checkpoints, tokenizers, original_sentences, masked_sentences, top_k=5)
df.T.reset_index()
df.to_csv("/content/dataframeMLM.csv")
df

If you want to test a single sentence:

In [None]:
masked_sentence = "L'ipotesi più in voga è che nell’Alzheimer la regione dell’ippocampo riduca la capacità di gestire la dopamina andando a compromettere la [MASK] che è il principale sintomo della patologia."
original_sentence = "L'ipotesi più in voga è che nell’Alzheimer la regione dell’ippocampo riduca la capacità di gestire la dopamina andando a compromettere la memoria che è il principale sintomo della patologia."
get_dataframe(checkpoints, tokenizers, [original_sentence], [masked_sentence], top_k=5, verbose = True)

Most of predictions are alwasy wrong or always correct for all the models.
Now we want to focus only on those examples where the predictions changes.

In [None]:
df = pd.read_csv("dataframeMLM.csv", index_col=0)

df =df.drop(["medBIT_ro","medBIT_s","medBIT_r3"],axis=0)

df_fig = df.drop("avg_score",axis=1).T
df_fig = df_fig.astype(float)
print(f"Never-predicted entries: {100*len(df_fig[(df_fig == 0).all(1)])/df.shape[1]:.1f}%")
print(f"Always-correct predictions: {100*len(df_fig[(df_fig == 1).all(1)])/df.shape[1]:.1f}%")
df_fig = df_fig[(df_fig != 0).any(1)] #remove the always-zero rows
df_fig = df_fig[(df_fig != 1).any(1)] #remove the always-one rows
print(f"Changing predictions: {100*len(df_fig)/df.shape[1]:.1f}%")
df_fig.to_csv("/content/dataframeMLM_filtered.csv")

#### Heathmap

In [None]:
%matplotlib inline
newcmp = ListedColormap(plt.get_cmap("Blues",100)([0]*19+[15]*5+[30]*8+[45]*17+[60]*50+[75]*1))
plt.figure(figsize=(6, 46))
plt.tick_params(axis='both', which='major', labelsize=10, labelbottom = False, bottom=False, top = False, labeltop=True)
ax = sns.heatmap(df_fig, linewidths=.5, cmap = newcmp, cbar = False)
plt.savefig('/content/heatmapMLM.png')
plt.show()
print(f"\nAverage Score:\n{round((df_fig.sum(axis=0))/len(df_fig),2)}")
print(f"size: {len(df_fig)}")

In [None]:
def plot_rank_progression(word):
  #if more than one appearance of the same word, average
  df_selection = df_fig.loc[word] if len(df_fig.loc[word].shape)<=1 else df_fig.loc[word].mean(axis=0)
  
  g = sns.relplot(
      data=df_selection, kind="line"  #change range if you want to include more/less models
  )
  (g.set_axis_labels("\n"+word.upper(), "MRR")
    .set(ylim=(0, 1))
    .set_xticklabels(rotation=90)
    .set_titles("Region")
    .tight_layout(w_pad=0))

#### Rank Progression

In [None]:
from matplotlib.lines import Line2D
df_fig2 = df_fig.groupby(level=0).agg('mean')
custom_lines = [Line2D([0], [0], color='blue',lw=4),
                Line2D([0], [0], color='orange',lw=4)]
fig, ax = plt.subplots(figsize=(16, 12))
ax.tick_params(axis='x', rotation=90)
ax.set_xlabel('Model')
ax.set_ylabel('MRR')
ax.legend(custom_lines, ['Single Words', 'Average'], loc="upper right")
p1 = sns.lineplot(data=df_fig2.T, ax=ax, legend=False, alpha=0.05, linestyle=":", palette = ['blue']*len(df_fig2))
p2 = sns.lineplot(data=df_fig2.mean(axis=0).T, linewidth = 3, ax=ax, marker='o')
for col in df_fig2.columns:
  ax.annotate(round(df_fig2.mean(axis=0).T[col],2), (col,df_fig2.mean(axis=0).T[col]-0.025),color="orange", ha='center')
plt.savefig('/content/trajectoryMLM.png')

For a single word:

In [None]:
plot_rank_progression("colon")

#### Sankey Plot

In [None]:
def getSankeyDataFrame(sourcename, targetname, df, label_offset = 0):
  dfsankey = df
  dfsankey = dfsankey.loc[[sourcename,targetname]].T
  dfsankey = dfsankey.drop("avg_score")
  dfsankey = dfsankey.rename(columns={sourcename: "source", targetname: "target"})
  dfsankey["word"] = dfsankey.index
  dfsankey = dfsankey.groupby(['source','target']).agg({'source':'size','word': lambda x: ', '.join(x)}).rename(columns={'source':'value'}).reset_index()
  dfsankey = dfsankey.drop([0,dfsankey.shape[0]-1])
  dfsankey["increment"] = dfsankey["source"]<=dfsankey["target"]
  dfsankey["source"] = dfsankey["source"].map({0.0 : 5+label_offset, 0.2 : 4+label_offset, 0.25 : 3+label_offset,  0.33 : 2+label_offset, 0.5 : 1+label_offset, 1.0 : 0+label_offset})
  dfsankey["target"] = dfsankey["target"].map({0.0 : 11+label_offset, 0.2 : 10+label_offset, 0.25 : 9+label_offset,  0.33 : 8+label_offset, 0.5 : 7+label_offset, 1.0 : 6+label_offset})
  dfsankey["increment"] = dfsankey["increment"].map({True : 'rgba(86,180,86, 0.7)', False: 'rgba(222,82,83, 0.7)'})
  dfsankey = dfsankey.sort_values(by=['target'])
  return dfsankey

In [None]:
import plotly.graph_objects as go

names = ["base","bio-full","med-reg-v0","med-reg-v1"]
dfs_list = []
for i in range(1,len(names)):
  dfs_list.append(getSankeyDataFrame(names[i-1],names[i],df,6*i))
dfs = pd.concat(dfs_list, axis=0)

fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 10,
      thickness = 10,
      line = dict(color = "black", width = 0.5),
      label = ["1st","2nd","3rd","4th","5th","None"]*2*len(dfs_list),
      color = ["#1F6FB3","#4896C8","#7EB8DA","#B5D4E9","#D9E8F5","#F7FBFF"]*2*len(dfs_list),#,"#DCE2F0","#AFAEDA","#9C81C4","#9856AB","#833E75","#59283A"],
    ),
    link = dict(
      source = dfs["source"], # indices correspond to labels, eg A1, A2, A1, B1, ...
      target = dfs["target"],
      value = dfs["value"],
      color = dfs["increment"],
      label = dfs["word"]
  ))])

for x_coordinate, column_name in enumerate(names):
  fig.add_annotation(
          x=x_coordinate / (len(names) - 1),
          y=1.1,
          xref="paper",
          yref="paper",
          text=column_name,
          showarrow=False,
          font=dict(
              size=12,
              color="black"
              ),
          align="center",
          )
fig.show()

#### Standard Scatter+Line plot for Paper

In [None]:
dftrend = pd.read_csv("mlm_trend.csv")

In [None]:
#the 28GB jump between BioBIT_M and BioBIT_L is an issue when plotting the MRR progression
#convert pretraining size from int to str to make it categorical and add an axis break manually
dftrend["total_pt_size"]=[str(x) for x in dftrend["total_pt_size"]]

In [None]:
font = {'size'   : 18}
plt.rc('font', **font)

f, ax = plt.subplots(figsize=(14, 10))
ax.set_xlabel('Total pretraining size (GB)')
ax.set_ylabel('MRR')
f = sns.scatterplot(data=dftrend, x="total_pt_size", y="mrr", style="latest_pt_corpus_type", s=180)
#[ax.text(p[0], p[1], p[2], color='#2078B4', ha='left') for p in zip(dftrend["total_pt_size"], dftrend["mrr"]+0.0013,dftrend["model"])]
dfline = dftrend[((dftrend['model'] != "BioBIT_XS") & (dftrend['model'] != "MedBIT") &  (dftrend['model'] != "MedBIT_R3+"))]
leg = plt.legend(title='Latest pretraining corpus\n')
leg._legend_box.align = "left"
plt.plot(dfline["total_pt_size"], dfline["mrr"], color='orange', linewidth=2)
plt.savefig('/content/trend.png')

## Test su CONTESTO

This test focus on determine how the wideness and position of the context window around the masked word influences the predictions.

We take into account a selection of biomedical texts with around 512 tokens long and we repeat the same process seen before but changing the size of the context window.

Example:

```

ORIGINAL TEXT: "(1) È difficile consigliare una prevenzione efficace. (2) Senza dubbio è importante non fumare e seguire una dieta povera di alcol. (3) Circa il 70% dei tumori del pancreas si sviluppa nella \[MASK\] dell'organo. (4) Nella maggior parte dei casi, il tumore ha origine nei dotti che trasportano gli enzimi della digestione. (5) Tale neoplasia prende il nome di adenocarcinoma duttale del pancreas."

N.B.: (N) indicates the sentence index and is not included in the text.

MASKED WORD: "testa"

FIRST ITERATION INPUT ([3]): "(3) Circa il 70% dei tumori del pancreas si sviluppa nella \[MASK\] dell'organo."

SECOND ITERATION INPUT ([2,3,4]): "(2) Senza dubbio è importante non fumare e seguire una dieta povera di alcol. (3) Circa il 70% dei tumori del pancreas si sviluppa nella \[MASK\] dell'organo. (4) Nella maggior parte dei casi, il tumore ha origine nei dotti che trasportano gli enzimi della digestione."

THIRD ITERATION INPUT ([1,2,3,4,5]): "(1) È difficile consigliare una prevenzione efficace. (2) Senza dubbio è importante non fumare e seguire una dieta povera di alcol. (3) Circa il 70% dei tumori del pancreas si sviluppa nella \[MASK\] dell'organo. (4) Nella maggior parte dei casi, il tumore ha origine nei dotti che trasportano gli enzimi della digestione. (5) Tale neoplasia prende il nome di adenocarcinoma duttale del pancreas."
```

In [None]:
g = open("/content/mlm_context_window.json", encoding="utf-8")
json_data = json.load(g)
contexts = json_data["contexts"]
g.close()

In [None]:
# utility to create different versions of the input text varying in length
def generate_input_sents(sents, ind_center, replacement=""):
  i = 0
  sentences = []
  if replacement!="":
    sents[ind_center] = replacement
  while ind_center>i:
    sentences.append(' '.join([sents[i] for i in range(ind_center-i,ind_center+i+1)]))
    i+=1 
  return(sentences)

In [None]:
def get_progression(checkpoint_name, sents, ind_center, mask):
  m = generate_input_sents(sents, ind_center, replacement=mask)
  r = evaluate_mlm(checkpoints[checkpoint_name], tokenizers[checkpoint_name], m, print_progression=False)
  s = get_results(r, original_sentences, m, tokenizers[checkpoint_name], verbose=False)
  return s

#### Results

In [None]:
for context in [contexts[0]]:
#for context in contexts: 

  original_text = context["original"]
  masks = context["masks"]
  sents = original_text.split("\n")
  ind_center = math.floor(len(sents)/2)
  original_sentences = generate_input_sents(sents, ind_center)

  print("=============================")
  print(f"* context: '{original_text}'")
  print(f"\n* length: {len(tokenizers['base'](original_text)['input_ids'])} tokens")
  print(f"* num sentences: {len(sents)}")
  print(f"* central sentence: '{sents[ind_center]}'")
  print("-----------------------------\n")

  for mask in masks:
    m = generate_input_sents(sents, ind_center, replacement=mask)
    print(m[0])
    print(f"[MASK] = {get_masked_word(original_sentences, m)[0]}\n")
    for k,v in checkpoints.items():
      s = get_progression(k, sents, ind_center, mask)
      print(f" -{k}: {s['avg']:.2f}%, {s['scores']}")
    print("-----------------------------\n")

# Psuedo Perplexity

There is a paper called [Masked Language Model Scoring](https://arxiv.org/abs/1910.14659) that explores pseudo-perplexity from masked language models and shows that pseudo-perplexity, while not being theoretically well justified, still performs well for comparing "naturalness" of texts.

Here we apply PPPL to all the sentences included in the MLM dataset for each model checkpoint.

In [None]:
import gc
def score(model, tokenizer, sentence):
  """
  The function below calculates average probability of each token in the sentence given all the other tokens.
  In masked language models, this does not amount to the total probability of the whole sentence (the conditional probabilities do not cancel each other out), but it is still a useful measure of a "naturallness" of a sentence.
  """
  tensor_input = tokenizer.encode(sentence, return_tensors='pt')
  repeat_input = tensor_input.repeat(tensor_input.size(-1)-2, 1)
  mask = torch.ones(tensor_input.size(-1) - 1).diag(1)[:-2]
  masked_input = repeat_input.masked_fill(mask == 1, tokenizer.mask_token_id)
  labels = repeat_input.masked_fill( masked_input != tokenizer.mask_token_id, -100)
  del tensor_input, repeat_input, mask
  with torch.inference_mode():
      loss = model(masked_input.to(device), labels=labels.to(device)).loss
  del masked_input, labels
  torch.cuda.empty_cache()
  gc.collect()
  return np.exp(loss.item())

In [None]:
def get_pseudo_perplexity_total(model, tokenizer, sentences):
  pppl = 0
  skipped = 0
  for i in tqdm(range(len(sentences))):
    sentence = sentences[i]
    if tokenizer.encode(sentence, return_tensors='pt').shape[1]>170: ###OOM Error in Colab if text is too long
      skipped += 1
      continue
    pppl += score(model,tokenizer,sentence)
    torch.cuda.empty_cache()
    gc.collect()
  return pppl/(len(sentences)-skipped)

In [None]:
original_sentences = [sentence["original"] for sentence in sentences]
device = 'cuda' if torch.cuda.is_available() else 'cpu'

for k,v in checkpoints.items():
  model = AutoModelForMaskedLM.from_pretrained(v)
  tokenizer=tokenizers[k]
  pppl_total = 0
  pppl_total = get_pseudo_perplexity_total(model, tokenizer, original_sentences)
  print(f"Model Name: {k}")
  print(f"Average PPPL: {pppl_total}")
  del model, tokenizer