Installing Libraries and Dependencies

In [None]:


!pip install transformers
!pip install torchmetrics
!pip install pytorch-lightning
!pip install seqeval
!pip install datasets

from math import ceil
import random
import pandas as pd
from torch.utils.data import DataLoader
from torch.optim import Adam 
from torch.nn.functional import cross_entropy
import pytorch_lightning as pl
import json
import numpy as np
import random
from transformers import BertModel, BertForTokenClassification, BertTokenizer, AutoTokenizer
from torch.nn import Linear, Sigmoid
import torch
import torchmetrics
import requests
import datasets
import spacy
nlp = spacy.load("en_core_web_sm")
from pytorch_lightning.loggers import TensorBoardLogger

Setting parameters

In [None]:
from google.colab import drive
drive.mount('/content/drive')

#Setting Batch Size for Named Entity Recognition task
BATCH_SIZE_NER = 16 

#Setting maximum length of BERT tokens for NER
MAX_LENGTH_NER = 500

#Setting maximum length of BERT tokens for Relation Extraction
MAX_LENGTH_RE = 200

#Setting max epochs for NER
NER_EPOCHS = 3

#Setting Path
GDRIVE_PATH='/content/drive/MyDrive/Hackathon: European Patent Office (EPO)'

#Choosing BERT model for tokenization and training
BERT_MODEL = 'dmis-lab/biobert-base-cased-v1.2'


Mounted at /content/drive


In [None]:
#Creating list of entities from dataset
LIST_TAGS=["O","B-STARTING_MATERIAL","I-STARTING_MATERIAL","B-REAGENT_CATALYST","I-REAGENT_CATALYST",
"B-REACTION_PRODUCT","I-REACTION_PRODUCT", "B-SOLVENT","I-SOLVENT", "B-OTHER_COMPOUND","I-OTHER_COMPOUND", "B-TIME", "I-TIME",
"B-TEMPERATURE", "I-TEMPERATURE","B-YIELD_PERCENT", "I-YIELD_PERCENT", "B-YIELD_OTHER","I-YIELD_OTHER", "B-EXAMPLE_LABEL",
"I-EXAMPLE_LABEL", "B-WORKUP","I-WORKUP", "B-REACTION_STEP", "I-REACTION_STEP"]

#Setting indeces of labeled entities from dataset
DICT_TAGS={
    "STARTING_MATERIAL": 1,
    "REAGENT_CATALYST": 3,
    "REACTION_PRODUCT": 5, 
    "SOLVENT": 7,
    "OTHER_COMPOUND": 9,
    "TIME": 11,
    "TEMPERATURE":13,
    "YIELD_PERCENT": 15,
    "YIELD_OTHER": 17,
    "EXAMPLE_LABEL": 19,
    "WORKUP": 21,
    "REACTION_STEP":23,
}

REVERSE_LIST_TAGS = ["O"]
for k in DICT_TAGS.keys():
  REVERSE_LIST_TAGS.extend([k,k])

##Training NER

Creating Data Module

In [None]:

#Creating Data Module with PyTorch Lightning
class NERDataModule(pl.LightningDataModule):  
  def __init__(self,
               batch_size = BATCH_SIZE_NER,
               max_length = MAX_LENGTH_NER,
               dict_tags = DICT_TAGS,
               list_tags = LIST_TAGS,
               tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL),
               ):

    super().__init__()
    self.train_batch_size=batch_size
    self.max_length=max_length
    self.dict_tags=dict_tags
    self.list_tags=list_tags
    self.tokenizer=tokenizer
    self.data = []
    self.dataset={}



  def prepare_data(self, path_json = GDRIVE_PATH + '/ee_train.json'):
    
    #Opening json data file with provided path
    file_json = open(path_json, 'r')
    load_json = json.load(file_json)

    #Extracting entity tags from data
    data = self.ner_tags(load_json)

    #Structuring data in preparation for training 
    data = [
      {
          "labels": torch.Tensor(label),
          "tokens": torch.squeeze(encoding.input_ids),
          "attention_mask": torch.squeeze(encoding.attention_mask)
      }

      for label, encoding in [
          (entry["ner_tags"], self.tokenizer(entry['text'],
                                          max_length=self.max_length, 
                                          padding='max_length', 
                                          truncation = True,
                                          return_tensors='pt'))
          
          for entry in data
      ]
    ]



    #Populating class self.data variable with edited data
    random.shuffle(data)
    self.data = data


  def setup(self, stage):

    #Preparing train-validation-test datasets
    self.dataset["train"]= self.data[:ceil(len(self.data)*0.7)]
    self.dataset["val"]= self.data[ceil(len(self.data)*0.7):ceil(len(self.data)*0.9)]
    self.dataset["test"]= self.data[ceil(len(self.data)*0.9):]

  def train_dataloader(self):

    #Creating DataLoader opbject that will be used for training
    return DataLoader(self.dataset["train"], batch_size=self.train_batch_size)
  
  def val_dataloader(self):

    #Creating DataLoader opbject that will be used for validation
    return DataLoader(self.dataset["val"], batch_size=self.train_batch_size)

  def test_dataloader(self):

    #Creating DataLoader opbject that will be used for testing
    return DataLoader(self.dataset["test"], batch_size=self.train_batch_size)

  def ner_predict_dataloader(self, input_ner):
  
    data = []
    for inp in input_ner:
      
      encoding = self.tokenizer(inp,
                                  max_length=self.max_length, 
                                  padding='max_length', 
                                  truncation = True,
                                  return_tensors='pt')
      
      data.append( {
          "tokens": torch.squeeze(encoding.input_ids),
          "attention_mask": torch.squeeze(encoding.attention_mask)
      })
    
    return DataLoader(data, batch_size = self.train_batch_size)


  def ner_tags(self, input_json):

    outputs=[]
    errors=[]

    #Extracting entity labels from provided json data and turning into ordinal labels
    for key in input_json.keys():
      output={}
      output['id']=key
      output['text']=input_json[key]['text']
      output['ner_tags']=self.get_ner(input_json[key]['text'], input_json[key]['entities'])
      outputs.append(output)

      if len(self.tokenizer.tokenize(output['text'], max_length=self.max_length -2 , truncation=True)) +2 != len(output['ner_tags']):
        errors.append(key)

    #Removing text extracts that have error in entity labels
    cleaned = [x for x in outputs if x not in errors]
    return cleaned

  def get_ner(self, text, entities):

    #Using -100 as ignore token for Cross Entropy Loss
    concat_list=[-100]
    old_start=0

    #Creating sequence of entity labels based on input text
    for entity in entities:

      start=entity['span'][0]
      end=entity['span'][1]
      if old_start<=start:

        #Appending zeroes where no entities are present in text
        empty = self.tokenizer.tokenize(text[old_start:start])
        concat_list.extend([0]*len(empty))

        #Appending relevant index label where entities are present
        if len(entity['text'].split())==1:
          excerpt=self.tokenizer.tokenize(entity['text'])
          concat_list.extend([self.dict_tags[entity['type']]]*len(excerpt))

        #Appending relevant B-I-O entity labels where multi-word entities are present
        elif len(entity['text'].split())>1:
          for i,word in enumerate(entity['text'].split()):
            if i==0:
              excerpt=self.tokenizer.tokenize(word)
              concat_list.extend([self.dict_tags[entity['type']]]*len(excerpt))
            else:
              excerpt=self.tokenizer.tokenize(word)
              tag = self.dict_tags[entity['type']]+1
              concat_list.extend([tag]*len(excerpt))

        else:
          print(f"Error in tags for {entity['text']}!")

      else:
        continue

      old_start=end


    #Appending zeroes where no entities are present in text
    last_part = self.tokenizer.tokenize(text[old_start:len(text)])
    concat_list.extend([0]*len(last_part))

    #Appending -100 to ignore the masking tokens
    if len(concat_list)>self.max_length-1:
      return concat_list[:self.max_length-1]+[-100]
    else:
      return concat_list + [-100]*(self.max_length-len(concat_list))
  

ner_datamodule=NERDataModule()

Creating Classification Module

In [None]:

#Creating Classification Module with PyTorch Lightning for Named Entity Recognition
class NERClassificationModule(pl.LightningModule):

  def __init__(self, 
               module,
               lr, 
               n_tags,
               list_tags=LIST_TAGS,
               max_length_re = MAX_LENGTH_RE,
               prepare_for_relex = True,
               reverse_list = REVERSE_LIST_TAGS,
               tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL),
               loss_fn=torch.nn.CrossEntropyLoss(ignore_index=-100)):
    
    super().__init__()
    self.module = module
    self.lr = lr
    self.list_tags = list_tags
    self.n_tags = n_tags
    self.loss_fn = loss_fn
    self.max_length_re = max_length_re
    self.tokenizer=tokenizer
    self.reverse_list = reverse_list
    self.prepare_for_relex = prepare_for_relex
    self.accuracy = torchmetrics.Accuracy(task='binary', multidim_average='samplewise')
    self.recall = torchmetrics.Recall(task='binary', multidim_average='samplewise')
    self.f1 = torchmetrics.F1Score(task='binary', multidim_average='samplewise')
    self.metric = datasets.load_metric('seqeval')

  def forward(self, tokens, mask, labels=None):

    #Forward BERT function
    result = self.module(tokens, mask, labels=labels)
    return result


  def training_step(self, batch, batch_index):

    #Extracting relevant variables from batch and putting through forward function
    target_tags = batch['labels'].to(self.device, dtype=torch.long)
    y_hat = self(batch['tokens'].to(self.device), batch['attention_mask'].to(self.device), labels=target_tags)

    #Calculating loss
    loss=y_hat.loss
      
    #Logging loss for progress tracking  
    self.log_dict({'train_loss':loss}, prog_bar=True)

    return loss

  def validation_step(self, batch, batch_idx):

    #Extracting relevant variables from batch and putting through forward function for validation
    target_tags = batch['labels'].type(dtype=torch.long)
    y_hat = self(batch['tokens'], batch['attention_mask'] ,labels=target_tags)

    #Calculating loss
    loss=y_hat.loss

    #Getting predicted label from logits calculated during forward step
    y_hat_labels=torch.argmax(y_hat.logits, dim=2)

    #Computing metrics for entity classification
    metrics = self.compute_metrics(y_hat_labels,target_tags)

    #Logging results
    self.log_dict({'val_loss':loss, 'val_f1':metrics['f1'], 'val_accuracy':metrics['accuracy'], 
                    'val_precision':metrics['precision'], 'val_recall':metrics['recall']}, prog_bar=True)
    return loss 
    
  def test_step(self, batch, batch_idx):

    #Extracting relevant variables from batch and putting through forward function for test phase
    target_tags = batch['labels'].type(torch.long)
    y_hat = self(batch['tokens'], batch['attention_mask'], labels=target_tags)

    #Calculating loss  
    loss = y_hat.loss

    #Getting predicted label from logits calculated during forward step
    y_hat_labels=torch.argmax(y_hat.logits, dim=2)

    #Computing metrics for entity classification
    metrics = self.compute_metrics(y_hat_labels,target_tags)

    #Logging results
    self.log_dict({'val_loss':loss, 'val_f1':metrics['f1'], 'val_accuracy':metrics['accuracy'], 
                    'val_precision':metrics['precision'], 'val_recall':metrics['recall']}, prog_bar=True)
    return loss 
    

  def predict_step(self, batch, batch_idx):

    #Putting predict input through forward step
    y_hat = self.module(batch['tokens'],batch['attention_mask']).logits

    #Returning results and logits
    return {'logits':y_hat,
              'input_ids':batch['tokens'],
              'attention_mask':batch['attention_mask']}


  def configure_optimizers(self):

    #Setting optimizer to be used
    return Adam(self.parameters(), lr=self.lr)

  def compute_metrics(self, preds, labs):

    #Detaching true labels and predictions from GPU
    predictions, labels = preds.detach().cpu().numpy(), labs.detach().cpu().numpy()
    
    #Removing ignore index that was used for Cross Entropy Loss
    true_predictions = [
        [self.list_tags[int(p)] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    #Removing ignore index that was used for Cross Entropy Loss
    true_labels = [
        [self.list_tags[int(l)] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    #Computing entity classification performance metrics
    results = self.metric.compute(predictions=true_predictions, references=true_labels)

    
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }
  
  def get_mask(self, text, entities, relations, a):

    masks={}

    ent0 = entities[0]
    ent1 = entities[1]

    #Getting masks for objects and subjects for RE
    diff = ent0['span'][0]+len(ent0['type'])+2 - ent0['span'][1]

    #Adjusting spans of words
    spans = {ent0['entity']: [ent0['span'][0] - a,ent0['span'][0]+len(ent0['type'])+2 - a],
             ent1['entity']: [diff + ent1['span'][0] - a,diff + ent1['span'][0]+len(ent1['type'])+2 - a]}
    
    #Setting new spans 
    subject_span = spans[relations['subject']]
    object_span = spans[relations['object']]
      
    #Setting masks of object and subject in relation  
    subject_mask = self.mask_relation(text, subject_span)
    object_mask = self.mask_relation(text, object_span)
    
    masks = {'subject_mask':subject_mask,'object_mask':object_mask}
    
    return masks
    
  def mask_relation(self, text, span_text, max_length = 200):

      #Creating masks for objects and subjects
      mask = [0]
      empty = self.tokenizer.tokenize(text[0:span_text[0]])
      mask.extend([0]*len(empty))

      #Setting subject's and object's token positions to 1 
      excerpt=self.tokenizer.tokenize(text[span_text[0]:span_text[1]])
      mask.extend([1]*len(excerpt))

      #Removing relationship if it's too long or returning mask
      if len(mask) > max_length-1:
        return "OVERFLOW"
      else:
        mask.extend([0]*(max_length-len(mask)))
        return mask


  def adjust_text(self, old_text, entities_list, start, end):

    ent0 = entities_list[0]
    ent1 = entities_list[1]

    #Replacing objects and subjects in text with their entity types
    new_text = old_text[start:ent0['span'][0]] + '@' + ent0['type'] + '#' \
                + old_text[ent0['span'][1]:ent1['span'][0]] + '@' + ent1['type'] + '#' + old_text[ent1['span'][1]:end]
                
    return new_text


  def predictions_to_relex(self, input_ids, predictions, reverse_dict):

    #Creating cleaned index and predictions without ignore index -100
    cleaned_predictions = [pred for (pred, token) in zip(predictions, input_ids) if token !=0]
    cleaned_inputs = [token for (pred, token) in zip(predictions, input_ids) if token !=0]

    collected_predictions = []
    previous = -2
    token_list=[]

    #Setting rules for parsing NER predictions before sending into Relation Extraction module
    for token_num in range(len(cleaned_predictions)-1):

      if previous == -2 and cleaned_predictions[token_num+1] == 0:
        continue

      elif previous == -2 and cleaned_predictions[token_num+1] != 0:
        previous = cleaned_predictions[token_num+1]

      elif cleaned_predictions[token_num+1] == previous or cleaned_predictions[token_num+1] == previous + 1:
        token_list.append(cleaned_inputs[token_num])

      elif cleaned_predictions[token_num+1] !=0 and cleaned_predictions[token_num+1] != previous and cleaned_predictions[token_num+1] != previous +1:
        token_list.append(cleaned_inputs[token_num])

        token_list_length = len("".join(self.tokenizer.decode(token_list).split("##")))
        text_length = len(self.tokenizer.decode(cleaned_inputs[1:token_num+1]))

        collected_predictions.append({"type": reverse_dict[previous], "entity": "".join(self.tokenizer.decode(token_list).split("##")), "span":[text_length-token_list_length, text_length] })
        previous = cleaned_predictions[token_num+1]
        token_list=[]

      elif cleaned_predictions[token_num+1] == 0 and previous != -2:
        token_list.append(cleaned_inputs[token_num])

        token_list_length = len("".join(self.tokenizer.decode(token_list).split("##")))
        text_length = len(self.tokenizer.decode(cleaned_inputs[1:token_num+1]))

        collected_predictions.append({"type": reverse_dict[previous], "entity": "".join(self.tokenizer.decode(token_list).split("##")), "span":[text_length-token_list_length, text_length]})
        previous = -2
        token_list=[]

      else:
        print("ERROR: Something went wrong")

    #Preparing predicted NER for input into Relation Extraction
    relex_data = self.prepare_for_relation_extraction(self.tokenizer.decode(cleaned_inputs[1:-1]), collected_predictions)

    return relex_data
    


  def prepare_for_relation_extraction(self, total_text, entities_list):

    total = []
    next_num=0

    #Creating list of entities and actions
    trigger_list = [d for d in entities_list if d['type'] in ['WORKUP','REACTION_STEP']]
    entity_list = [c for c in entities_list if c['type'] not in ['WORKUP','REACTION_STEP']]

    #Dividing text into sentences
    sents = [(s.start_char, s.end_char) for s in nlp(total_text).sents]

    #Creating relations based on all possible entity combinations
    for trigger in trigger_list:
      for entity in entity_list:
        for (a,b) in sents:
          if entity['span'][0]>a and entity['span'][1]<b and trigger['span'][0]>a and trigger['span'][1]<b:

            #Creating relation based on subject and object
            relations = {'rel_id':f"R{next_num}",
                              'subject': trigger['entity'], 
                              'object': entity['entity'],
                              }

            #Sorting entities based on span
            entities = sorted([trigger, entity], key=lambda f: f['span'][0])

            #Adjsuting text by replacing entities with their types
            text = self.adjust_text(total_text, entities, a , b)

            #Getting subject and object masks
            masks = self.get_mask(text,entities, relations, a)

            if masks['subject_mask'] == "OVERFLOW" or masks['object_mask'] == "OVERFLOW":
              continue
            else:

              #Appending data in relation extraction format
              total.append({'text': text, 'entities': entities, 'relations': relations, 
                            'subject_mask': masks['subject_mask'], 'object_mask': masks['object_mask'] })

              next_num+=1

    return total


In [None]:
#Choosing model to be used
model_NER = BertForTokenClassification.from_pretrained(BERT_MODEL, num_labels = len(LIST_TAGS))

In [None]:
lr=0.00002

logger = TensorBoardLogger(save_dir=GDRIVE_PATH, name="NER-bert")
ner = NERClassificationModule(model_NER, lr, len(LIST_TAGS))
trainer_ner = pl.Trainer(gpus = 1, max_epochs = 1, logger=logger)
trainer_ner.fit(ner, datamodule=ner_datamodule)
trainer_ner.test(dataloaders=ner_datamodule.test_dataloader())

  rank_zero_deprecation(
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name     | Type                       | Params
--------------------------------------------------------
0 | module   | BertForTokenClassification | 107 M 
1 | loss_fn  | CrossEntropyLoss           | 0     
2 | accuracy | BinaryAccuracy             | 0     
3 | recall   | BinaryRecall               | 0     
4 | f1       | BinaryF1Score              | 0     
--------------------------------------------------------
107 M     Trainable params
0         Non-trainable params
107 M     Total params
430.956   Total est

Sanity Checking: 0it [00:00, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.
  rank_zero_warn(
INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/drive/MyDrive/Hackathon: European Patent Office (EPO)/NER-bert/version_24/checkpoints/epoch=0-step=40.ckpt
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from checkpoint at /content/drive/MyDrive/Hackathon: European Patent Office (EPO)/NER-bert/version_24/checkpoints/epoch=0-step=40.ckpt


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy          0.8889172077178955
         val_f1             0.8065020891643887
        val_loss            0.4264299273490906
      val_precision         0.7667858397227154
       val_recall           0.8508710766088126
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_loss': 0.4264299273490906,
  'val_f1': 0.8065020891643887,
  'val_accuracy': 0.8889172077178955,
  'val_precision': 0.7667858397227154,
  'val_recall': 0.8508710766088126}]

In [None]:
tt = "Bio-techno-sustainable industrial mortar suitable for use in construction as a base, sub-base or tread layer in all types of paving, in the manufacture of lightweight prefabricated elements and in the construction of floors and ceilings collaborating among others, formed preferably by waste and by-products, such as plastics, rubber, textile and vegetable fibers, etc., mixed with cement-type binder materials and as a binder, adherent, breathable and waterproofing material uses polymers and acrylic copolymers. The invention makes it possible to reduce or even completely dispense with the use of conventional aggregates of quarry or ground mountain stone. (Machine-translation by Google Translate, not legally binding)"

In [None]:
def predict_entities_relations(text_list, reverse_list=REVERSE_LIST_TAGS):
  pred_loader = ner_datamodule.ner_predict_dataloader([tt])
  predicts = trainer_ner.predict(ner,pred_loader)

  relex_dataset = []

  for p in range(len(predicts)):
    preds = torch.argmax(predicts[p]['logits'], dim=2).numpy()
    inps = predicts[p]['input_ids'].numpy()
    for row in range(predicts[p]['logits'].shape[0]):
      relex_dataset.extend(ner.predictions_to_relex(inps[row], preds[row], reverse_list))

  final_dataloader = relex_datamodule.relex_predict_dataloader(relex_dataset)
  predictions_re = trainer_re.predict(relex, final_dataloader)

  relation_preds = []
  for r in range(len(predictions_re)):
    labels_re = torch.argmax(predictions_re[r]['logits'], dim=1).numpy()
    for row in range(predictions_re[r]['logits'].shape[0]):
      relation_preds.append(labels_re[row])
  
  for rel in range(len(relex_dataset)):
    relex_dataset[rel]['predicted_relation'] = relex_datamodule.encoding_to_label(relation_preds[rel])
  
  return relex_dataset
    


  


In [None]:
result = predict_entities_relations([tt])

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 40it [00:00, ?it/s]

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 518it [00:00, ?it/s]

##Training Relation Extraction

In [None]:
#Setting Batch Size for Relation Extraction task
BATCH_SIZE_RE = 32 

#Setting maximum length of BERT tokens for Relation Extraction
MAX_LENGTH_RE = 200

#Setting max epochs for Relation Extraction
RE_EPOCHS = 3

#Number of relations for Relation Extraction task
NUM_RELATIONS = 3

#Setting Path
GDRIVE_PATH='/content/drive/MyDrive/Hackathon: European Patent Office (EPO)'

#Choosing BERT model for tokenization and training
BERT_MODEL = 'dmis-lab/biobert-base-cased-v1.2'

In [None]:
class REDataModule(pl.LightningDataModule):  
  def __init__(self, 
               batch_size = BATCH_SIZE_RE,
               max_length = MAX_LENGTH_RE,
               tokenizer=BertTokenizer.from_pretrained(BERT_MODEL)
               ):

    super().__init__()
    self.train_batch_size=batch_size
    self.max_length=max_length
    self.tokenizer=tokenizer
    self.data = []
    self.dataset={}


  def prepare_data(self, path_json=f"{GDRIVE_PATH}/ee_train_relex.json"):

    #Opening file with trianing data and extracting
    file_json=open(path_json, 'r')
    data = json.load(file_json)
    file_json.close()

    edited_data=[]

    for entry in data.keys():

      #Getting relation for the row
      relation=data[entry]['relations']
      
      #Tokenizing the text and getting attention mask
      encoding = self.tokenizer(data[entry]['text'],
                                  max_length=self.max_length, 
                                  padding='max_length', 
                                  truncation = True,
                                  return_tensors='pt'
                                  )
      
      #Creating dataset for Relation Extraction
      edited_data.append({
          
          "labels": torch.Tensor(self.label_to_encoding(relation['type'])),
          "tokens": torch.squeeze(encoding.input_ids),
          "attention_mask": torch.squeeze(encoding.attention_mask),
          "subject_mask": torch.Tensor(relation['subject_mask']),
          "object_mask": torch.Tensor(relation['object_mask'])

        }
      )


    random.shuffle(edited_data)

    self.data = edited_data


  def label_to_encoding(self,label):

    #Turning labels into numerical encodings
    encodings={
        "O": 0,
        "ARGM": 1,
        "ARG1": 2
    }

    return [encodings[label]]

  def encoding_to_label(self,encoding):

    #Turning encodings back into labels
    labels=["O","ARGM","ARG1"]

    return labels[encoding]

  def setup(self, stage):

    #Creating train-validate-test datasets
    self.dataset["train"]= self.data[:ceil(len(self.data)*0.6)]
    self.dataset["val"]= self.data[ceil(len(self.data)*0.6):ceil(len(self.data)*0.8)]
    self.dataset["test"]= self.data[ceil(len(self.data)*0.8):]

  def train_dataloader(self):

    #Creating DataLoader 
    return DataLoader(self.dataset["train"], batch_size=self.train_batch_size)
  
  def val_dataloader(self):

    #Creating DataLoader 
    return DataLoader(self.dataset["val"], batch_size=self.train_batch_size)

  def test_dataloader(self):

    #Creating DataLoader 
    return DataLoader(self.dataset["test"], batch_size=self.train_batch_size)

  def relex_predict_dataloader(self, relex_input):
    
    #Creating DataLoader for predictions
    predict_input = []
    for entry in range(len(relex_input)):

      encoding = self.tokenizer(relex_input[entry]['text'],
                                  max_length=self.max_length, 
                                  padding='max_length', 
                                  truncation = True,
                                  return_tensors='pt')
      predict_input.append({
          
          "tokens": torch.squeeze(encoding.input_ids),
          "attention_mask": torch.squeeze(encoding.attention_mask),
          "subject_mask": torch.Tensor(relex_input[entry]['subject_mask']),
          "object_mask": torch.Tensor(relex_input[entry]['object_mask'])

        }
      )

    return DataLoader(predict_input, batch_size = self.train_batch_size)

relex_datamodule=REDataModule()

In [None]:

#Creating linear layer to be trained on top of BERT for Relation Extraction
class LinearLayer(torch.nn.Module):

    def __init__(self, input_dim, output_dim, dropout_rate=0.1, activation=True):
        super(LinearLayer, self).__init__()

        self.activation = activation
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.linear = torch.nn.Linear(input_dim, output_dim)
        self.tanh = torch.nn.Tanh()

    #Forward run after BERT embeddings
    def forward(self, x):
        x = self.dropout(x)
        if self.activation:
            x = self.tanh(x)
        return self.linear(x)

#Classification Module for Relation Extraction
class REClassificationModule(pl.LightningModule):
  def __init__(self, module, lr, n_relations, loss_fn=torch.nn.CrossEntropyLoss()):
    super().__init__()
    self.module = module
    self.lr = lr
    self.loss_fn = loss_fn
    self.accuracy = torchmetrics.Accuracy(task = 'multiclass', average = 'macro',num_classes =3)
    self.recall = torchmetrics.Recall(task = 'multiclass',average = 'macro',num_classes =3)
    self.f1 = torchmetrics.F1Score(task = 'multiclass',average = 'macro',num_classes =3)
    self.n_relations = n_relations
    self.cls_linear_layer = LinearLayer(768,768)
    self.entity_linear_layer = LinearLayer(768,768)
    self.classification_layer = LinearLayer(
            768 * 3,
            n_relations,
            0.1,
            activation=True,
        )


  def forward(self, tokens, mask, sub_mask, obj_mask):

    #Running data through BERT and retrieving Pooled and Final output
    result = self.module(tokens, attention_mask=mask)
    pooler = result[1]
    outputs = result[0]

    #Averaging embeddings of entities
    subject_average = self.averaging_entities(outputs, sub_mask)
    object_average = self.averaging_entities(outputs, obj_mask)

    #Passing Pooled and Averaged Embeddings through linear layers
    pooled_embeddings = self.cls_linear_layer(pooler)
    subject_embeddings = self.entity_linear_layer(subject_average)
    object_embeddings = self.entity_linear_layer(object_average)

    #Concatenating embeddings
    concat_embeddings = torch.cat([pooled_embeddings, subject_embeddings, object_embeddings], dim=-1)

    #Running through final linear layer
    logits = self.classification_layer(concat_embeddings)

    return logits


  @staticmethod
  def averaging_entities(output, entity_mask):

        #Calculating length of entity embeddings
        entity_mask_unsqueeze = entity_mask.unsqueeze(1) 
        length_tensor = (entity_mask != 0).sum(dim=1).unsqueeze(1)

        #Taking average of enetity embeddings
        sum_embeddings = torch.bmm(entity_mask_unsqueeze.float(), output).squeeze(1)
        avg_embeddings = sum_embeddings.float() / length_tensor.float() 

        return avg_embeddings




  def training_step(self, batch, batch_index):

    #Extracting data from batch 
    target_labels = batch['labels'].type(torch.long)
    tokens = batch['tokens']
    attention_mask = batch['attention_mask']
    subject_mask = batch['subject_mask']
    object_mask = batch['object_mask']

    #Forward propagating
    y_hat = self(tokens,attention_mask,subject_mask,object_mask)

    #Calculating loss
    if target_labels is not None:
        loss = self.loss_fn(y_hat.view(-1, self.n_relations), target_labels.view(-1))

    self.log_dict({'train_loss':loss}, prog_bar=True)

    return loss

  def validation_step(self, batch, batch_idx):
    
    #Extracting data from batch 
    target_labels = batch['labels'].type(torch.long)

    #Forward propagating
    y_hat = self(batch['tokens'], 
                 batch['attention_mask'],
                 batch['subject_mask'],
                 batch['object_mask'])

    #Calculating loss       
    if target_labels is not None:
        loss = self.loss_fn(y_hat.view(-1, self.n_relations), target_labels.view(-1))

    #Calculting additional metrics
    accuracy = self.accuracy(torch.argmax(y_hat, dim=1).view(-1), target_labels.view(-1))
    recall = self.recall(torch.argmax(y_hat, dim=1).view(-1), target_labels.view(-1))
    f1 = self.f1(torch.argmax(y_hat, dim=1).view(-1), target_labels.view(-1))
                 
    self.log_dict({"val_loss": loss,
                  "val_accuracy": accuracy,
                   "val_recall": recall,
                   "val_f1": f1}, prog_bar=True, logger=True)

    return loss
    
  def test_step(self, batch, batch_idx):
 
    #Extracting data from batch 
    target_labels = batch['labels'].type(torch.long)

    #Forward propagating
    y_hat = self(batch['tokens'], 
                 batch['attention_mask'],
                 batch['subject_mask'],
                 batch['object_mask'])

    #Calculating Loss
    if target_labels is not None:
        loss = self.loss_fn(y_hat.view(-1, self.n_relations), target_labels.view(-1))

    #Calculting additional metrics
    accuracy = self.accuracy(torch.argmax(y_hat, dim=1).view(-1), target_labels.view(-1))
    recall = self.recall(torch.argmax(y_hat, dim=1).view(-1), target_labels.view(-1))
    f1 = self.f1(torch.argmax(y_hat, dim=1).view(-1), target_labels.view(-1))

    self.log_dict({"val_loss": loss,
                "val_accuracy": accuracy,
                "val_recall": recall,
                "val_f1": f1}, prog_bar=True, logger=True)

    return loss


  def predict_step(self,batch, batch_idx):

    #Forward propagating
    y_hat = self(batch['tokens'], 
                 batch['attention_mask'],
                 batch['subject_mask'],
                 batch['object_mask'])
    
    return {'logits':y_hat}


  def configure_optimizers(self):

    #Setting optimization algorithm
    return Adam(self.parameters(), lr=self.lr)

    

In [None]:
#Chosing BERT Model
model_RE=BertModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2')

#Setting only last layers as trainable
for p in [p for p in model_RE.parameters()][:-8]:
   p.requires_grad = False

Downloading:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of the model checkpoint at dmis-lab/biobert-base-cased-v1.2 were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
lr=0.0002

#Training and logging
logger=TensorBoardLogger(save_dir=GDRIVE_PATH,name="RelEx-bert")
relex=REClassificationModule(model_RE, lr, NUM_RELATIONS)
trainer_re = pl.Trainer(gpus=1,max_epochs = 1, logger=logger)
trainer_re.fit(relex, datamodule=relex_datamodule)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name                 | Type               | Params
------------------------------------------------------------
0 | module               | BertModel          | 108 M 
1 | loss_fn              | CrossEntropyLoss   | 0     
2 | accuracy             | MulticlassAccuracy | 0     
3 | recall               | MulticlassRecall   | 0     
4 | f1                   | MulticlassF1Score  | 0     
5 | cls_linear_layer     | LinearLayer        | 590 K 
6 | entity_linear_layer  | LinearLayer        | 590 K 
7 | classification_layer | LinearLayer      

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.


##Predictions

In [None]:
text = []

In [None]:

def predict_entities_relations(text_list, reverse_list=REVERSE_LIST_TAGS):

  #Putting prediction inputs through NER dataloader
  pred_loader = ner_datamodule.ner_predict_dataloader([text_list])

  #Predicting entities
  predicts = trainer_ner.predict(ner,pred_loader)

  relex_dataset = []

  #Extracting entity predictions and creating dataset to feed into Relation Extraction
  for p in range(len(predicts)):
    preds = torch.argmax(predicts[p]['logits'], dim=2).numpy()
    inps = predicts[p]['input_ids'].numpy()
    for row in range(predicts[p]['logits'].shape[0]):
      relex_dataset.extend(ner.predictions_to_relex(inps[row], preds[row], reverse_list))

  #Loading extracted data into Relation Extraction dataloader
  final_dataloader = relex_datamodule.relex_predict_dataloader(relex_dataset)

  #Predicting relations
  predictions_re = trainer_re.predict(relex, final_dataloader)

  #Extracting predicted relations and appending to predicted entities
  relation_preds = []
  for r in range(len(predictions_re)):
    labels_re = torch.argmax(predictions_re[r]['logits'], dim=1).numpy()
    for row in range(predictions_re[r]['logits'].shape[0]):
      relation_preds.append(labels_re[row])
  
  #Changing predicted relation encoding to textual label
  for rel in range(len(relex_dataset)):
    relex_dataset[rel]['predicted_relation'] = relex_datamodule.encoding_to_label(relation_preds[rel])
  
  return relex_dataset
    



In [None]:
preds = predict_entities_relations(text)

APPENDIX: Parsing Brat Files

In [None]:
!pip install mendelai-brat-parser
from brat_parser import get_entities_relations_attributes_groups as brat

import os
data_names = os.listdir(f"{GDRIVE_PATH}/ee_train/ee_train")
data_names=list(filter(lambda x : x[-3:]=='txt', data_names))

def parse_ent(entities):
  ents=[]
  for key in entities.keys():
    ents.append({"ent_id": entities[key].id, 
               "type": entities[key].type, 
               "span": entities[key].span[0], 
               "text": entities[key].text})
  ents=sorted(ents, key = lambda d: d['span'][0])
  return ents

def parse_rel(relation):
  rels=[]
  for key in relation.keys():
    rels.append({"rel_id":relation[key].id,
                 "type":relation[key].type,
                 "subject": relation[key].subj,
                 "object": relation[key].obj})
  return rels


data={}
for name in data_names:
  add={}
  with open(f"{GDRIVE_PATH}/ee_train/ee_train/{name}", 'rt') as f:
      add["text"]=f.read()
  ent,rel,att,gru = brat(f"{GDRIVE_PATH}/ee_train/ee_train/{name[:-3]}ann")
  add["entities"]=parse_ent(ent)
  add["relations"]=parse_rel(rel)
  data[name[:-4]]=add


json_data = json.dumps(data, indent=4)
with open(f"{GDRIVE_PATH}/ee_train.json", 'w') as write_json:
  write_json.write(json_data)

In [None]:
tokenizer=BertTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.2')

file_json=open(f"{GDRIVE_PATH}/ee_train.json", 'r')
data = json.load(file_json)
file_json.close()

def get_mask(unit):

    masks={}

    text = unit['text']
    relations = unit['relations']
    entities = unit['entities']
    diff = entities[0]['span'][0]+len(entities[0]['type'])+2 - entities[0]['span'][1]
    a = unit['sent'][0]



    spans = {entities[0]['ent_id']: [entities[0]['span'][0] - a,entities[0]['span'][0]+len(entities[0]['type'])+2 - a],
             entities[1]['ent_id']: [diff + entities[1]['span'][0] - a,diff + entities[1]['span'][0]+len(entities[1]['type'])+2 - a]}
    
    subject_span = spans[relations['subject']]
    object_span = spans[relations['object']]
      
    subject_mask = mask_relation(text, subject_span)
    object_mask = mask_relation(text, object_span)
    
    masks[relations['rel_id']]={'subject_mask':subject_mask,'object_mask':object_mask}
    
    return masks
    
def mask_relation(text, span_text, max_length=200):


    mask = [0]
    empty = tokenizer.tokenize(text[0:span_text[0]])
    mask.extend([0]*len(empty))

    excerpt=tokenizer.tokenize(text[span_text[0]:span_text[1]])
    mask.extend([1]*len(excerpt))

    if len(mask) > max_length-1:
      return "OVERFLOW"
    else:
      mask.extend([0]*(max_length-len(mask)))
      return mask

def extend_relation(id,dataset):

    trigger_list = [(d['text'],d['ent_id']) for d in dataset[id]['entities'] if d['type'] in ['WORKUP','REACTION_STEP']]
    entity_list = [(c['text'],c['ent_id']) for c in dataset[id]['entities'] if c['type'] not in ['WORKUP','REACTION_STEP']]
    list_relations = [(y['subject'],y['object']) for y in dataset[id]['relations']]

    next_num=0
    temporary=[]

    for trigger in trigger_list:
      for entity in entity_list:
        if (trigger[1], entity[1]) not in list_relations:

          temporary.append({'rel_id':f"R{len(list_relations)+next_num}",
                            'type': "O", 
                            'subject': trigger[1], 
                            'object': entity[1],
                            })
          next_num+=1

    return temporary


def adjust_text(old_text, entities_list, start, end):

  ent0 = entities_list[0]
  ent1 = entities_list[1]

  new_text = old_text[start:ent0['span'][0]] + '@' + ent0['type'] + '#' \
              + old_text[ent0['span'][1]:ent1['span'][0]] + '@' + ent1['type'] + '#' + old_text[ent1['span'][1]:end]
  return new_text



In [None]:
single_data = {}

for key in data.keys():
  data[key]['relations'].extend(extend_relation(key, data))
  spans = [(s.start_char, s.end_char) for s in nlp(data[key]['text']).sents]
  for k, rel in enumerate(data[key]['relations']):
    entities_included = [ent for ent in data[key]['entities'] if ent['ent_id'] in [rel['subject'], rel['object']]]
    sort_ent = sorted(entities_included, key = lambda x: x['span'][0])
    for (a,b) in spans:
      if sort_ent[0]['span'][0]>a and sort_ent[0]['span'][1]<b and sort_ent[1]['span'][0]>a and sort_ent[1]['span'][1]<b:
        single_data[key+str(k)] = {
            'relations': rel,
            'entities': sort_ent,
            'text': adjust_text(data[key]['text'], sort_ent,a,b),
            'sent': [a,b]
        }

    


In [None]:

json_dumps = json.dumps(single_data)
with open(f"{GDRIVE_PATH}/ee_train_intermediate.json", 'w') as out:
  out.write(json_dumps)

In [None]:
json_dumps = json.dumps(single_data)
with open(f"{GDRIVE_PATH}/ee_train_relex.json", 'w') as out:
  out.write(json_dumps)

In [None]:
def extend_relation(id,dataset):

    trigger_list = [(d['entity'],d['type']) for d in dataset[id]['entities'] if d['type'] in ['WORKUP','REACTION_STEP']]
    entity_list = [(c['entity'],c['type']) for c in dataset[id]['entities'] if c['type'] not in ['WORKUP','REACTION_STEP']]

    next_num=0
    temporary=[]

    for trigger in trigger_list:
      for entity in entity_list:
        temporary.append({'rel_id':f"R{len(list_relations)+next_num}",
                            'type': "O", 
                            'subject': trigger[1], 
                            'object': entity[1],
                            })
        next_num+=1

    return temporary