In [None]:
import os
from datasets import load_dataset
import torch
from tqdm import tqdm

os.environ['HF_HOME'] = '/scratch/' + str(open('../tokens/HPC_ACCOUNT_ID.txt', 'r').read())
cache_dir = '/scratch/' + str(open('../tokens/HPC_ACCOUNT_ID.txt', 'r').read()) + '/cache'

In [None]:
DATASET = "LeoZotos/immu_full"

In [None]:
hf_api_key = ""
with open("../tokens/HF_TOKEN.txt", "r") as f:
    hf_api_key = f.read().strip()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def retrieve_relevant_docs(question_set, passages, batch_size=10000): 
    """Retrieves the top k relevant documents for each question. Reduce batch_size if out-of-memory/process is killed"""
    question_embs = torch.tensor(question_set['emb'], dtype=torch.bfloat16).to(device)
    
    all_scores = torch.tensor([[] for _ in range(len(question_set))], dtype=torch.bfloat16).to(device)
    all_doc = [[] for _ in range(len(question_set))]

    # Do retrieval. Reduce batch_size if out-of-memory
    tmp_doc = []
    tmp_emb = []
    for passage_id, passage in enumerate(tqdm(passages)):
        tmp_emb.append(passage['emb'])
        tmp_doc.append({"title": passage['title'], "text": passage['text']})

        if ((passage_id+1) % batch_size == 0) or (passage_id+1) == len(passages):
            passage_emb = torch.tensor(tmp_emb, dtype=torch.bfloat16).to(device)
            dot_scores = torch.mm(question_embs, passage_emb.T)
            all_scores = torch.cat((all_scores, dot_scores), 1) 
            all_doc = [i + tmp_doc for i in all_doc]
            all_scores, top_k_hits = torch.topk(all_scores, 20)
            all_doc = [[all_doc[idx][j] for j in i] for idx, i in enumerate(top_k_hits)]

            tmp_doc = []
            tmp_emb = []

    relevant_docs_combined = []
    for i in range(len(all_doc)):
        relevant_docs_combined.append([doc['title'] + " " + doc['text'] for doc in all_doc[i]])
    
    return relevant_docs_combined

In [None]:
question_set = load_dataset(DATASET, split='train', token = hf_api_key, cache_dir=cache_dir)

passages = load_dataset("Cohere/wikipedia-2023-11-embed-multilingual-v3", 'simple', split="train", cache_dir=cache_dir, token=hf_api_key)

relevant_docs = retrieve_relevant_docs(question_set, passages)


In [None]:
# delete relevant_docs_simple if "relevant_docs_simple" in question_set.column_names:
if "Relevant_Docs_Simple" in question_set.column_names:
    question_set = question_set.remove_columns("Relevant_Docs_Simple")

question_set = question_set.add_column("Relevant_Docs_Simple", relevant_docs)

# upload to hf
question_set.push_to_hub(
    repo_id=DATASET,
    commit_message="Added relevant documents from Wiki Simple",
    token=hf_api_key,
    private=True
)

In [None]:
print(question_set['Question_With_Options'][15])
print("-----")
print(question_set['Relevant_Docs_Simple'][15])