In [None]:
import os
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import torch
from tqdm import tqdm
import numpy as np
from scipy.stats import pearsonr, spearmanr
import pandas as pd

os.environ['HF_HOME'] = '/scratch/' + str(open('../tokens/HPC_ACCOUNT_ID.txt', 'r').read())
cache_dir = '/scratch/' + str(open('../tokens/HPC_ACCOUNT_ID.txt', 'r').read()) + '/cache'
hf_api_key = ""
with open("../tokens/HF_TOKEN.txt", "r") as f:
    hf_api_key = f.read().strip()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

sentence_bert_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

In [None]:
DATASET = "LeoZotos/bio_full"
WIKI = "Simple" # 'En' or 'Simple'
SOURCE_TEXT = "_Only_Options"  # '_Only_Options' # or '' for full text
NUM_DOCS_RETRIEVED = 60 # 20 or 60

HAS_CONTENT_DISTRACTORS = 2 # 0 to 2, -1 for any
SHORT_OPTIONS_THRESHOLD = 10000 # >500 for all options, otherwise 20 or so for short options

RETRIEVED_DOCS_COL_NAME = 'Relevant_Docs_' + WIKI + SOURCE_TEXT + '_' + str(NUM_DOCS_RETRIEVED)

In [None]:
def is_nan_or_none(x):
    return x is None or (isinstance(x, float) and np.isnan(x))

def has_short_options(list_of_options, threshold=20):
    non_empty_options = [option for option in list_of_options if option and not is_nan_or_none(option)]
    avg_length = sum(len(option) for option in non_empty_options) / len(non_empty_options)
    return avg_length < threshold

data = load_dataset(DATASET, split='train', token=hf_api_key, cache_dir=cache_dir)
print("Before filtering, dataset size:", len(data))
data = data.filter(lambda x: 
    not is_nan_or_none(x['Answer_A_Rate']) and
    not is_nan_or_none(x['Answer_B_Rate']) and
    not is_nan_or_none(x['Answer_C_Rate']) and
    not is_nan_or_none(x['Answer_D_Rate']) 
    )

if HAS_CONTENT_DISTRACTORS in [0, 1, 2]:
    data = data.filter(lambda x: x['Has_Content_Distractors'] == HAS_CONTENT_DISTRACTORS)

# Filter out entries based on length of options
if SHORT_OPTIONS_THRESHOLD < 500:
    data = data.filter(lambda x: not has_short_options([x['Answer_A'], x['Answer_B'], x['Answer_C'], x['Answer_D']], SHORT_OPTIONS_THRESHOLD))

print("After filtering, dataset size:", len(data))

# Classify retrieved docs per choice

In [None]:
def print_similarities(similarities, sentences1, sentences2):
    for idx_i, sentence1 in enumerate(sentences1):
        print(sentence1)
        for idx_j, sentence2 in enumerate(sentences2):
            print(f" - {sentence2: <30}: {similarities[idx_i][idx_j]:.4f}")


def classify_docs_per_distractor(row, sentence_bert_model):
    docs_per_choice = {}
    for key in ['Answer_A', 'Answer_B', 'Answer_C', 'Answer_D']:
        if row[key] != "":
                docs_per_choice[key + '_Docs'] = []
    choices_keys = [key[:-5] for key in docs_per_choice.keys()]
    choices_content = [row[key] for key in choices_keys if row[key] != ""]
    embeddings_choices = sentence_bert_model.encode(choices_content)
    embeddings_docs = sentence_bert_model.encode(row[RETRIEVED_DOCS_COL_NAME])
    similarities = sentence_bert_model.similarity(embeddings_choices, embeddings_docs)

    # We now add each doc to the choice with the highest similarity
    for i, doc in enumerate(row[RETRIEVED_DOCS_COL_NAME]):
        max_sim_index = np.argmax(similarities[:, i])
        max_choice = list(docs_per_choice.keys())[max_sim_index]
        docs_per_choice[max_choice].append(doc)
    
    return docs_per_choice

In [None]:
column_names = [f"Answer_{choice}_Docs" for choice in ['A', 'B', 'C', 'D']]
docs_by_choice = {name: [] for name in column_names}

for row in tqdm(data):
    docs_per_choice_for_row = classify_docs_per_distractor(row, sentence_bert_model)
    for name in column_names:
        docs_by_choice[name].append(docs_per_choice_for_row.get(name, []))
        
if column_names[0] in data.column_names:
    data = data.remove_columns(column_names)
    
for name, column_data in docs_by_choice.items():
    data = data.add_column(name, column_data)

In [None]:
# Inspect an instance manually to see if it makes sense
id = 22
print(data[id]['Question_With_Options'], ":", "\n A:", data[id]['Answer_A_Docs'], "\n B:", data[id]['Answer_B_Docs'], "\n C:", data[id]['Answer_C_Docs'], "\n D:", data[id]['Answer_D_Docs'])

# Calculate Correlation

In [None]:
def calc_correlation(type='pearson'):
    simple_list = []
    print(type.capitalize(), "correlation between distractor rates and document lengths(A-D, p-values in between):")
    correlations_with_docs_len = {}

    for choice_name in [f"Answer_{choice}" for choice in ['A', 'B', 'C', 'D']]:
        rates = data[f'{choice_name}_Rate']
        doc_lengths = [len(sentence_list) for sentence_list in data[f'{choice_name}_Docs']]
        correlation, p = None, None
        # count nones in rates and doc_lengths
        print(f"Choice: {choice_name}, Nones in rates: {sum(is_nan_or_none(rate) for rate in rates)}, Nones in doc_lengths: {sum(is_nan_or_none(length) for length in doc_lengths)}")
        
        # if rates has nones ,continue
        if any(is_nan_or_none(rate) for rate in rates):
            correlation, p = 0, 0
        else:
            if type == 'pearson':
                correlation, p = pearsonr(rates, doc_lengths)
            elif type == 'spearman':
                correlation, p = spearmanr(rates, doc_lengths)
        
        correlations_with_docs_len[choice_name] = (round(correlation,4), round(p,4))
        simple_list.append(float(round(correlation,4)))
        simple_list.append(float(round(p,4)))
        
    correlations_string = "\t".join(
        [f"{str(correlation)} {str(p)}"
         for (correlation, p) in correlations_with_docs_len.values()]
    )  
    print(correlations_string)
        
    return simple_list

In [None]:
all_results = []
all_results.extend(calc_correlation('pearson'))
all_results.extend(calc_correlation('spearman'))

# export to csv
df = pd.DataFrame([all_results], columns=[
    'Pearson_A_Correlation', 'Pearson_A_p', 
    'Pearson_B_Correlation', 'Pearson_B_p', 
    'Pearson_C_Correlation', 'Pearson_C_p', 
    'Pearson_D_Correlation', 'Pearson_D_p',
    'Spearman_A_Correlation', 'Spearman_A_p', 
    'Spearman_B_Correlation', 'Spearman_B_p', 
    'Spearman_C_Correlation', 'Spearman_C_p', 
    'Spearman_D_Correlation', 'Spearman_D_p'
])
df.to_csv('last_correlations.csv', index=False)