## Source: PubMed
## Retriever: BioMed Explorer
## Model: Roberta Large BoolQ

In [1]:
import torch
from sklearn.metrics import classification_report, roc_auc_score

import sys
sys.path.append("../../") # use utils

import utils
import importlib
importlib.reload(utils)

from utils import calc_auc
import pandas as pd
from progressbar import progressbar as pb

In [2]:
from transformers import AutoTokenizer, RobertaForSequenceClassification

model_name = "apugachev/roberta-large-boolq-finetuned"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = RobertaForSequenceClassification.from_pretrained(model_name)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()
print("OK")

OK


In [4]:
def predict(question, passage, device=device):
    sequence = tokenizer.encode_plus(
        question, 
        passage, 
        return_tensors="pt",
        max_length=tokenizer.model_max_length,
        truncation=True
    )['input_ids']

    logits = model(sequence.to(device))[0]
    probabilities = torch.softmax(logits, dim=1).detach().cpu().tolist()[0]
    proba_yes = probabilities[1]
    
    return proba_yes

In [2]:
init_data = pd.read_csv("../../../data/data_to_process.csv")
print(len(init_data))
init_data.head(3)

113


Unnamed: 0,data_source,query_id,description,query,label
0,2019,1,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0
1,2019,3,Can acupuncture be effective for people with e...,acupuncture epilepsy,0.0
2,2019,5,Can acupuncture prevent migraines?,acupuncture migraine,1.0


## Keywords

In [3]:
df = (
    pd.read_csv("../../../data/biomed_explorer/biomed_explorer_abstracts_keywords.csv")
    .dropna()
)


print(len(df))
df.head(1)

1066


Unnamed: 0,query_id,data_source,pubmed_id,abstract,description,query,label,query_processed
0,1,2019,23396043,Despite considerable controversy about their e...,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Can cranberries prevent urinary tract infections?


In [7]:
results = []

for _, row in pb(df.iterrows(), max_value=len(df)):
    results.append(predict(row.description, row.abstract))

100% (1929 of 1929) |####################| Elapsed Time: 0:01:06 Time:  0:01:06


In [8]:
df['prediction'] = results
df.head(1)

Unnamed: 0,query_id,data_source,pubmed_id,abstract,description,query,label,query_processed,prediction
0,1,2019,23396043,Despite considerable controversy about their e...,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Can cranberries prevent urinary tract infections?,0.99987


In [4]:
KEYWORDS_SAVE_PRED_FILE = "predictions/roberta_large_boolq_biomed_explorer_preds_keywords.csv"

df.to_csv(KEYWORDS_SAVE_PRED_FILE, index=0)

## Calc Metrics

In [5]:
df = pd.read_csv(KEYWORDS_SAVE_PRED_FILE)
print(len(df))

1066


In [6]:
agg_types = ["avg", "top1", "norm_linear", "norm_log"]
data_source_types = sorted(df.data_source.unique().tolist()) + ["all"]
data_source_types

['2019', '2021', 'health_belief', 'misbelief', 'all']

In [7]:
df_filled = (
    df
    .merge(init_data, how='outer', on=['query_id', 'data_source', "label"])
    .fillna({"prediction": 0.5})
)
print(len(df_filled))
df_filled.tail(2)

1067


Unnamed: 0,query_id,data_source,pubmed_id,abstract,description_x,query_x,label,query_processed,prediction,description_y,query_y
1065,14_h,health_belief,16351605.0,"Listeria monocytogenes is a Gram-positive, wea...",Does listeria cause miscarriage?,listeria causes miscarriage,1.0,Does listeria cause miscarriage?,0.98257,Does listeria cause miscarriage?,listeria causes miscarriage
1066,202,misbelief,,,,,0.0,,0.5,Can hemorrhoids be cured with leeches?,hemorrhoids cured leeches


In [8]:
assert 113 == len(df_filled.drop_duplicates(["query_id", "data_source"]))

In [9]:
metrics = {item: [] for item in data_source_types}

for data_source in data_source_types:
    for agg_type in agg_types:
        if data_source == "all":
            metrics[data_source].append(calc_auc(df_filled, agg_type))
        else:
            try:
                df_cut = df_filled.query(f"data_source == '{data_source}'")
                metrics[data_source].append(calc_auc(df_cut, agg_type))
            except ValueError:
                print(f"Can't calc auc for {data_source} {agg_type}")
            
for key, value in metrics.items():
    if not value:
        metrics[key] = [None] * 4

metrics_df = pd.DataFrame(metrics, index=agg_types).round(4)
metrics_df

Unnamed: 0,2019,2021,health_belief,misbelief,all
avg,0.7197,0.8688,0.8333,0.8214,0.753
top1,0.6332,0.8448,0.625,0.6607,0.7163
norm_linear,0.6713,0.8752,0.7917,0.7679,0.7416
norm_log,0.6678,0.8768,0.7917,0.7679,0.7423


## Question

In [10]:
df = (
  pd.read_csv("../../../data/biomed_explorer/biomed_explorer_abstracts_question.csv")
  .dropna()
)

print(len(df))
df.head(1)

1063


Unnamed: 0,query_id,data_source,pubmed_id,abstract,description,query,label,query_processed
0,1,2019,19219097,Background: Cranberries have been used for pre...,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Can cranberries prevent urinary tract infections?


In [16]:
results = []

for _, row in pb(df.iterrows(), max_value=len(df)):
    results.append(predict(row.description, row.abstract))

100% (1906 of 1906) |####################| Elapsed Time: 0:01:05 Time:  0:01:05


In [17]:
df['prediction'] = results
df.head(1)

Unnamed: 0,query_id,data_source,pubmed_id,abstract,description,query,label,query_processed,prediction
0,1,2019,19219097,Background: Cranberries have been used for pre...,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Can cranberries prevent urinary tract infections?,0.999636


In [11]:
QUESTION_SAVE_PRED_FILE = "predictions/roberta_large_boolq_biomed_explorer_preds_question.csv"

df.to_csv(QUESTION_SAVE_PRED_FILE, index=0)

# Calc metrics

In [12]:
df = pd.read_csv(QUESTION_SAVE_PRED_FILE)
print(len(df))
df.head(2)

1063


Unnamed: 0,query_id,data_source,pubmed_id,abstract,description,query,label,query_processed,prediction
0,1,2019,19219097,Background: Cranberries have been used for pre...,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Can cranberries prevent urinary tract infections?,0.999636
1,1,2019,28288837,Purpose: We sought to clarify the association ...,Can cranberries prevent urinary tract infections?,cranberries urinary tract infections,0.0,Can cranberries prevent urinary tract infections?,0.999668


In [13]:
df_filled = (
    df
    .merge(init_data, how='outer', on=['query_id', 'data_source', "label"])
    .fillna({"prediction": 0.5})
)
print(len(df_filled))
df_filled.tail(2)

1064


Unnamed: 0,query_id,data_source,pubmed_id,abstract,description_x,query_x,label,query_processed,prediction,description_y,query_y
1062,14_h,health_belief,25681385.0,Recurrent miscarriage is frustrating for the p...,Does listeria cause miscarriage?,listeria causes miscarriage,1.0,Does listeria cause miscarriage?,0.00059,Does listeria cause miscarriage?,listeria causes miscarriage
1063,202,misbelief,,,,,0.0,,0.5,Can hemorrhoids be cured with leeches?,hemorrhoids cured leeches


In [14]:
assert 113 == len(df_filled.drop_duplicates(["query_id", "data_source"]))

In [15]:
metrics = {item: [] for item in data_source_types}

for data_source in data_source_types:
    for agg_type in agg_types:
        if data_source == "all":
            metrics[data_source].append(calc_auc(df_filled, agg_type))
        else:
            try:
                df_cut = df_filled.query(f"data_source == '{data_source}'")
                metrics[data_source].append(calc_auc(df_cut, agg_type))
            except ValueError:
                print(f"Can't calc auc for {data_source} {agg_type}")
            
for key, value in metrics.items():
    if not value:
        metrics[key] = [None] * 4

metrics_df = pd.DataFrame(metrics, index=agg_types).round(4)
metrics_df

Unnamed: 0,2019,2021,health_belief,misbelief,all
avg,0.7301,0.8816,0.9583,0.7857,0.7663
top1,0.6713,0.8368,0.8333,0.6964,0.7385
norm_linear,0.699,0.8784,0.9583,0.7857,0.7619
norm_log,0.7024,0.8864,0.9583,0.7857,0.7676
