In [40]:
from datasets import load_dataset

In [2]:
dataset = load_dataset("ade-benchmark-corpus/ade_corpus_v2",'Ade_corpus_v2_drug_ade_relation') 


In [3]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'drug', 'effect', 'indexes'],
        num_rows: 6821
    })
})

In [4]:
test = dataset["train"]
test[0]

{'text': 'Intravenous azithromycin-induced ototoxicity.',
 'drug': 'azithromycin',
 'effect': 'ototoxicity',
 'indexes': {'drug': {'start_char': [12], 'end_char': [24]},
  'effect': {'start_char': [33], 'end_char': [44]}}}

In [6]:
from gliner import GLiNER

import os
import torch
from tqdm import tqdm

# text="The branding of the CSAT survey email has been made generic to ensure that end customers do not associate it with DevRev."
model=GLiNER.from_pretrained("urchade/gliner_large-v2.1")
# Force usage of GPU 1
device = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")
model = model.to(device)

predicted_ner=[]
labels=['drug', 'effect','reaction','disease']
# labels=all_entities
results=[]
for i in tqdm(range(len(test))):
    text=test[i]['text']
    entities=model.predict_entities(text,labels=labels,threshold=0.5,multi_label=True)
    real=set()
    import re
    temp=[]
    
    # dtokens = re.split(r"[ ,._:;\[\](){}\/\?\r'\*\n|!-\"<>\-#@=%&+]+", text)

    for entity in entities:
        # ent_text=re.split(r"[ ,._:;\[\](){}\/\?\r'\*\n|!-\"<>\-#@=%&+]+", entity['text'])
        if entity['text'] not in real:
            temp.append([entity['start'], entity['end'], entity['text'], entity['label']])

        real.add(entity['text'])

    real=list(real)
    predicted_ner.append(temp)
    results.append(real)

Fetching 4 files: 100%|██████████| 4/4 [00:00<00:00, 40920.04it/s]
  0%|          | 0/6821 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
100%|██████████| 6821/6821 [07:21<00:00, 15.45it/s]


In [7]:
results[1]

['rifampicin', "Paget's bone disease"]

In [8]:
true=[]
for i in range(len(test)):
    temp=[test[i]['drug'],test[i]['effect']]
    true.append(temp)

In [9]:
test[6]['text']

'RESULTS: A 44-year-old man taking naproxen for chronic low back pain and a 20-year-old woman on oxaprozin for rheumatoid arthritis presented with tense bullae and cutaneous fragility on the face and the back of the hands.'

In [10]:
true[6]

['naproxen', 'cutaneous fragility']

In [11]:
results[6]

['rheumatoid arthritis',
 'presented with tense bullae and cutaneous fragility',
 'naproxen',
 'oxaprozin']

In [12]:
recall_avg=[]
precision_avg=[]
common2=0
for i in range(len(results)):
    true_set=set(true[i])
    pred_set=set(results[i])
    
    if len(true_set) == 0 and len(pred_set) == 0:
        recall_avg.append(1.0)
        precision_avg.append(1.0)
    else:
        count=0
        for m in true_set:
            for n in pred_set:
                if m in n or n in m:
                    count+= 1
        recall = count / len(true_set) if len(true_set) > 0 else 0
        precision = count / len(pred_set) if len(pred_set) > 0 else 0
        common2 += count
        recall_avg.append(recall)
        precision_avg.append(precision)

In [13]:
sum(recall_avg) / len(recall_avg), sum(precision_avg) / len(precision_avg)

(0.8955431754874652, 0.6475514479656069)

In [14]:
import re
ner=[]
for i in range(len(test)):
    temp=[]
    tokens = re.split(r'[ -.,:;/?\]\[]+', test[i]['text'])
    drug=re.split(r'[ -.,:;/?\]\[]+', test[i]['drug'])
    start=tokens.index(drug[0])
    end=tokens.index(drug[-1])
    temp.append([start,end,'drug',test[i]['drug']])
    effect=re.split(r'[ -.,:;/?\[\]]+', test[i]['effect'])
    start=tokens.index(effect[0])
    end=tokens.index(effect[-1])
    temp.append([start,end,'effect',test[i]['effect']])
    ner.append(temp)
test[0]

{'text': 'Intravenous azithromycin-induced ototoxicity.',
 'drug': 'azithromycin',
 'effect': 'ototoxicity',
 'indexes': {'drug': {'start_char': [12], 'end_char': [24]},
  'effect': {'start_char': [33], 'end_char': [44]}}}

In [15]:
predicted_ner[0]

[[12, 32, 'azithromycin-induced', 'drug'], [33, 44, 'ototoxicity', 'disease']]

In [16]:
import re
pred_ner=[]
for i in range(len(test)):
    temp=[]
    for j in predicted_ner[i]:
        tokens = re.split(r'[ -.,:;/?\]\[]+', test[i]['text'])
        drug=re.split(r'[ -.,:;/?\]\[]+', j[2])
        start=tokens.index(drug[0])
        end=tokens.index(drug[-1])
        temp.append([start,end,'drug',j[2]])
    
    pred_ner.append(temp)


In [17]:
pred_ner[0]

[[1, 2, 'drug', 'azithromycin-induced'], [3, 3, 'drug', 'ototoxicity']]

In [18]:
ner[0]

[[1, 1, 'drug', 'azithromycin'], [3, 3, 'effect', 'ototoxicity']]

In [19]:
relations_true=[]
for i in range(len(ner)):
    temp=[]
    temp.append([ner[i][0][3], ner[i][1][3], 'has'])
    relations_true.append(temp)  

In [20]:
relations_true[3]

[['naproxen', 'pseudoporphyria', 'has']]

In [48]:
from pair2rel import Pair2Rel

from tqdm import tqdm
model = Pair2Rel.from_pretrained("chapalavamshi022/pair2rel")
import torch
import re
# Force usage of GPU 1
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.device = device 
relations_all=[]
labels = ['has']
for i in tqdm(range(len(test))):
    # required_labels = []
    # for token in processed_data[i]['tokens']:
    #     if token in rel_set:
    #         required_labels.append(token)
    tokens = re.split(r'[ -.,:;/?\]\[]+', test[i]['text'])

    relations = model.predict_relations(tokens, labels, threshold=0.0, ner=ner[i], top_k=1)

    sorted_data_desc = sorted(relations, key=lambda x: x['score'], reverse=True)
    temp=[]
    for item in sorted_data_desc:
        head=' '.join(item['head_text'])
        tail=' '.join(item['tail_text'])
        if head == tail:
            continue
        temp.append([head,tail,item['label']])

    relations_all.append(temp)
        
print("Success! ✅")

100%|██████████| 6821/6821 [07:29<00:00, 15.18it/s]

Success! ✅





In [49]:
relations_true[1]

[['dihydrotachysterol', 'increased calcium release', 'has']]

In [50]:
for i in range(len(relations_all)):
    if relations_true[i][0] not in relations_all[i]:
        print(i)

12
72
81
100
105
107
141
147
149
166
168
216
223
224
225
232
244
324
385
420
422
425
426
431
432
445
470
472
488
523
526
537
559
562
563
564
622
623
641
642
726
832
858
907
921
922
925
926
941
994
1002
1004
1060
1064
1078
1152
1160
1174
1218
1223
1242
1244
1246
1248
1283
1366
1367
1368
1444
1479
1480
1508
1518
1519
1521
1532
1597
1600
1624
1657
1658
1665
1685
1692
1694
1695
1732
1733
1736
1747
1765
1766
1786
1802
1828
1855
1866
1867
1868
1947
1948
1953
1980
1981
1993
2041
2042
2043
2118
2119
2132
2157
2199
2200
2211
2221
2231
2238
2274
2317
2329
2383
2436
2445
2447
2459
2460
2487
2490
2498
2522
2523
2588
2592
2660
2683
2711
2745
2748
2749
2774
2868
2869
2870
2926
2982
2984
3017
3018
3039
3077
3103
3123
3148
3198
3201
3247
3257
3317
3318
3319
3322
3386
3447
3458
3503
3504
3506
3523
3524
3525
3554
3558
3561
3607
3754
3776
3815
3866
3868
3874
3876
3889
3973
3990
4009
4010
4012
4020
4055
4056
4067
4068
4071
4074
4099
4186
4190
4263
4275
4414
4429
4430
4431
4465
4506
4509
4525
4526
4528
456

In [51]:
test[1]['text']

"Immobilization, while Paget's bone disease was present, and perhaps enhanced activation of dihydrotachysterol by rifampicin, could have led to increased calcium-release into the circulation."

In [52]:
relations_all[1]

[['increased calcium release', 'dihydrotachysterol', 'has'],
 ['dihydrotachysterol', 'increased calcium release', 'has']]

In [53]:
recall_avg=[]
precision_avg=[]
common=0
ours=0

for i in range(len(relations_all)):
    true_set=relations_true[i]
    pred_set=relations_all[i]
    d={}
    for item in pred_set:
        if (item[0],item[1]) not in d and (item[1],item[0]) not in d:
            d[(item[0],item[1])]=item[2]
    
    ours+=len(d)
    if len(true_set) == 0 and len(pred_set) == 0:
        recall_avg.append(1.0)
        precision_avg.append(1.0)
    else:
        count=0
        for m in true_set:
            for n in pred_set:
                
                if (m[0]==n[0]) and m[1]==n[1] and (m[2]==n[2]):
                    count+= 1
        recall = count/ len(true_set) if len(true_set) > 0 else 0
        precision = count / len(d) if len(d) > 0 else 0
        common += count
        recall_avg.append(recall)
        precision_avg.append(precision)

In [54]:
sum(recall_avg) / len(recall_avg), sum(precision_avg) / len(precision_avg)

(0.9524996334848262, 0.9524996334848262)

In [35]:
from pair2rel import Pair2Rel

from tqdm import tqdm
model = Pair2Rel.from_pretrained("chapalavamshi022/pair2rel")
import torch
import re
# Force usage of GPU 1
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.device = device 
relations_all=[]
labels = ['has']
for i in tqdm(range(len(test))):
    # required_labels = []
    # for token in processed_data[i]['tokens']:
    #     if token in rel_set:
    #         required_labels.append(token)
    tokens = re.split(r'[ -.,:;/?\]\[]+', test[i]['text'])
    try:

        relations = model.predict_relations(tokens, labels, threshold=0.0, ner=pred_ner[i], top_k=1)

        sorted_data_desc = sorted(relations, key=lambda x: x['score'], reverse=True)
        temp=[]
        for item in sorted_data_desc:
            head=' '.join(item['head_text'])
            tail=' '.join(item['tail_text'])
            if head == tail:
                continue
            temp.append([head,tail,item['label']])

        relations_all.append(temp)
    except:
        relations_all.append([])
        

print("Success! ✅")


100%|██████████| 6821/6821 [24:31<00:00,  4.64it/s]

Success! ✅





In [38]:
recall_avg=[]
precision_avg=[]
common=0
ours=0

for i in range(len(relations_all)):
    true_set=relations_true[i]
    pred_set=relations_all[i]
    d={}
    for item in pred_set:
        if (item[0],item[1]) not in d and (item[1],item[0]) not in d:
            d[(item[0],item[1])]=item[2]
    
    ours+=len(d)
    if len(true_set) == 0 and len(pred_set) == 0:
        recall_avg.append(1.0)
        precision_avg.append(1.0)
    else:
        count=0
        for m in true_set:
            for n in pred_set:
                if m[0] in n[0] and m[1] in n[1] and (m[2]==n[2]):
                    count+= 1
        recall = count/ len(true_set) if len(true_set) > 0 else 0
        precision = count / len(d) if len(d) > 0 else 0
        common += count
        recall_avg.append(recall)
        precision_avg.append(precision)

In [39]:
sum(recall_avg) / len(recall_avg), sum(precision_avg) / len(precision_avg)

(0.6116405219176074, 0.2600283059233326)