# ML-based Relation Extraction

While the rule-based method performs heuristic search of subject-attribute entity pairs, the machine learning method classifies the candidate entity pairs into predefined relation classes. 

We created annotated datasets of relations between entities in clinical trial eligibility criteria, including the full data (data_re/allrealtions/train.tsv) and a subset of 2000 random samples (data_re/2000relations/train.tsv). The data can be used to train many types of machine learning models, including support vector machine, condition random field, and deep learning-based classification. We trained BERT-based models using a recent package (https://github.com/uf-hobi-informatics-lab/ClinicalTransformerRelationExtraction) with our relation training data.

Here we show how to use a trained BERT-classification model to classify candidate relations using a sample data set.

In [1]:
import pandas as pd
import ast
import csv
from nltk import pos_tag
from nltk.chunk import conlltags2tree
from nltk.tree import Tree
from itertools import combinations

## Test Data Preparation

### Utility functions

In [2]:
# BIO tags
def tag(ner_result):
    tags = []
    for word in ner_result:
        # print('word, ', word)
        if 'SEP' not in word['tag'] and 'CLS' not in word['tag'] :
            tags.append((word['word'], word['tag']))
    return tags

In [3]:
# Create tree
def stanford_tree(bio_tagged):
    if len(bio_tagged) != 0:
        tokens, ne_tags = zip(*bio_tagged)
        pos_tags = [pos for token, pos in pos_tag(tokens)]

        conlltags = [(token, pos, ne) for token, pos, ne in zip(tokens, pos_tags, ne_tags)]
        ne_tree = conlltags2tree(conlltags)
        return ne_tree
    else:
        return None

In [4]:
# Parse named entities from tree
def structure_ne(ne_tree):
    if ne_tree is not None:
        ne = []
        for subtree in ne_tree:
            if type(subtree) == Tree:
                ne_label = subtree.label()
                ne_string = " ".join([token for token, pos in subtree.leaves()])
                ne.append((ne_string, ne_label))
        return ne
    else:
        return None

### Functions to generate test data with the right format

In [5]:
def find_index(mentions, s):
    index=0
    indexes=[]
    for mention in mentions:
        if mention in s:
            c = mention[0]
            # Iterate over index
            for i in range(index, len(s)):            
                if s[i]==c:
                    if s[i:i+len(mention)] == mention:
                        indexes.append((mention,i,i+len(mention)))
                        index = i+len(mention)
                        break
    return indexes

In [6]:
def nerTokens_to_sentences(df_ner):

    df_ner = df_ner[['#nct_id','eligibility_type','criterion','NER']]
    df_ner = df_ner.drop_duplicates()
    
    result_ner = df_ner['NER']
    
    doc_index = df_ner['#nct_id']
    result_re = []
    stringlist = []
    for ner in result_ner:
        if isinstance(ner, str):
            ner = ast.literal_eval(ner)
        words = tag(ner)
        
        text=""
        for word in words:
            token = word[0]
            text+=token+" "
        
        tags_formatted = structure_ne(stanford_tree(words))
        mentions=[]
        tags=[]
        for one in tags_formatted:
            mentions.append(one[0])
            tags.append(one[1])
        
        indexes = find_index(mentions,text)
        # print(text)
        # print(indexes)        
        
        entitylist=[]
        # 3:6:age,7:14:upper_bound	< age @NUMBER
        string=""
        for i in range(0,len(indexes)):
            mention=indexes[i][0]
            start=indexes[i][1]
            end=indexes[i][2]
            label=tags[i]
            string+= str(start)+":"+str(end)+":"+label +","
                        
            entitylist.append((mention,label))

        string+="\t"+text
        # print(string)
        
        if string!="\t"+text: #don't process sentences without entities
            stringlist.append(string)
        
    return stringlist

In [7]:
# Annotate mentions with start [s] and end [e] positions in the sentence
def annotate_mentions(sent, doc_index):
    result = sent.split('\t')
    tags = result[0]
    # remove the last "," in tags
    tags = tags.rstrip(',')
    tagList = tags.split(',')
    text = result[1]

    # print('tagList: ', tagList)
    
    labels = []
    entity = 0

    for tagStr in tagList:
        tup = tagStr.split(":")
        #start = int(tup[0])-1
        #end = int(tup[1])-1
        start = int(tup[0])
        end = int(tup[1])
        text_new = (text[:start] + "[s1] " + text[start:end] + " [e1]" + text[end:], tup[2], str(doc_index)+"_"+str(entity), text[start:end])

        labels.append(text_new)
        entity = entity + 1

    return labels

In [8]:
def generate_candidate_relations(stringlist):
    
    df_re = pd.DataFrame(stringlist,columns=['Annotated String'])
    
    doc=0
    result_re = []
    for string in stringlist:
        
        # NOTE: prepare test data based on predicted NERs
        entitylist = annotate_mentions(string, doc)       
        
        relations=[]
        
        # do combo
        for combo in combinations(entitylist, 2):  # 2 for pairs, 3 for triplets, etc
            relations.append(('NonRel',combo[0], combo[1],doc))
            
        result_re.append(relations)
        doc=doc+1
    
    df_re['Relation'] = result_re
    df_re = df_re.explode('Relation')    
    df_re[['RelationType', 'Entity1', 'Entity2', 'DoC']] = pd.DataFrame(df_re['Relation'].tolist(), index=df_re.index)    
    df_re = df_re[['RelationType','Entity1','Entity2','DoC']]
    
    # Continue to finalize the format
    df_re[['Sent1', 'EntityType1', 'EntityIndex1','Mention1']] = pd.DataFrame(df_re['Entity1'].tolist(), index=df_re.index) 
    df_re[['Sent2', 'EntityType2', 'EntityIndex2','Mention2']] = pd.DataFrame(df_re['Entity2'].tolist(), index=df_re.index) 
    
    test_df = df_re[['RelationType','Sent1','Sent2','EntityType1','EntityType2','EntityIndex1','EntityIndex2','DoC']]

    test_df = test_df.copy()
    test_df['Sent1'] = test_df['Sent1'].str.replace("\n", "")
    test_df['Sent2'] = test_df['Sent2'].str.replace("s1", "s2").str.replace("e1", "e2").str.replace("\n", "")

    # Filter rows without relations
    test_df = test_df[~test_df['RelationType'].isnull()]
    
    test_df.columns = ['1','2','3','4','5','6','7','8']
    
    return test_df

### Generate candidate relations (test data) using sample data

In [9]:
# Load sample NER data
df_ner = pd.read_excel('data_ner/sample_trial_NER.xlsx')
df_ner

Unnamed: 0,#nct_id,eligibility_type,criterion,NER,Tags,Entity,Type
0,NCT00097734,inclusion,- At least three of the following signs or sym...,"[{'word': '-', 'tag': 'O', 'confidence': 0.999...","('three', 'lower_bound')",three,lower_bound
1,NCT00097734,inclusion,- At least three of the following signs or sym...,"[{'word': '-', 'tag': 'O', 'confidence': 0.999...","('sigmoid diverticulitis', 'chronic_disease')",sigmoid diverticulitis,chronic_disease
2,NCT00097734,inclusion,- At least three of the following signs or sym...,"[{'word': '-', 'tag': 'O', 'confidence': 0.999...","('Fever', 'chronic_disease')",Fever,chronic_disease
3,NCT00097734,inclusion,- At least three of the following signs or sym...,"[{'word': '-', 'tag': 'O', 'confidence': 0.999...","('body temperature', 'clinical_variable')",body temperature,clinical_variable
4,NCT00097734,inclusion,- At least three of the following signs or sym...,"[{'word': '-', 'tag': 'O', 'confidence': 0.999...","('38°C', 'lower_bound')",38°C,lower_bound
5,NCT00097734,inclusion,- At least three of the following signs or sym...,"[{'word': '-', 'tag': 'O', 'confidence': 0.999...","('Leukocytosis', 'chronic_disease')",Leukocytosis,chronic_disease
6,NCT00097734,inclusion,- At least three of the following signs or sym...,"[{'word': '-', 'tag': 'O', 'confidence': 0.999...","('leukocytes', 'clinical_variable')",leukocytes,clinical_variable
7,NCT00097734,inclusion,- At least three of the following signs or sym...,"[{'word': '-', 'tag': 'O', 'confidence': 0.999...","('10,000/µl', 'lower_bound')","10,000/µl",lower_bound
8,NCT00097734,inclusion,- At least three of the following signs or sym...,"[{'word': '-', 'tag': 'O', 'confidence': 0.999...","('differential blood count', 'clinical_variable')",differential blood count,clinical_variable
9,NCT00097734,inclusion,- At least three of the following signs or sym...,"[{'word': '-', 'tag': 'O', 'confidence': 0.999...","('1 %', 'lower_bound')",1 %,lower_bound


In [10]:
sentencelist = nerTokens_to_sentences(df_ner)

In [11]:
candidate_relations = generate_candidate_relations(sentencelist)

In [12]:
candidate_relations = candidate_relations.reset_index(drop=True)
candidate_relations

Unnamed: 0,1,2,3,4,5,6,7,8
0,NonRel,- At least [s1] three [e1] of the following si...,- At least three of the following signs or sym...,lower_bound,chronic_disease,0_0,0_1,0.0
1,NonRel,- At least [s1] three [e1] of the following si...,- At least three of the following signs or sym...,lower_bound,chronic_disease,0_0,0_2,0.0
2,NonRel,- At least [s1] three [e1] of the following si...,- At least three of the following signs or sym...,lower_bound,clinical_variable,0_0,0_3,0.0
3,NonRel,- At least [s1] three [e1] of the following si...,- At least three of the following signs or sym...,lower_bound,lower_bound,0_0,0_4,0.0
4,NonRel,- At least [s1] three [e1] of the following si...,- At least three of the following signs or sym...,lower_bound,chronic_disease,0_0,0_5,0.0
...,...,...,...,...,...,...,...,...
79,NonRel,- Participation in [s1] another clinical study...,- Participation in another clinical study or u...,treatment,upper_bound,7_0,7_2,7.0
80,NonRel,- Participation in another clinical study or u...,- Participation in another clinical study or u...,treatment,upper_bound,7_1,7_2,7.0
81,NonRel,- Patients with a [s1] hematologic/oncologic d...,- Patients with a hematologic/oncologic diseas...,chronic_disease,cancer,9_0,9_1,9.0
82,NonRel,- Patients with a [s1] hematologic/oncologic d...,- Patients with a hematologic/oncologic diseas...,chronic_disease,cancer,9_0,9_2,9.0


In [13]:
# Save candidate relations (test data), which will be used to make predictions
candidate_relations.to_csv("data_re/2000relations/test.tsv", sep="\t", index=False, quoting=csv.QUOTE_NONE)

## ML-based Relation Extraction

You may train a BERT-based relation extraction model using this package (https://github.com/uf-hobi-informatics-lab/ClinicalTransformerRelationExtraction) with our training data. Here we show sample commands to train a model. And then use the test data to make predictions.

For details, refer to the instruction at that repo.

### Train a model

1) set up environment variables

export CUDA_VISIBLE_DEVICES=1\
data_dir=./data_re/2000relations\
pof=./data_re/2000relations/predictions.txt\
log=./data_re/2000relations/log.txt\
nmd=./2000relations_model

2) train model
  
python ./src/relation_extraction.py \\\
    --model_type bert \\\
    --data_format_mode 0 \\\
    --classification_scheme 1 \\\
    --pretrained_model bert-base-uncased \\\
    --data_dir \\$data_dir \\\
    --new_model_dir \\$nmd \\\
    --predict_output_file \\$pof \\\
    --overwrite_model_dir \\\
    --seed 13 \\\
    --max_seq_length 256 \\\
    --cache_data \\\
    --do_train \\\
    --do_lower_case \\\
    --train_batch_size 4 \\\
    --eval_batch_size 4 \\\
    --learning_rate 1e-5 \\\
    --num_train_epochs 3 \\\
    --gradient_accumulation_steps 1 \\\
    --do_warmup \\\
    --warmup_ratio 0.1 \\\
    --weight_decay 0 \\\
    --max_num_checkpoints 0 \\\
    --log_file \\$log \\
  

### Predict relations in the test data

python ./src/relation_extraction.py \\\
  --model_type bert \\\
  --data_format_mode 0 \\\
  --classification_scheme 1 \\\
  --pretrained_model bert-base-uncased \\\
  --data_dir \\$data_dir \\\
  --new_model_dir \\$nmd \\\
  --predict_output_file \\$pof \\\
  --overwrite_model_dir \\\
  --seed 13 \\\
  --max_seq_length 256 \\\
  --cache_data \\\
  --do_predict \\\
  --do_lower_case \\\
  --eval_batch_size 4 \\\
  --log_file \\$log \\


### Generate prediction output

In [14]:
# Read ML predicted relation labels
prediction = pd.read_csv('data_re/2000relations/predictions.txt', sep='\t',header=None)
prediction

Unnamed: 0,0
0,NonRel
1,NonRel
2,NonRel
3,NonRel
4,NonRel
...,...
79,hasTemp
80,hasTemp
81,NonRel
82,NonRel


In [15]:
# Merge with input data
predicted_relations = candidate_relations
predicted_relations['prediction'] = prediction
predicted_relations = predicted_relations[['1','prediction','2','3','4','5','6','7','8']]
predicted_relations

Unnamed: 0,1,prediction,2,3,4,5,6,7,8
0,NonRel,NonRel,- At least [s1] three [e1] of the following si...,- At least three of the following signs or sym...,lower_bound,chronic_disease,0_0,0_1,0.0
1,NonRel,NonRel,- At least [s1] three [e1] of the following si...,- At least three of the following signs or sym...,lower_bound,chronic_disease,0_0,0_2,0.0
2,NonRel,NonRel,- At least [s1] three [e1] of the following si...,- At least three of the following signs or sym...,lower_bound,clinical_variable,0_0,0_3,0.0
3,NonRel,NonRel,- At least [s1] three [e1] of the following si...,- At least three of the following signs or sym...,lower_bound,lower_bound,0_0,0_4,0.0
4,NonRel,NonRel,- At least [s1] three [e1] of the following si...,- At least three of the following signs or sym...,lower_bound,chronic_disease,0_0,0_5,0.0
...,...,...,...,...,...,...,...,...,...
79,NonRel,hasTemp,- Participation in [s1] another clinical study...,- Participation in another clinical study or u...,treatment,upper_bound,7_0,7_2,7.0
80,NonRel,hasTemp,- Participation in another clinical study or u...,- Participation in another clinical study or u...,treatment,upper_bound,7_1,7_2,7.0
81,NonRel,NonRel,- Patients with a [s1] hematologic/oncologic d...,- Patients with a hematologic/oncologic diseas...,chronic_disease,cancer,9_0,9_1,9.0
82,NonRel,NonRel,- Patients with a [s1] hematologic/oncologic d...,- Patients with a hematologic/oncologic diseas...,chronic_disease,cancer,9_0,9_2,9.0


In [16]:
predicted_relations.to_excel("data_re/sample_trial_relations_ML-prediction.xlsx", index=False)