## Source: PubMed
## Retriever: BioMed Explorer
## Model: BioLinkBERT (BioASQ)

In [1]:
import torch
from sklearn.metrics import classification_report, roc_auc_score

import sys
sys.path.append("../../") # use utils

import utils
import importlib
importlib.reload(utils)

from utils import calc_auc

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import AutoTokenizer, BertForSequenceClassification, BertModel, DataCollatorWithPadding

model_name = "<path_to_bio-linkbert-large__bioasq_hf>"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name)
model.eval()
print("OK")

OK


In [3]:
def predict(question, passage):
    sequence = tokenizer.encode_plus(
        question, 
        passage, 
        return_tensors="pt",
        max_length=512,
        truncation=True
    )['input_ids']

    logits = model(sequence)[0]
    probabilities = torch.softmax(logits, dim=1).detach().cpu().tolist()[0]
    proba_yes = probabilities[1]
    
    return proba_yes

In [4]:
init_data = pd.read_csv("../../../data/data_to_process.csv")
print(len(init_data))
init_data.head(3)

113


Unnamed: 0,data_source,query_id,description,query,label
0,2019,1,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0
1,2019,3,Can acupuncture be effective for people with e...,acupuncture epilepsy,0.0
2,2019,5,Can acupuncture prevent migraines?,acupuncture migraine,1.0


## Keywords

In [7]:
df = (  
  pd.read_csv("../../../data/biomed_explorer/biomed_explorer_abstracts_keywords.csv")
  .dropna()
)

print(len(df))
df.head(1)

1929


Unnamed: 0,query_id,data_source,pubmed_id,abstract,description,query,label,query_processed
0,1,2019,23396043,Despite considerable controversy about their e...,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Can cranberries prevent urinary tract infections?


In [23]:
results = []

for _, row in pb(df.iterrows(), max_value=len(df)):
    results.append(predict(row.description, row.abstract))

100% (1929 of 1929) |####################| Elapsed Time: 1:08:02 Time:  1:08:02


In [24]:
df['prediction'] = results
df.head(1)

Unnamed: 0,query_id,data_source,pubmed_id,abstract,description,query,label,query_processed,prediction
0,1,2019,23396043,Despite considerable controversy about their e...,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Can cranberries prevent urinary tract infections?,0.295032


In [25]:
df.to_csv("predictions/biolinkbert_bioasq_biomed_explorer_preds_keywords.csv", index=0)

## Calc Metrics

In [13]:
df = pd.read_csv("predictions/biolinkbert_bioasq_biomed_explorer_preds_keywords.csv")
print(len(df))
df.head(2)

1066


Unnamed: 0,query_id,data_source,pubmed_id,abstract,description,query,label,query_processed,prediction
0,1,2019,23396043,Despite considerable controversy about their e...,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Can cranberries prevent urinary tract infections?,0.295032
1,1,2019,22760907,Lower urinary tract infections are very common...,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Can cranberries prevent urinary tract infections?,0.42042


In [14]:
agg_types = ["avg", "top1", "norm_linear", "norm_log"]
data_source_types = sorted(df.data_source.unique().tolist()) + ["all"]
data_source_types

['2019', '2021', 'health_belief', 'misbelief', 'all']

In [15]:
df_filled = (
    df
    .merge(init_data, how='outer', on=['query_id', 'data_source', "label"])
    .fillna({"prediction": 0.5})
)
print(len(df_filled))
df_filled.tail(2)

1067


Unnamed: 0,query_id,data_source,pubmed_id,abstract,description_x,query_x,label,query_processed,prediction,description_y,query_y
1065,14_h,health_belief,16351605.0,"Listeria monocytogenes is a Gram-positive, wea...",Does listeria cause miscarriage?,listeria causes miscarriage,1.0,Does listeria cause miscarriage?,0.351155,Does listeria cause miscarriage?,listeria causes miscarriage
1066,202,misbelief,,,,,0.0,,0.5,Can hemorrhoids be cured with leeches?,hemorrhoids cured leeches


In [16]:
df_filled = (
    df
    .merge(init_data, how='outer', on=['query_id', 'data_source', "label"])
    .fillna({"prediction": 0.5})
)
print(len(df_filled))
df_filled.tail(2)

1067


Unnamed: 0,query_id,data_source,pubmed_id,abstract,description_x,query_x,label,query_processed,prediction,description_y,query_y
1065,14_h,health_belief,16351605.0,"Listeria monocytogenes is a Gram-positive, wea...",Does listeria cause miscarriage?,listeria causes miscarriage,1.0,Does listeria cause miscarriage?,0.351155,Does listeria cause miscarriage?,listeria causes miscarriage
1066,202,misbelief,,,,,0.0,,0.5,Can hemorrhoids be cured with leeches?,hemorrhoids cured leeches


In [17]:
assert 113 == len(df_filled.drop_duplicates(["query_id", "data_source"]))

In [19]:
metrics = {item: [] for item in data_source_types}

for data_source in data_source_types:
    for agg_type in agg_types:
        if data_source == "all":
            metrics[data_source].append(calc_auc(df_filled, agg_type))
        else:
            try:
                df_cut = df_filled.query(f"data_source == '{data_source}'")
                metrics[data_source].append(calc_auc(df_cut, agg_type))
            except ValueError:
                print(f"Can't calc auc for {data_source} {agg_type}")
            
for key, value in metrics.items():
    if not value:
        metrics[key] = [None] * 4

metrics_df = pd.DataFrame(metrics, index=agg_types).round(4)
metrics_df

Unnamed: 0,2019,2021,health_belief,misbelief,all
avg,0.5917,0.8288,0.5,0.7857,0.7369
top1,0.3841,0.704,0.4167,0.4464,0.5721
norm_linear,0.5536,0.8272,0.5,0.8036,0.7185
norm_log,0.5017,0.808,0.5,0.7679,0.7008


## Question

In [20]:
df = (
    pd.read_csv("../../../data/biomed_explorer/biomed_explorer_abstracts_question.csv")
    .dropna()
)

print(len(df))
df.head(1)

1063


Unnamed: 0,query_id,data_source,pubmed_id,abstract,description,query,label,query_processed
0,1,2019,19219097,Background: Cranberries have been used for pre...,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Can cranberries prevent urinary tract infections?


In [6]:
results = []

for _, row in pb(df.iterrows(), max_value=len(df)):
    results.append(predict(row.description, row.abstract))

100% (1906 of 1906) |####################| Elapsed Time: 1:05:34 Time:  1:05:34


In [7]:
df['prediction'] = results
df.head(1)

Unnamed: 0,query_id,data_source,pubmed_id,abstract,description,query,label,query_processed,prediction
0,1,2019,19219097,Background: Cranberries have been used for pre...,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Can cranberries prevent urinary tract infections?,0.325672


In [8]:
df.to_csv("predictions/biolinkbert_bioasq_biomed_explorer_preds_question.csv", index=0)

## Calc Metrics

In [21]:
df = pd.read_csv("predictions/biolinkbert_bioasq_biomed_explorer_preds_question.csv")
print(len(df))
df.head(2)

1063


Unnamed: 0,query_id,data_source,pubmed_id,abstract,description,query,label,query_processed,prediction
0,1,2019,19219097,Background: Cranberries have been used for pre...,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Can cranberries prevent urinary tract infections?,0.325672
1,1,2019,28288837,Purpose: We sought to clarify the association ...,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Can cranberries prevent urinary tract infections?,0.542461


In [22]:
df_filled = (
    df
    .merge(init_data, how='outer', on=['query_id', 'data_source', "label"])
    .fillna({"prediction": 0.5})
)
print(len(df_filled))
df_filled.tail(2)

1064


Unnamed: 0,query_id,data_source,pubmed_id,abstract,description_x,query_x,label,query_processed,prediction,description_y,query_y
1062,14_h,health_belief,25681385.0,Recurrent miscarriage is frustrating for the p...,Does listeria cause miscarriage?,listeria causes miscarriage,1.0,Does listeria cause miscarriage?,0.258255,Does listeria cause miscarriage?,listeria causes miscarriage
1063,202,misbelief,,,,,0.0,,0.5,Can hemorrhoids be cured with leeches?,hemorrhoids cured leeches


In [23]:
assert 113 == len(df_filled.drop_duplicates(["query_id", "data_source"]))

In [24]:
metrics = {item: [] for item in data_source_types}

for data_source in data_source_types:
    for agg_type in agg_types:
        if data_source == "all":
            metrics[data_source].append(calc_auc(df_filled, agg_type))
        else:
            try:
                df_cut = df_filled.query(f"data_source == '{data_source}'")
                metrics[data_source].append(calc_auc(df_cut, agg_type))
            except ValueError:
                print(f"Can't calc auc for {data_source} {agg_type}")
            
for key, value in metrics.items():
    if not value:
        metrics[key] = [None] * 4

metrics_df = pd.DataFrame(metrics, index=agg_types).round(4)
metrics_df

Unnamed: 0,2019,2021,health_belief,misbelief,all
avg,0.654,0.808,0.5,0.8036,0.7119
top1,0.5536,0.8848,0.625,0.5893,0.7166
norm_linear,0.6574,0.8336,0.5417,0.7857,0.728
norm_log,0.6401,0.8448,0.5417,0.7857,0.729
