In [4]:
import pickle

model = "dunzhang/stella_en_400M_v5"
number = 5000
difficulty = "hard"

with open(f'embeddings/{model}/{difficulty}/{number}/df.pkl', 'rb') as file: hotpot_qa_df = pickle.load(file)
with open(f'embeddings/{model}/{difficulty}/{number}/contexts.pkl', 'rb') as file: contexts = pickle.load(file)

hotpot_qa_df['actual_contexts'] = hotpot_qa_df['actual_contexts'].apply(lambda x: [int(i) for i in x])
actual_contexts = hotpot_qa_df['actual_contexts'].tolist()

hotpot_qa_df.head()

Unnamed: 0,level,question,answer,actual_contexts
0,hard,"George Boscawen, 9th Viscount Falmouth is a fo...","the Guards Division, Foot Guards regiments","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]"
1,hard,When Vladimir Kashpur portrayed Baba Yaga she ...,trio of sisters,"[10, 11, 12, 13, 14, 15, 16, 17, 18, 19]"
2,hard,Which musician has a solo punk rock project: T...,"Frank Anthony Iero, Jr.","[20, 21, 22, 23, 24, 25, 26, 27, 28, 29]"
3,hard,A Disney voice actor has won which Emmy award?,Outstanding Supporting Actor,"[30, 31, 32, 33, 34, 35, 36, 37, 38, 39]"
4,hard,Which north-western suburb of Adelaide lies wi...,Birkenhead,"[40, 41, 42, 43, 44, 45, 46, 47, 48, 49]"


In [11]:
from beir.retrieval.evaluation import EvaluateRetrieval

def beir_evaluation(actual_contexts,results):
    actual_contexts_dict = {
        str(i): {str(doc_id): 1 for doc_id in context} for i, context in enumerate(actual_contexts)
    }
    results_dict = {
        str(i): {str(doc_id): rank + 1 for rank, doc_id in enumerate(result)} for i, result in enumerate(results)
    }

    ndcg, map_score, recall, precision = EvaluateRetrieval.evaluate(
        actual_contexts_dict, results_dict, k_values=[10]
    )

    print("recall:", recall)
    print("precision:", precision)
    #print("acc:", EvaluateRetrieval.evaluate_custom(actual_contexts_dict, results_dict, [10], metric="acc"))

    print()

    print("ndcg:", ndcg)
    print("map:", map_score)
    print("mrr:", EvaluateRetrieval.evaluate_custom(actual_contexts_dict, results_dict, [10], metric="mrr"))



In [15]:
import chromadb
from tqdm import tqdm

documents = []
ids = []
embeddings = []
metadatas = []

for context in contexts.values():
    documents.append(context.text)
    ids.append(context.id_)
    embeddings.append(context.embedding)
    metadatas.append({'caption': context.metadata['caption']})

chroma_client = chromadb.Client()

try:
    chroma_client.delete_collection("my_collection")
except Exception:
    print("creating collection ...")
collection = chroma_client.create_collection(name="my_collection",metadata={"hnsw:space": "l2",
                                                                            "hnsw:construction_ef":1000, # Controls the number of neighbours in the HNSW graph to explore when adding new vectors. 
                                                                            "hnsw:M":1000, # maximum number of neighbour connections a vector can have
                                                                            "hnsw:search_ef":1000, # number of neighbours in the HNSW graph to explore when searching.
                                                                            }) 

batch_size = 41666
for i in range(0, len(documents), batch_size):
    collection.add(documents=documents[i:i+batch_size],
                   ids=ids[i:i+batch_size],
                   embeddings=embeddings[i:i+batch_size],
                   metadatas=metadatas[i:i+batch_size])

retrieved_contexts = []
for question in tqdm(hotpot_qa_df['question'], desc="Retrieving contexts", unit="question"):

    question_embedding = question.embedding
    result = collection.query(
                                query_embeddings=[question_embedding], # Chroma will embed this for you
                                n_results=10 # how many results to return
                              )
    retrieved_contexts.append([int(node) for node in result["ids"][0]] )

beir_evaluation(actual_contexts,retrieved_contexts)

Retrieving contexts: 100%|██████████| 5000/5000 [01:07<00:00, 74.38question/s]

recall: {'Recall@10': 0.38786}
precision: {'P@10': 0.3846}

ndcg: {'NDCG@10': 0.32416}
map: {'MAP@10': 0.1784}
mrr: {'MRR@10': 0.37181}



