In [12]:
import sys
import os
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import torch
from tqdm import tqdm
import numpy as np

os.environ['HF_HOME'] = '/scratch/' + str(open('../tokens/HPC_ACCOUNT_ID.txt', 'r').read())
cache_dir = '/scratch/' + str(open('../tokens/HPC_ACCOUNT_ID.txt', 'r').read()) + '/cache'

In [13]:
hf_api_key = ""
with open("../tokens/HF_TOKEN.txt", "r") as f:
    hf_api_key = f.read().strip()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [14]:
def classify_docs_per_distractor(row, sentence_bert_model):
    docs_per_choice = {}
    for key in ['Answer_A', 'Answer_B', 'Answer_C', 'Answer_D']:
        if row[key] != "":
                docs_per_choice[key + '_docs'] = []
                
    embeddings_choices = sentence_bert_model.encode([key for key in docs_per_choice.keys() if key.endswith('_docs')])
    embeddings_docs = sentence_bert_model.encode(row['relevant_docs_simple'])
    similarities = sentence_bert_model.similarity(embeddings_choices, embeddings_docs)
    
    # We now add each doc to the choice with the highest similarity
    for i, doc in enumerate(row['relevant_docs_simple']):
        max_sim_index = np.argmax(similarities[:, i])
        max_choice = list(docs_per_choice.keys())[max_sim_index]
        docs_per_choice[max_choice].append(doc)
    
    return docs_per_choice


In [15]:
sentence_bert_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
data = load_dataset("LeoZotos/immu_full", split='train', token = hf_api_key, cache_dir=cache_dir)

In [None]:
Answer_A_docs, Answer_B_docs, Answer_C_docs, Answer_D_docs = [], [], [], [] # each will be a column

for row in tqdm(data):
    docs_per_choice_for_row = classify_docs_per_distractor(row, sentence_bert_model)
    Answer_A_docs.append(docs_per_choice_for_row.get('Answer_A_docs', []))
    Answer_B_docs.append(docs_per_choice_for_row.get('Answer_B_docs', []))
    Answer_C_docs.append(docs_per_choice_for_row.get('Answer_C_docs', []))
    Answer_D_docs.append(docs_per_choice_for_row.get('Answer_D_docs', []))
    
data = data.add_column("Answer_A_docs", Answer_A_docs)
data = data.add_column("Answer_B_docs", Answer_B_docs)
data = data.add_column("Answer_C_docs", Answer_C_docs)
data = data.add_column("Answer_D_docs", Answer_D_docs)
# print 1st row of data, column 'Qu'
id = 14
print(data[0]['Question_with_options'], ":", "A:", data[id]['Answer_A_docs'], "B:", data[id]['Answer_B_docs'], "C:", data[id]['Answer_C_docs'], "D:", data[id]['Answer_D_docs'])
# export to csv in ../data/immu/all_full_docs.csv
# data.to_csv('../data/immu/all_full_docs.csv', index=False)


100%|██████████| 843/843 [00:13<00:00, 64.61it/s]


What has the LEAST effect on the value of R0 of SARS-CoV-2?
A) Outdoor temperature
B) Virulence of the virus
C) Herd immunity of the population
D) Physical distancing within the population
 : A: ['Volcanic winter The effects of recent volcanic eruptions on winters are modest in scale but historically their effects have been significant.', 'Field (mathematics) There exists an element 0 in R, such that for all elements a in R, the equation 0 + a = a + 0 = a holds.'] B: ['Basic reproduction number A very important number for describing whether a disease can become an epidemic or not is R0, pronounced "R naught" or "R zero". It refers to how many people a person who has this disease is expected to infect on average if there are no people immune to the disease. It is an abbreviation for basic reproduction number.', 'Basic reproduction number If R0 > 1, a disease can become an epidemic. If R0 < 1, it cannot. Most commonly known diseases have R0 > 1. However, vaccines can be used to make enou

Creating CSV from Arrow format: 100%|██████████| 1/1 [00:00<00:00,  1.96ba/s]


10136690

In [24]:
id = 22
print(data[id]['Question_with_options'], ":", "\n A:", data[id]['Answer_A_docs'], "\n B:", data[id]['Answer_B_docs'], "\n C:", data[id]['Answer_C_docs'], "\n D:", data[id]['Answer_D_docs'])

M cells are crucial for the initiation of mucosal immunity. Where can these M cells be found?
A) In Peyer's patches
B) In GALT and in between intestinal epithelial cells
C) In parts of the entire ileum that have contact with mesenteric lymph nodes
D) In lymph nodes
 : 
 A: ['Mucous membrane A mucous membrane (or mucosae; singular mucosa) is a skin-like lining. A mucus membrane is covered in epithelium. They secrete mucus, and in the alimentary canal they absorb nutrients. They line cavities that are exposed to the external environment and internal organs.', 'Thymus In the thymus, T cells or T lymphocytes mature. T cells are critical to the adaptive immune system, where the body adapts specifically to foreign invaders.', 'Immune system Vertebrates, including humans, have much more sophisticated defense mechanisms.  The innate immune system is found in all metazoa, but the adaptive immune system is only found in vertebrates.'] 
 B: ['Goblet cell Goblet cells are specialized cells found i