<a href="https://colab.research.google.com/github/IVN-RIN/bio-med-BIT/blob/main/notebooks/BioBERT_Relation_Extraction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **BioBIT Fine-Tuning Experiment For <u>Relation Extraction</u>**

*Tommaso Buonocore, University of Pavia, 2022*

*Last edited: 16/11/2022*

#Initialization

Short string describing the current run

In [None]:
experiment_name = "Chemprot-RE reg3plus only"

## Imports

In [None]:
%%capture
# If running on colab, install first
!pip install datasets evaluate sklearn transformers

# Google Colab only
from IPython.display import display, HTML
from google.colab import files

# General
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import sklearn.metrics
import evaluate
import pandas as pd
import numpy as np
import json
import os
from io import StringIO
import time
from tqdm import tqdm

# HuggingFace Transformers
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW, TrainingArguments, Trainer, EarlyStoppingCallback, set_seed

# Set device to GPU Cuda if available 
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Session Info

In [None]:
session_info = json.loads(os.popen("curl curl ipinfo.io").read())
if device=='cuda':
  gpu_info = pd.read_csv(StringIO(os.popen("nvidia-smi --query-gpu=gpu_name,memory.total --format=csv").read()),names=["name","memory"],header=0)
  session_info[f'gpus'] = [{'name': row["name"], 'memory': row["memory"]} for index, row in gpu_info.iterrows()] 
else: 
  session_info[f'gpus'] = []
session_info['time_start'] = time.strftime("%H:%M:%S", time.localtime())
session_info['experiment_name'] = experiment_name
session_info

#Data Preprocessing

##Expected Input Format

At this point, we expect to have six files already loaded in the current session (by drag and drop):

*   "/content/*_CORPUS.json" : document id and document string
*   "/content/*_REL.txt": details about the relation (e.g., arg1 id, arg2 id, relation label, etc.) separated by "/t"
*   "/content/*_ANN.txt": details about the annotation (e.g., start index, end index, value, tag label, etc.) separated by "/t"

(* = "train", "test", "dev")


##Desired Data Format

Data Format Example

```python
{'guid': 'document_00000',
 'sentence': 'I polimorfismi a singolo nucleotide del gene HNF4alpha sono associati alla conversione in diabete mellito di tipo 2.',
 'subject_entity': {'word': 'HNF4alpha',
                    'start_idx': 45,
                    'end_idx': 54,
                    'type': 'GENEORGENEPRODUCT'},
 'object_entity': {'word': 'diabete mellito di tipo 2',
                   'start_idx': 90,
                   'end_idx': 115,
                   'type': 'DISEASEORPHENOTYPICFEATURE'},
  'label': 1,
  'source': 'wikipedia'}

```

Related Labels Example

```python
{0:"No Relation",
 1:"Association",
 2:"Positive Correlation",
 3:"Negative Correlation"}
```

## Dataset Preparation

Indices wrong after translation, we must recompute them.
To recompute the correct index, we start looking for an exact match in close proximity to the 'old' indices, expanding the search window if the attempt fails. The closest correspondence we get defines the new start and end indices. If no correspondence is found, the example will be dropped.

In [None]:
def recompute_indices(sentence, match, start, end, window = 20):
  idx=-1
  while idx==-1:
    newstart = start-window if (start-window)>=0 else 0
    newend = end+window if (end+window)<len(sentence) else len(sentence)
    substring = sentence[newstart:newend]
    idx = substring.find(match)
    window = window+window
    # If the window has been extended to the whole document and no match has been found, return -1,-1
    if newstart==0 and newend==len(sentence) and idx==-1: 
      return -1,-1
  return newstart+idx,newstart+idx+len(match)-1

We generate the final dictionary in the requested format, addressing all the errors that the translation process may have introduced, namely:


*   No correct indices found for a given annotation
*   Relation pointing to a non-existing document
*   The entity id does not exist in the document indicated by the relation

If the same relationship occurs multiple time, i.e, the same two ann1-ann2 ids appears multiple times in the annotation list for the same document,we add a different entry for each relationship in the final dataset. 
We don't combine each occurrence of ann1 with each occurence of ann2, but we create ann1-ann2 couples according to the closest correspondence in the text.

In [None]:
import pdb
import warnings

def get_final_dict(sentence, a1, a2, label):
  #If entities are NaN, drop
  if pd.isna(a1["text"]) or pd.isna(a2["text"]):
    warnings.warn(f"Example dropped: NaN entity")
    return None

  idx1 = recompute_indices(sentence,a1["text"],a1["start"],a1["end"])
  idx2 = recompute_indices(sentence,a2["text"],a2["start"],a2["end"])

  if idx1[0]==-1 or idx2[0]==-1:
    warnings.warn(f"Example dropped: impossible to recompute the correct indices")
    return None
    
  return {'guid': id, 'sentence': sentence,
          'subject_entity': {'word': a1["text"],
                            'start_idx': idx1[0],
                            'end_idx': idx1[1],
                            'type': a1["type"]},
          'object_entity': {'word': a2["text"],
                            'start_idx': idx2[0],
                            'end_idx': idx2[1],
                            'type': a2["type"]},
          'label': label,
          'source': 'BioRED'}

def format_inputs(df_rel, df_ann, df_corpus, label_mapping):
    formatted_inputs = []
    drop_count = 0

    for i in tqdm(range(df_rel.shape[0])):
      #Relation Info
      rel = df_rel.iloc[i]
      label = rel["relation"]
      id = rel["pmid"]

      #Corpus Info
      #if the relation points to a document that does not extist, drop
      if len(df_corpus.loc[df_corpus["PMID"]==str(id)]["Testo"])==0:
        warnings.warn(f"Example dropped: relation points to non-existing document (pmid: {str(id)})")
        drop_count+=1
        continue
      sentence = df_corpus.loc[df_corpus["PMID"]==str(id)]["Testo"].iloc[0]

      #Entities Info
      arg1_id = rel["arg1_id"].replace("Arg1:", "")
      arg2_id = rel["arg2_id"].replace("Arg2:", "")
      ann1 = df_ann.loc[(df_ann['pmid'] == rel["pmid"]) & (df_ann['entity_id'] == arg1_id)]
      ann2 = df_ann.loc[(df_ann['pmid'] == rel["pmid"]) & (df_ann['entity_id'] == arg2_id)]

      #If no match, one of the arguments does not exists anymore (probably dropped during auto-translation of the dataset)
      #In this case, skip this relation and increment drop count
      if ann1.shape[0]==0:
        #warnings.warn(f"Example dropped: entity id {arg1_id} does not exist in document {id}")
        drop_count +=1
        continue
      elif ann2.shape[0]==0:
        #warnings.warn(f"Example dropped: entity id {arg2_id} does not exist in document {id}")
        drop_count +=1
        continue

      #If the same relationship occurs multiple time, i.e, the same two ann1-ann2 ids appears multiple times in the annotation list for the same document,
      #we add a different row for each relationship in the final dataset. We don't combine each occurrence of ann1 with each occurence of ann2, but we create ann1-ann2 
      #couples according to the closest correspondence in the text

      #we might have more occurences of ann1 then ann2 or vice-versa, therefore we must define two different loops based on which of the two sets is larger
      success = False
      if ann1.shape[0]>ann2.shape[0]:
        for j in range(ann1.shape[0]):
          a1 = ann1.iloc[j]
          diff = ann2['start']-a1['end']
          valid_idx = np.where(diff > 0)[0]
          if len(valid_idx)==0: continue 
          idx = valid_idx[diff.iloc[valid_idx].argmin()]       
          a2 = ann2.iloc[idx]
          formatted_input = get_final_dict(sentence,a1,a2,label_mapping[label])
          if formatted_input is not None:
            formatted_inputs.append(formatted_input)
            success = True
      else:
        for j in range(ann2.shape[0]):
          a2 = ann2.iloc[j]
          diff = ann1['end']-a2['start']
          valid_idx = np.where(diff < 0)[0]
          if len(valid_idx)==0: continue             
          idx = valid_idx[diff.iloc[valid_idx].argmax()]         
          a1 = ann1.iloc[idx]
          formatted_input = get_final_dict(sentence,a1,a2,label_mapping[label])
          if formatted_input is not None:
            formatted_inputs.append(formatted_input)
            success = True
        
      #If we don't manage to generate any new entry from the ann1-ann2 couples, consider this iteration as failed and increment the drop count
      #We don't have to trigger a new warning because they have already been triggered in the get_final_dict function
      if not success:
        drop_count +=1
        continue
    print(f"\ndropped: {round(100*drop_count/i,2)}%")
    return(formatted_inputs)

In [None]:
formatted_datasets = {
    "train":[],
    "test":[],
    "dev":[]
}

#CHEMPROT
colnames = {"rel":["pmid","cpr","eval_type","relation","arg1_id","arg2_id"],
            "ann":["pmid","entity_id","start","end","text","type"]}
#BIORED
#colnames = {"rel":["pmid","relation","arg1_id","arg2_id","novel"],
#            "ann":["pmid","start","end","text","type","entity_id"]}


#Label Mapping using 
df_rel = pd.read_csv("train_REL.txt", sep='\t',names=colnames["rel"], header=None)

#map label strings to integers and vice versa
labels = np.unique(df_rel[["relation"]])
num_to_labels = dict(zip([i for i in range(len(labels))],labels)) 
labels_to_num = {v: k for k, v in num_to_labels.items()}

for key in formatted_datasets.keys():
  #REL DATAFRAME
  df_rel = pd.read_csv(key+"_REL.txt", sep='\t',names=colnames["rel"], header=None)

  #ANN DATAFRAME
  df_ann = pd.read_csv(key+"_ANN.txt", sep='\t', names=colnames["ann"], header=None)

  #CORPUS DATAFRAME
  f = open(key+"_CORPUS.json", encoding='utf-8')
  df_corpus = pd.DataFrame(json.load(f)["data"])
  # Closing file
  f.close()

  formatted_datasets[key] = format_inputs(df_rel, df_ann, df_corpus, labels_to_num)

Check correspondence between label nums and label strings + the numerosity of each label in the training set

In [None]:
label_count = {}
for data in formatted_datasets["train"]:
    label = str(data['label'])+") "+num_to_labels[data['label']]
    if label not in label_count:
        label_count[label] = 1
    else:
        label_count[label] += 1

label_count = dict(sorted(label_count.items(), key=lambda x: x[0]))
label_count

Add special tokens <subj> e <obj> to tag the sentence with the correspondent annotation 1 and annotation 2 of the relation

For instance, this entry

```python
{'guid': 'document_00000',
 'sentence': 'I polimorfismi a singolo nucleotide del gene HNF4alpha sono associati alla conversione in diabete mellito di tipo 2.',
 'subject_entity': {'word': 'HNF4alpha',
                    'start_idx': 45,
                    'end_idx': 54,
                    'type': 'GENEORGENEPRODUCT'},
 'object_entity': {'word': 'diabete mellito di tipo 2',
                   'start_idx': 90,
                   'end_idx': 115,
                   'type': 'DISEASEORPHENOTYPICFEATURE'},
  'label': 1,
  'source': 'wikipedia'}

```

becomes this:


```python
'I polimorfismi a singolo nucleotide del gene <subj>HNF4alpha</subj> sono associati alla conversione in <obj>diabete mellito di tipo 2</obj>.'
```


In [None]:
def add_entity_tokens(sentence, object_entity, subject_entity):
    obj_start_idx, obj_end_idx = object_entity['start_idx'], object_entity['end_idx']
    subj_start_idx, subj_end_idx = subject_entity['start_idx'], subject_entity['end_idx']
    
    if obj_start_idx < subj_start_idx:
        new_sentence = sentence[:obj_start_idx] + '<obj>' + sentence[obj_start_idx:obj_end_idx+1] + '</obj>' + \
                       sentence[obj_end_idx+1:subj_start_idx] + '<subj>' + sentence[subj_start_idx:subj_end_idx+1] + \
                       '</subj>' + sentence[subj_end_idx+1:]
    else:
        new_sentence = sentence[:subj_start_idx] + '<subj>' + sentence[subj_start_idx:subj_end_idx+1] + '</subj>' + \
                       sentence[subj_end_idx+1:obj_start_idx] + '<obj>' + sentence[obj_start_idx:obj_end_idx+1] + \
                       '</obj>' + sentence[obj_end_idx+1:]
    
    return new_sentence

def parse_re_dataset(dataset):
    sentences = []
    labels = []
    
    for data in dataset:
        sentence = add_entity_tokens(data['sentence'], data['object_entity'], data['subject_entity'])
        sentences.append(sentence)
        labels.append(data['label'])

    ds = Dataset.from_pandas(pd.DataFrame({'text': sentences,'label': labels}))

    return ds

In [None]:
train_ds = parse_re_dataset(formatted_datasets["train"])
dev_ds = parse_re_dataset(formatted_datasets["dev"])
test_ds = parse_re_dataset(formatted_datasets["test"])

In [None]:
train_ds[2]

# Training

The task consists in calssifying N different types of relations. To do so, we create a classification model that uses the [CLS] token to output the most probable class between N passing it through a linear layer with output dimension N.

![Imgur](https://i.imgur.com/qaUObkV.png)


In [None]:
# Mount Google Drive 
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
model_checkpoints = [
                     #"dbmdz/bert-base-italian-xxl-cased",
                     #"/content/gdrive/MyDrive/Colab Environments/biobert_models/bio-full",
                     "/content/gdrive/MyDrive/Colab Environments/biobert_models/med-reg-v3",
                     #"/content/gdrive/MyDrive/Colab Environments/biobert_models/med-reg-v12",
                     #"/content/gdrive/MyDrive/Colab Environments/biobert_models/med-reg-v3-enriched"
                    ]

seeds = [
         #3407, 
         #6, 
         11, 
         61,
         1
        ]

#This can be changed according to the downstream dataset. The only important thing is that they remain consistent for *ALL* the models   
batch_size = 16
learning_rate = 3e-5
epochs=7
weight_decay=0.01

## Preprocessing functions

In [None]:
def prepare_tokenizer(model_name):
  tokenizer = AutoTokenizer.from_pretrained(model_name)
  # Bisogna aggiungere gli special tokens per far capire al modello che non devono essere trattati come testo normale della sequenza
  entity_special_tokens = {'additional_special_tokens': ['<obj>', '</obj>', '<subj>', '</subj>']}
  num_additional_special_tokens = tokenizer.add_special_tokens(entity_special_tokens)
  return tokenizer

def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

## Metrics

In [None]:
m_calc = {
    "precision":evaluate.load("precision"),
    "recall":evaluate.load("recall"),
    "accuracy":evaluate.load("accuracy"),
    "f1":evaluate.load("f1")
}
def compute_metrics(p):
  logits, labels = p
  predictions = np.argmax(logits, axis=1)
  return {'acc': m_calc["accuracy"].compute(references=labels, predictions=predictions)["accuracy"],
          'prec': m_calc["precision"].compute(references=labels, predictions=predictions, average="weighted")["precision"],
          'recall': m_calc["recall"].compute(references=labels, predictions=predictions, average="weighted")["recall"],
          'f1': m_calc["f1"].compute(references=labels, predictions=predictions, average="weighted")["f1"]}

## Training Loop

In [None]:
for model_checkpoint in model_checkpoints:
  results_collector = []
  for seed in seeds:
    # Seed must be set before creating the model, otherwise the random head will be initialized in a different way every time and the results will not be replicable
    # From now on, the seed is set for *all* the random processes, including numpy, sklearn, etc...not only for transformers!
    set_seed(seed)

    # Initialize the tokenizer
    tokenizer = prepare_tokenizer(model_checkpoint)

    # Initialize the TokenClassification transformer with checkpoint weights            
    model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=len(label_count))
    
    # Load del modello. Abbiamo aggiunto 4 tokens, anche se speciali, quindi bisogna ridimensionare i layers di BERT aggiungendo 4, quindi 32000-->32004
    model.resize_token_embeddings(len(tokenizer))

    # Processa gli input e mettili in formato compatibile con il modello
    train_dataset = train_ds.map(preprocess_function, batched=True)
    test_dataset = test_ds.map(preprocess_function, batched=True)
    dev_dataset = dev_ds.map(preprocess_function, batched=True)

    training_args = TrainingArguments(
        output_dir=f"/content/{os.path.basename(model_checkpoint)}_ft_RE/{seed}",           
        evaluation_strategy="epoch",
        logging_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=3,
        load_best_model_at_end = True,
        metric_for_best_model = "f1",
        learning_rate=learning_rate,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=epochs,
        weight_decay=weight_decay
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=dev_dataset,
        tokenizer=tokenizer,
        compute_metrics = compute_metrics,
        callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
    )

    trainer.train()

    predictions,label_ids, metrics  = trainer.predict(test_dataset)
    print(metrics)
    metrics["seed"]=seed
    results_collector.append(metrics)

  df_results = pd.DataFrame(results_collector)
  display(df_results)
  df_results.to_csv(f'/content/RE_results_{os.path.basename(model_checkpoint)}.csv')
  files.download(f'/content/RE_results_{os.path.basename(model_checkpoint)}.csv')

Finalize session info and download

In [None]:
session_info['checkpoints'] = [os.path.basename(c) for c in model_checkpoints]
session_info['seeds'] = seeds
session_info['training_arguments'] = training_args.to_dict()
session_info['time_end'] = time.strftime("%H:%M:%S", time.localtime())

with open(f'/content/session_info.json', "w") as outfile:
    outfile.write(json.dumps(session_info, indent=4))
files.download(f'/content/session_info.json')