## Source: PubMed
## Retriever: BioMed Explorer
## Model: BioLinkBERT (PubMedQA)

In [2]:
import torch
from sklearn.metrics import classification_report, roc_auc_score

import sys
sys.path.append("../../") # use utils

import utils
import importlib
importlib.reload(utils)

from utils import calc_auc

In [2]:
from transformers import AutoTokenizer, BertForSequenceClassification, BertModel, DataCollatorWithPadding

model_name = "<path_to_bio-linkbert-large__pubmedqa_hf>"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name)
model.eval()
print("OK")

OK


In [4]:
def predict(question, passage):
    sequence = tokenizer.encode_plus(
        question, 
        passage, 
        return_tensors="pt",
        max_length=512,
        truncation=True
    )['input_ids']

    logits = model(sequence)[0]
    probabilities = torch.softmax(logits, dim=1).detach().cpu().tolist()[0]
    proba_yes = probabilities[2]
    
    return proba_yes

In [1]:
init_data = pd.read_csv("../../../data/data_to_process.csv")
print(len(init_data))
init_data.head(3)

113


Unnamed: 0,data_source,query_id,description,query,label
0,2019,1,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0
1,2019,3,Can acupuncture be effective for people with e...,acupuncture epilepsy,0.0
2,2019,5,Can acupuncture prevent migraines?,acupuncture migraine,1.0


## Keywords

In [3]:
df = ( 
  pd.read_csv("../../../data/biomed_explorer/biomed_explorer_abstracts_keywords.csv")
  .dropna()
)

print(len(df))
df.head(1)

1066


Unnamed: 0,query_id,data_source,pubmed_id,abstract,description,query,label,query_processed
0,1,2019,23396043,Despite considerable controversy about their e...,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Can cranberries prevent urinary tract infections?


In [8]:
results = []

for _, row in pb(df.iterrows(), max_value=len(df)):
    results.append(predict(row.description, row.abstract))

100% (10 of 10) |########################| Elapsed Time: 0:00:25 Time:  0:00:25


In [9]:
df['prediction'] = results
df.head(1)

Unnamed: 0,query_id,data_source,pubmed_id,abstract,description,query,label,query_processed,prediction
329,51,2019,8087387,This study analyzes the role of dehumidifiers ...,Can dehumidifiers be used to control asthma?,dehumidifiers asthma,0.0,Can dehumidifiers be used to control asthma?,0.999365


In [19]:
df.to_csv("predictions/biolinkbert_pubmedqa_biomed_explorer_preds_keywords.csv", index=0)

## Calc Metrics

In [4]:
df = pd.read_csv("predictions/biolinkbert_pubmedqa_biomed_explorer_preds_keywords.csv")
print(len(df))
df.head(2)

1066


Unnamed: 0,query_id,data_source,pubmed_id,abstract,description,query,label,query_processed,prediction
0,1,2019,23396043,Despite considerable controversy about their e...,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Can cranberries prevent urinary tract infections?,0.91687
1,1,2019,22760907,Lower urinary tract infections are very common...,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Can cranberries prevent urinary tract infections?,0.898625


In [5]:
agg_types = ["avg", "top1", "norm_linear", "norm_log"]
data_source_types = sorted(df.data_source.unique().tolist()) + ["all"]
data_source_types

['2019', '2021', 'health_belief', 'misbelief', 'all']

In [6]:
df_filled = (
    df
    .merge(init_data, how='outer', on=['query_id', 'data_source', "label"])
    .fillna({"prediction": 0.5})
)
print(len(df_filled))
df_filled.tail(2)

1066


Unnamed: 0,query_id,data_source,pubmed_id,abstract,description_x,query_x,label,query_processed,prediction,description_y,query_y
1064,14_h,health_belief,30675327,Background and objectives: Listeria monocytoge...,Does listeria cause miscarriage?,listeria causes miscarriage,1.0,Does listeria cause miscarriage?,0.423971,Does listeria cause miscarriage?,listeria causes miscarriage
1065,14_h,health_belief,16351605,"Listeria monocytogenes is a Gram-positive, wea...",Does listeria cause miscarriage?,listeria causes miscarriage,1.0,Does listeria cause miscarriage?,0.208132,Does listeria cause miscarriage?,listeria causes miscarriage


In [7]:
assert 113 == len(df_filled.drop_duplicates(["query_id", "data_source"]))

In [9]:
metrics = {item: [] for item in data_source_types}

for data_source in data_source_types:
    for agg_type in agg_types:
        if data_source == "all":
            metrics[data_source].append(calc_auc(df_filled, agg_type))
        else:
            try:
                df_cut = df_filled.query(f"data_source == '{data_source}'")
                metrics[data_source].append(calc_auc(df_cut, agg_type))
            except ValueError:
                print(f"Can't calc auc for {data_source} {agg_type}")
            
for key, value in metrics.items():
    if not value:
        metrics[key] = [None] * 4

metrics_df = pd.DataFrame(metrics, index=agg_types).round(4)
metrics_df

Unnamed: 0,2019,2021,health_belief,misbelief,all
avg,0.4671,0.8048,0.7917,0.6786,0.6393
top1,0.4048,0.8096,0.5833,0.6429,0.6274
norm_linear,0.474,0.8208,0.75,0.6964,0.6573
norm_log,0.4775,0.8368,0.75,0.6786,0.6636


## Question

In [10]:
df = (
    pd.read_csv("../../../data/biomed_explorer/biomed_explorer_abstracts_question.csv")
    .dropna()
)

print(len(df))
df.head(1)

1063


Unnamed: 0,query_id,data_source,pubmed_id,abstract,description,query,label,query_processed
0,1,2019,19219097,Background: Cranberries have been used for pre...,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Can cranberries prevent urinary tract infections?


In [22]:
results = []

for _, row in pb(df.iterrows(), max_value=len(df)):
    results.append(predict(row.description, row.abstract))

100% (10 of 10) |########################| Elapsed Time: 0:00:26 Time:  0:00:26


In [30]:
df['prediction'] = results
df.head(1)

Unnamed: 0,query_id,data_source,pubmed_id,abstract,description,query,label,query_processed,prediction
0,1,2019,19219097,Background: Cranberries have been used for pre...,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Can cranberries prevent urinary tract infections?,0.677117


In [31]:
df.to_csv("predictions/biolinkbert_pubmedqa_biomed_explorer_preds_question.csv", index=0)

## Calc Metrics

In [11]:
df = pd.read_csv("predictions/biolinkbert_pubmedqa_biomed_explorer_preds_question.csv")
print(len(df))
df.head(2)

1063


Unnamed: 0,query_id,data_source,pubmed_id,abstract,description,query,label,query_processed,prediction
0,1,2019,19219097,Background: Cranberries have been used for pre...,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Can cranberries prevent urinary tract infections?,0.677117
1,1,2019,28288837,Purpose: We sought to clarify the association ...,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Can cranberries prevent urinary tract infections?,0.936215


In [12]:
df_filled = (
    df
    .merge(init_data, how='outer', on=['query_id', 'data_source', "label"])
    .fillna({"prediction": 0.5})
)
print(len(df_filled))
df_filled.tail(2)

1063


Unnamed: 0,query_id,data_source,pubmed_id,abstract,description_x,query_x,label,query_processed,prediction,description_y,query_y
1061,14_h,health_belief,19775781,Objective: to explore midwives' perceptions of...,Does listeria cause miscarriage?,listeria causes miscarriage,1.0,Does listeria cause miscarriage?,0.972747,Does listeria cause miscarriage?,listeria causes miscarriage
1062,14_h,health_belief,25681385,Recurrent miscarriage is frustrating for the p...,Does listeria cause miscarriage?,listeria causes miscarriage,1.0,Does listeria cause miscarriage?,0.446948,Does listeria cause miscarriage?,listeria causes miscarriage


In [13]:
assert 113 == len(df_filled.drop_duplicates(["query_id", "data_source"]))

In [14]:
metrics = {item: [] for item in data_source_types}

for data_source in data_source_types:
    for agg_type in agg_types:
        if data_source == "all":
            metrics[data_source].append(calc_auc(df_filled, agg_type))
        else:
            try:
                df_cut = df_filled.query(f"data_source == '{data_source}'")
                metrics[data_source].append(calc_auc(df_cut, agg_type))
            except ValueError:
                print(f"Can't calc auc for {data_source} {agg_type}")
            
for key, value in metrics.items():
    if not value:
        metrics[key] = [None] * 4

metrics_df = pd.DataFrame(metrics, index=agg_types).round(4)
metrics_df

Unnamed: 0,2019,2021,health_belief,misbelief,all
avg,0.4983,0.872,0.625,0.6607,0.6847
top1,0.436,0.7856,0.8333,0.5,0.6009
norm_linear,0.5121,0.8736,0.7917,0.6607,0.691
norm_log,0.5052,0.8752,0.75,0.6607,0.6885
