## Source: PubMed
## Retriever: Pubmed Search
## Model: BioLinkBERT (BioASQ)

In [2]:
import torch
from sklearn.metrics import classification_report, roc_auc_score

import sys
sys.path.append("../../") # use utils

import utils
import importlib
importlib.reload(utils)

from utils import calc_auc

In [2]:
from transformers import AutoTokenizer, BertForSequenceClassification, BertModel, DataCollatorWithPadding

model_name = "<path_to_bio-linkbert-large__bioasq_hf>"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name)
model.eval()
print("OK")

OK


In [3]:
def predict(question, passage):
    sequence = tokenizer.encode_plus(
        question, 
        passage, 
        return_tensors="pt",
        max_length=512,
        truncation=True
    )['input_ids']

    logits = model(sequence)[0]
    probabilities = torch.softmax(logits, dim=1).detach().cpu().tolist()[0]
    proba_yes = probabilities[1]
    
    return proba_yes

In [1]:
init_data = pd.read_csv("../../../data/data_to_process.csv")
print(len(init_data))
init_data.head(3)

113


Unnamed: 0,data_source,query_id,description,query,label
0,2019,1,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0
1,2019,3,Can acupuncture be effective for people with e...,acupuncture epilepsy,0.0
2,2019,5,Can acupuncture prevent migraines?,acupuncture migraine,1.0


## Keywords

In [6]:
raw_df = pd.read_csv("../../../data/pubmed_search/pubmed_search_abstracts_keywords.csv")

df = (
    raw_df
    .merge(init_data, on=["data_source", "query_id", "label", "description", "query"])
    .dropna()
)
print(len(df))
df.head(1)

645


Unnamed: 0,query_id,data_source,pubmed_id,description,query,label,abstract
0,1,2019,35285701,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Aim: Cranberries ( Vaccinium macrocarpon ) are...


In [8]:
results = []

for _, row in pb(df.iterrows(), max_value=len(df)):
    results.append(predict(row.description, row.abstract))

100% (645 of 645) |######################| Elapsed Time: 0:33:38 Time:  0:33:38


In [9]:
df['prediction'] = results
df.head(1)

Unnamed: 0,query_id,data_source,pubmed_id,description,query,label,abstract,prediction
0,1,2019,35285701,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Aim: Cranberries ( Vaccinium macrocarpon ) are...,0.348066


In [10]:
df.to_csv("predictions/biolinkbert_bioasq_pubmed_search_preds_keywords.csv", index=0)

## Calc Metrics

In [19]:
df = pd.read_csv("predictions/biolinkbert_bioasq_pubmed_search_preds_keywords.csv")
print(len(df))
df.head(2)

645


Unnamed: 0,query_id,data_source,pubmed_id,description,query,label,abstract,prediction
0,1,2019,35285701,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Aim: Cranberries ( Vaccinium macrocarpon ) are...,0.348066
1,1,2019,34205292,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Consumption of cranberries is associated with ...,0.667579


In [20]:
agg_types = ["avg", "top1", "norm_linear", "norm_log"]
data_source_types = sorted(df.data_source.unique().tolist()) + ["all"]
data_source_types

['2019', '2021', 'health_belief', 'misbelief', 'all']

In [21]:
df_filled = (
    df
    .merge(init_data, how='outer', on=["query_id", "data_source", "label", "description", "query"])
    .fillna({"prediction": 0.5})
)
print(len(df_filled))
df_filled.tail(5)

678


Unnamed: 0,query_id,data_source,pubmed_id,description,query,label,abstract,prediction
673,215,misbelief,,Does onion kill viruses?,onion kill viruses,1.0,,0.5
674,4_h,health_belief,,Does oral sex cause throat cancer?,oral sex causes throat cancer,1.0,,0.5
675,11_h,health_belief,,Does nasal polyp cause nasal block?,nasal polyp causes nasal block,1.0,,0.5
676,12_h,health_belief,,Does cialis treat enlarged prostrate?,cialis treats enlarged prostrate,1.0,,0.5
677,13_h,health_belief,,Does diet cause bad breathe?,diet causes bad breathe,1.0,,0.5


In [7]:
assert 113 == len(df_filled.drop_duplicates(["query_id", "data_source"]))

In [8]:
metrics = {item: [] for item in data_source_types}

for data_source in data_source_types:
    for agg_type in agg_types:
        if data_source == "all":
            metrics[data_source].append(calc_auc(df_filled, agg_type))
        else:
            df_cut = df_filled.query(f"data_source == '{data_source}'")
            metrics[data_source].append(calc_auc(df_cut, agg_type))

metrics_df = pd.DataFrame(metrics, index=agg_types).round(4)
metrics_df

Unnamed: 0,2019,2021,health_belief,misbelief,all
avg,0.5675,0.6208,0.5833,0.5179,0.5766
top1,0.5467,0.5792,0.5833,0.6429,0.558
norm_linear,0.5467,0.608,0.6667,0.5179,0.5757
norm_log,0.5398,0.6,0.625,0.5179,0.5712


## Question

In [9]:
raw_df = pd.read_csv("../../../data/pubmed_search/pubmed_search_abstracts_question.csv")

df = (
    raw_df
    .merge(init_data, on=["data_source", "query_id", "label", "description", "query"])
    .dropna()
)
print(len(df))
df.head(1)

507


Unnamed: 0,query_id,data_source,pubmed_id,description,query,label,abstract
0,1,2019,34444706,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Cranberry is a fruit originally from New Engla...


In [18]:
results = []

for _, row in pb(df.iterrows(), max_value=len(df)):
    results.append(predict(row.description, row.abstract))

100% (507 of 507) |######################| Elapsed Time: 0:23:03 Time:  0:23:03


In [19]:
df['prediction'] = results
df.head(1)

Unnamed: 0,query_id,data_source,pubmed_id,description,query,label,abstract,prediction
0,1,2019,34444706,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Cranberry is a fruit originally from New Engla...,0.354085


In [20]:
df.to_csv("predictions/biolinkbert_bioasq_pubmed_search_preds_question.csv", index=0)

## Calc Metrics

In [10]:
df = pd.read_csv("predictions/biolinkbert_bioasq_pubmed_search_preds_question.csv")
print(len(df))
df.head(2)

507


Unnamed: 0,query_id,data_source,pubmed_id,description,query,label,abstract,prediction
0,1,2019,34444706,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Cranberry is a fruit originally from New Engla...,0.354085
1,1,2019,33751068,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Background: Cranberry proanthocyanidins (c-PAC...,0.281109


In [15]:
df_filled = (
    df
    .merge(init_data, how='outer', on=["query_id", "data_source", "label", "description", "query"])
    .fillna({"prediction": 0.5})
)
print(len(df_filled))
df_filled.tail(5)

550


Unnamed: 0,query_id,data_source,pubmed_id,description,query,label,abstract,prediction
545,206,misbelief,,Does aloe help for a runny nose?,aloe help runny nose,0.0,,0.5
546,211,misbelief,,Does licorice root help with cough?,licorice root help cough,1.0,,0.5
547,214,misbelief,,Does garlic kill viruses?,garlic kill viruses,1.0,,0.5
548,12_h,health_belief,,Does cialis treat enlarged prostrate?,cialis treats enlarged prostrate,1.0,,0.5
549,13_h,health_belief,,Does diet cause bad breathe?,diet causes bad breathe,1.0,,0.5


In [16]:
assert 113 == len(df_filled.drop_duplicates(["query_id", "data_source"]))

In [17]:
metrics = {item: [] for item in data_source_types}

for data_source in data_source_types:
    for agg_type in agg_types:
        if data_source == "all":
            metrics[data_source].append(calc_auc(df_filled, agg_type))
        else:
            df_cut = df_filled.query(f"data_source == '{data_source}'")
            metrics[data_source].append(calc_auc(df_cut, agg_type))

metrics_df = pd.DataFrame(metrics, index=agg_types).round(4)
metrics_df

Unnamed: 0,2019,2021,health_belief,misbelief,all
avg,0.6021,0.4944,0.6667,0.375,0.4757
top1,0.6713,0.5072,0.7083,0.5179,0.5675
norm_linear,0.5986,0.4944,0.6667,0.3929,0.4751
norm_log,0.5986,0.4688,0.6667,0.3929,0.471
