## Source: PubMed
## Retriever: Google Search
## Model: BioLinkBERT (PubMedQA)

In [1]:
import torch
from sklearn.metrics import classification_report, roc_auc_score

import sys
sys.path.append("../../") # use utils

import utils
import importlib
importlib.reload(utils)

from utils import calc_auc

In [2]:
from transformers import AutoTokenizer, BertForSequenceClassification, BertModel, DataCollatorWithPadding

model_name = "<path_to_bio-linkbert-large__pubmedqa_hf>"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name)
model.eval()
print("OK")

OK


In [2]:
def predict(question, passage):
    sequence = tokenizer.encode_plus(
        question, 
        passage, 
        return_tensors="pt",
        max_length=512,
        truncation=True
    )['input_ids']

    logits = model(sequence)[0]
    probabilities = torch.softmax(logits, dim=1).detach().cpu().tolist()[0]
    proba_yes = probabilities[2]
    
    return proba_yes

In [3]:
init_data = pd.read_csv("../../../data/data_to_process.csv")
print(len(init_data))
init_data.head(3)

113


Unnamed: 0,data_source,query_id,description,query,label
0,2019,1,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0
1,2019,3,Can acupuncture be effective for people with e...,acupuncture epilepsy,0.0
2,2019,5,Can acupuncture prevent migraines?,acupuncture migraine,1.0


## Keywords

In [4]:
raw_df = pd.concat((
    pd.read_csv("../../../data/google_search/google_search_abstracts_keywords_part_1.csv"),
    pd.read_csv("../../../data/google_search/google_search_abstracts_keywords_part_2.csv")
))
df = (
    raw_df
    .merge(init_data, on=["data_source", "query_id", "label", "description", "query"])
    .dropna()
)
print(len(df))
df.head(1)

999


Unnamed: 0,query_id,data_source,pubmed_id,description,query,label,abstract
0,1,2019,19441868,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Urinary tract infection (UTI) refers to the pr...


In [6]:
results = []

for _, row in pb(df.iterrows(), max_value=len(df)):
    results.append(predict(row.description, row.abstract))

100% (9 of 9) |##########################| Elapsed Time: 0:00:27 Time:  0:00:27


In [11]:
df['prediction'] = results
df.head(1)

Unnamed: 0,query_id,data_source,pubmed_id,description,query,label,abstract,prediction
0,1,2019,19441868,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Urinary tract infection (UTI) refers to the pr...,0.696868


In [12]:
df.to_csv("predictions/biolinkbert_pubmedqa_google_search_preds_keywords.csv", index=0)

## Calc metrics

In [3]:
init_data = pd.read_csv("../data_to_process.csv")
print(len(init_data))
init_data.head(3)

113


Unnamed: 0,data_source,query_id,description,query,label
0,2019,1,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0
1,2019,3,Can acupuncture be effective for people with e...,acupuncture epilepsy,0.0
2,2019,5,Can acupuncture prevent migraines?,acupuncture migraine,1.0


In [5]:
df = pd.read_csv("predictions/biolinkbert_pubmedqa_google_search_preds_keywords.csv").drop(['label'], axis=1)
print(len(df))
df.head(2)

1021


Unnamed: 0,query_id,data_source,pubmed_id,description,query,abstract,prediction
0,1,2019,19441868,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,Urinary tract infection (UTI) refers to the pr...,0.696868
1,1,2019,23076891,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,Background: Cranberries have been used widely ...,0.180846


In [6]:
agg_types = ["avg", "top1", "norm_linear", "norm_log"]
data_source_types = sorted(df.data_source.unique().tolist()) + ["all"]
data_source_types

['2019', '2021', 'health_belief', 'misbelief', 'all']

In [7]:
df_filled = (
    df
    .merge(init_data, how='outer', on=['query_id', 'data_source'])
    .fillna({"prediction": 0.5})
)
print(len(df_filled))
df_filled.tail(5)

1022


Unnamed: 0,query_id,data_source,pubmed_id,description_x,query_x,abstract,prediction,description_y,query_y,label
1017,14_h,health_belief,12561675.0,Does listeria cause miscarriage?,listeria causes miscarriage,"Listeria monocytogenes, an intracellular facul...",0.997213,Does listeria cause miscarriage?,listeria causes miscarriage,1.0
1018,14_h,health_belief,29720597.0,Does listeria cause miscarriage?,listeria causes miscarriage,Listeria monocytogenes is a mammalian pathogen...,0.638575,Does listeria cause miscarriage?,listeria causes miscarriage,1.0
1019,14_h,health_belief,19542009.0,Does listeria cause miscarriage?,listeria causes miscarriage,Listeria monocytogenes is a ubiquitous bacteri...,0.71211,Does listeria cause miscarriage?,listeria causes miscarriage,1.0
1020,14_h,health_belief,28367407.0,Does listeria cause miscarriage?,listeria causes miscarriage,Listeria monocytogenes is a known cause of gas...,0.535308,Does listeria cause miscarriage?,listeria causes miscarriage,1.0
1021,123,2021,,,,,0.5,Can I get rid of a pimple overnight by applyin...,toothpaste pimple overnight,0.0


In [8]:
assert 113 == len(df_filled.drop_duplicates(["query_id", "data_source"]))

In [9]:
metrics = {item: [] for item in data_source_types}

for data_source in data_source_types:
    for agg_type in agg_types:
        if data_source == "all":
            metrics[data_source].append(calc_auc(df_filled, agg_type))
        else:
            df_cut = df_filled.query(f"data_source == '{data_source}'")
            metrics[data_source].append(calc_auc(df_cut, agg_type))

metrics_df = pd.DataFrame(metrics, index=agg_types).round(4)
metrics_df

Unnamed: 0,2019,2021,health_belief,misbelief,all
avg,0.436,0.8176,0.7083,0.6607,0.6592
top1,0.3945,0.7072,0.4583,0.6786,0.5943
norm_linear,0.4498,0.8304,0.75,0.6964,0.6702
norm_log,0.4498,0.8384,0.75,0.7143,0.6781


## Question

In [10]:
df = (
    pd.concat((
        pd.read_csv("../../../data/google_search/google_search_abstracts_question_part_1.csv"),
        pd.read_csv("../../../data/google_search/google_search_abstracts_question_part_2.csv")
    ))
    .dropna()
)

print(len(df))
df.head(1)

1017


Unnamed: 0,query_id,data_source,pubmed_id,description,query,label,abstract
0,1,2019,21788542,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Background: The increasing prevalence of uropa...


In [27]:
results = []

for _, row in pb(df.iterrows(), max_value=len(df)):
    results.append(predict(row.description, row.abstract))

100% (10 of 10) |########################| Elapsed Time: 0:00:28 Time:  0:00:28


In [31]:
df['prediction'] = results
df.head(1)

Unnamed: 0,query_id,data_source,pubmed_id,description,query,label,abstract,prediction
0,1,2019,21788542,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Background: The increasing prevalence of uropa...,0.46422


In [32]:
df.to_csv("predictions/biolinkbert_pubmedqa_google_search_preds_question.csv", index=0)

## Calc Metrics

In [11]:
df = pd.read_csv("predictions/biolinkbert_pubmedqa_google_search_preds_question.csv").drop("label", axis=1)
print(len(df))
df.head(2)

1017


Unnamed: 0,query_id,data_source,pubmed_id,description,query,abstract,prediction
0,1,2019,21788542,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,Background: The increasing prevalence of uropa...,0.46422
1,1,2019,18253990,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,Background: Cranberries have been used widely ...,0.219321


In [12]:
df_filled = (
    df
    .merge(init_data, how='outer', on=['query_id', 'data_source'])
    .fillna({"prediction": 0.5})
)
print(len(df_filled))
df_filled.tail(5)

1019


Unnamed: 0,query_id,data_source,pubmed_id,description_x,query_x,abstract,prediction,description_y,query_y,label
1014,14_h,health_belief,35786468.0,Does listeria cause miscarriage?,listeria causes miscarriage,Listeria monocytogenes (LM) is a food-borne pa...,0.799275,Does listeria cause miscarriage?,listeria causes miscarriage,1.0
1015,14_h,health_belief,30675327.0,Does listeria cause miscarriage?,listeria causes miscarriage,Background and objectives: Listeria monocytoge...,0.423971,Does listeria cause miscarriage?,listeria causes miscarriage,1.0
1016,14_h,health_belief,28367407.0,Does listeria cause miscarriage?,listeria causes miscarriage,Listeria monocytogenes is a known cause of gas...,0.535308,Does listeria cause miscarriage?,listeria causes miscarriage,1.0
1017,128,2021,,,,,0.5,Does steam from a shower help croup?,steam shower croup,0.0
1018,134,2021,,,,,0.5,Can I remove a tick by covering it with Vaseline?,remove tick with vaseline,0.0


In [13]:
assert 113 == len(df_filled.query("data_source != 'wh_topics'").drop_duplicates(["query_id", "data_source"]))

In [14]:
metrics = {item: [] for item in data_source_types}

for data_source in data_source_types:
    for agg_type in agg_types:
        if data_source == "all":
            metrics[data_source].append(calc_auc(df_filled, agg_type))
        else:
            df_cut = df_filled.query(f"data_source == '{data_source}'")
            metrics[data_source].append(calc_auc(df_cut, agg_type))

metrics_df = pd.DataFrame(metrics, index=agg_types).round(4)
metrics_df

Unnamed: 0,2019,2021,health_belief,misbelief,all
avg,0.4083,0.8208,0.625,0.625,0.6356
top1,0.5986,0.7056,0.4583,0.4821,0.6422
norm_linear,0.4567,0.824,0.625,0.6786,0.6624
norm_log,0.4948,0.832,0.5833,0.7321,0.675
