## Source: Wikipedia
## Retriever: Google Search
## Model: BioLinkBERT (BioASQ)

In [1]:
import torch
from sklearn.metrics import classification_report, roc_auc_score

import sys
sys.path.append("../../") # use utils

import utils
import importlib
importlib.reload(utils)

from utils import calc_auc

In [2]:
from transformers import AutoTokenizer, BertForSequenceClassification, BertModel, DataCollatorWithPadding

model_name = "<path_to_bio-linkbert-large__bioasq_hf>"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name)
model.eval()
print("OK")

OK


In [3]:
def predict(question, passage):
    sequence = tokenizer.encode_plus(
        question, 
        passage, 
        return_tensors="pt",
        max_length=512,
        truncation=True
    )['input_ids']

    logits = model(sequence)[0]
    probabilities = torch.softmax(logits, dim=1).detach().cpu().tolist()[0]
    proba_yes = probabilities[1]
    
    return proba_yes

In [3]:
init_data = pd.read_csv("../../../data/data_to_process.csv")
print(len(init_data))
init_data.head(3)

113


Unnamed: 0,data_source,query_id,description,query,label
0,2019,1,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0
1,2019,3,Can acupuncture be effective for people with e...,acupuncture epilepsy,0.0
2,2019,5,Can acupuncture prevent migraines?,acupuncture migraine,1.0


## Keywords

In [4]:
raw_df = pd.read_csv("../../../data/google_search_wiki/wikipedia_articles_keywords.csv")
df = (
    raw_df
    .merge(init_data, on=["data_source", "query_id", "label"])
    .dropna()
)
print(len(df))
df.head(1)

519


Unnamed: 0,query_processed,article_title,paragraph,serp_position,title_common_tokens,num_title_common_tokens,tokens_to_find,paragraph_common_tokens,num_paragraph_common_tokens,paragraph_number_in_article,query_id,data_source,label,description,query
0,cranberries urinary tract infections,Cranberry,A comprehensive review in 2012 of available re...,1.0,['cranberry'],1.0,"['urinary', 'infection', 'tract']","['urinary', 'infection', 'tract']",3.0,82.0,1,2019,0.0,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections


In [11]:
results = []

for _, row in pb(df.iterrows(), max_value=len(df)):
    results.append(predict(row.description, row.paragraph))

100% (519 of 519) |######################| Elapsed Time: 0:14:49 Time:  0:14:49


In [12]:
df['prediction'] = results
df.head(1)

Unnamed: 0,query_processed,article_title,paragraph,serp_position,title_common_tokens,num_title_common_tokens,tokens_to_find,paragraph_common_tokens,num_paragraph_common_tokens,paragraph_number_in_article,query_id,data_source,label,description,query,prediction
0,cranberries urinary tract infections,Cranberry,A comprehensive review in 2012 of available re...,1.0,['cranberry'],1.0,"['urinary', 'infection', 'tract']","['urinary', 'infection', 'tract']",3.0,82.0,1,2019,0.0,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.04461


In [13]:
df.to_csv("predictions/biolinkbert_bioasq_google_search_wiki_preds_keywords.csv", index=0)

## Calc metrics

In [14]:
init_data = pd.read_csv("../data_to_process.csv").query("data_source != 'wh_topics'")
print(len(init_data))
init_data.head(3)

113


Unnamed: 0,data_source,query_id,description,query,label
0,2019,1,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0
1,2019,3,Can acupuncture be effective for people with e...,acupuncture epilepsy,0.0
2,2019,5,Can acupuncture prevent migraines?,acupuncture migraine,1.0


In [6]:
df = pd.read_csv("predictions/biolinkbert_bioasq_google_search_wiki_preds_keywords.csv")
print(len(df))
df.head(2)

519


Unnamed: 0,query_processed,article_title,paragraph,serp_position,title_common_tokens,num_title_common_tokens,tokens_to_find,paragraph_common_tokens,num_paragraph_common_tokens,paragraph_number_in_article,query_id,data_source,label,description,query,prediction
0,cranberries urinary tract infections,Cranberry,A comprehensive review in 2012 of available re...,1.0,['cranberry'],1.0,"['urinary', 'infection', 'tract']","['urinary', 'infection', 'tract']",3.0,82.0,1,2019,0.0,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.04461
1,cranberries urinary tract infections,Cranberry juice,Cranberry juice is the liquid juice of the cra...,2.0,['cranberry'],1.0,"['urinary', 'infection', 'tract']","['urinary', 'infection', 'tract']",3.0,0.0,1,2019,0.0,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.077757


In [7]:
agg_types = ["avg", "top1", "norm_linear", "norm_log"]
data_source_types = sorted(df.data_source.unique().tolist()) + ["all"]
data_source_types

['2019', '2021', 'health_belief', 'misbelief', 'all']

In [8]:
df_filled = (
    df
    .merge(init_data, how='outer', on=["query", "description", 'query_id', 'data_source', "label"])
    .fillna({"prediction": 0.5})
)
print(len(df_filled))
df_filled.tail(5)

538


Unnamed: 0,query_processed,article_title,paragraph,serp_position,title_common_tokens,num_title_common_tokens,tokens_to_find,paragraph_common_tokens,num_paragraph_common_tokens,paragraph_number_in_article,query_id,data_source,label,description,query,prediction
533,,,,,,,,,,,206,misbelief,0.0,Does aloe help for a runny nose?,aloe help runny nose,0.5
534,,,,,,,,,,,212,misbelief,1.0,Does echinacea boost immunity?,echinacea boost immunity,0.5
535,,,,,,,,,,,213,misbelief,1.0,Does honey boost immunity?,honey boost immunity,0.5
536,,,,,,,,,,,214,misbelief,1.0,Does garlic kill viruses?,garlic kill viruses,0.5
537,,,,,,,,,,,12_h,health_belief,1.0,Does cialis treat enlarged prostrate?,cialis treats enlarged prostrate,0.5


In [9]:
assert 113 == len(df_filled.drop_duplicates(["query_id", "data_source"]))

In [10]:
metrics = {item: [] for item in data_source_types}

for data_source in data_source_types:
    for agg_type in agg_types:
        if data_source == "all":
            metrics[data_source].append(calc_auc(df_filled, agg_type))
        else:
            df_cut = df_filled.query(f"data_source == '{data_source}'")
            metrics[data_source].append(calc_auc(df_cut, agg_type))

metrics_df = pd.DataFrame(metrics, index=agg_types).round(4)
metrics_df

Unnamed: 0,2019,2021,health_belief,misbelief,all
avg,0.5848,0.5792,0.7083,0.5536,0.5678
top1,0.699,0.4464,0.5833,0.4643,0.5489
norm_linear,0.6021,0.5408,0.7083,0.5536,0.5659
norm_log,0.609,0.528,0.7083,0.5357,0.5665


## Question

In [12]:
raw_df = pd.concat((
    pd.read_csv("../../../data/google_search_wiki/wikipedia_articles_question.csv"),
))

df = (
    raw_df
    .merge(init_data, on=["data_source", "query_id", "label"])
    .dropna()
)

print(len(df))
df.head(1)

627


Unnamed: 0,query_processed,article_title,paragraph,serp_position,title_common_tokens,num_title_common_tokens,tokens_to_find,paragraph_common_tokens,num_paragraph_common_tokens,paragraph_number_in_article,query_id,data_source,label,description,query
0,Can cranberries prevent urinary tract infections?,Cranberry,A comprehensive review in 2012 of available re...,1.0,['cranberry'],1.0,"['prevent', 'tract', 'urinary', 'can', 'infect...","['urinary', 'infection', 'tract']",3.0,82.0,1,2019,0.0,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections


In [26]:
results = []

for _, row in pb(df.iterrows(), max_value=len(df)):
    results.append(predict(row.description, row.paragraph))

100% (627 of 627) |######################| Elapsed Time: 0:18:58 Time:  0:18:58


In [27]:
df['prediction'] = results
df.head(1)

Unnamed: 0,query_processed,article_title,paragraph,serp_position,title_common_tokens,num_title_common_tokens,tokens_to_find,paragraph_common_tokens,num_paragraph_common_tokens,paragraph_number_in_article,query_id,data_source,label,description,query,prediction
0,Can cranberries prevent urinary tract infections?,Cranberry,A comprehensive review in 2012 of available re...,1.0,['cranberry'],1.0,"['prevent', 'tract', 'urinary', 'can', 'infect...","['urinary', 'infection', 'tract']",3.0,82.0,1,2019,0.0,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.04461


In [28]:
df.to_csv("predictions/biolinkbert_bioasq_google_search_wiki_preds_question.csv", index=0)

## Calc Metrics

In [15]:
df = pd.read_csv("predictions/biolinkbert_bioasq_google_search_wiki_preds_question.csv").drop("label", axis=1)
print(len(df))
df.head(2)

627


Unnamed: 0,query_processed,article_title,paragraph,serp_position,title_common_tokens,num_title_common_tokens,tokens_to_find,paragraph_common_tokens,num_paragraph_common_tokens,paragraph_number_in_article,query_id,data_source,description,query,prediction
0,Can cranberries prevent urinary tract infections?,Cranberry,A comprehensive review in 2012 of available re...,1.0,['cranberry'],1.0,"['prevent', 'tract', 'urinary', 'can', 'infect...","['urinary', 'infection', 'tract']",3.0,82.0,1,2019,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.04461
1,Can cranberries prevent urinary tract infections?,Cranberry,Cranberries are a group of evergreen dwarf shr...,1.0,['cranberry'],1.0,"['prevent', 'tract', 'urinary', 'can', 'infect...",['can'],1.0,0.0,1,2019,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.123621


In [16]:
df_filled = (
    df
    .merge(init_data, how='outer', on=["query", "description", 'query_id', 'data_source'])
    .fillna({"prediction": 0.5})
)
print(len(df_filled))
df_filled.tail(5)

643


Unnamed: 0,query_processed,article_title,paragraph,serp_position,title_common_tokens,num_title_common_tokens,tokens_to_find,paragraph_common_tokens,num_paragraph_common_tokens,paragraph_number_in_article,query_id,data_source,description,query,prediction,label
638,,,,,,,,,,,205,misbelief,Does garlic help with thrush?,garlic help thrush,0.5,0.0
639,,,,,,,,,,,206,misbelief,Does aloe help for a runny nose?,aloe help runny nose,0.5,0.0
640,,,,,,,,,,,212,misbelief,Does echinacea boost immunity?,echinacea boost immunity,0.5,1.0
641,,,,,,,,,,,213,misbelief,Does honey boost immunity?,honey boost immunity,0.5,1.0
642,,,,,,,,,,,12_h,health_belief,Does cialis treat enlarged prostrate?,cialis treats enlarged prostrate,0.5,1.0


In [17]:
assert 113 == len(df_filled.drop_duplicates(["query_id", "data_source"]))

In [18]:
metrics = {item: [] for item in data_source_types}

for data_source in data_source_types:
    for agg_type in agg_types:
        if data_source == "all":
            metrics[data_source].append(calc_auc(df_filled, agg_type))
        else:
            df_cut = df_filled.query(f"data_source == '{data_source}'")
            metrics[data_source].append(calc_auc(df_cut, agg_type))

metrics_df = pd.DataFrame(metrics, index=agg_types).round(4)
metrics_df

Unnamed: 0,2019,2021,health_belief,misbelief,all
avg,0.526,0.5712,0.875,0.4286,0.5296
top1,0.5363,0.5376,0.5,0.4286,0.5167
norm_linear,0.5433,0.5856,0.875,0.4286,0.5517
norm_log,0.5294,0.5872,0.875,0.4286,0.5564
