In [20]:
# from elasticsearch import Elasticsearch
import pandas as pd
from sentence_transformers import CrossEncoder
from typing import Callable
import tqdm

In [3]:
es = Elasticsearch()
# es.info()
INDEX_NAME = "passage_index"
doc = es.get(index=INDEX_NAME, id=1)
print(doc)

{'_index': 'passage_index', '_type': '_doc', '_id': '1', '_version': 1, '_seq_no': 1, '_primary_term': 1, 'found': True, '_source': {'content': 'The Manhattan Project and its atomic bomb helped bring an end to World War II. Its legacy of peaceful uses of atomic energy continues to have an impact on history and science.'}}




In [5]:
queries_eval = pd.read_csv("data/queries/queries.eval.tsv", sep='\t', header=None)#, index_col=0)
queries_eval.head()

Unnamed: 0,0,1
0,786436,what is prescribed to treat thyroid storm
1,9,Refer to the data. Diminishing returns begin ...
2,786450,what is presentation software?
3,524308,treasury routing number
4,33,game called poem who wrote what occasion


In [6]:
collection_df = pd.read_csv("data/collection/collection.tsv", sep='\t', header=None) #, index_col=0)
print(len(collection_df))
collection_df.head()

6535846


Unnamed: 0,0,1
0,0,The presence of communication amid scientific ...
1,1,The Manhattan Project and its atomic bomb help...
2,2,Essay on The Manhattan Project - The Manhattan...
3,3,The Manhattan Project was the name for a proje...
4,4,versions of each volume as well as complementa...


In [9]:
qrelspath = "data/qrels/qrels.txt"

qrels_ids = set()
with open(qrelspath, encoding="utf-8") as file:
    for line in file:
        l = line.split(' ')
        qrels_ids.add(l[0])

In [10]:
queriesToUse = {}
for _, query in queries_eval.iterrows():
    if str(query[0]) in qrels_ids:
        queriesToUse[query[0]] = query[1]

In [8]:
query_topK = {}
for query_id, query in queriesToUse.items():
    # query = queries[idx]
    res = es.search(index=INDEX_NAME, q=query, _source=False, size=5000, request_timeout=60)
    top_k_scores = [hit["_id"] for hit in res["hits"]["hits"]]
    query_topK[str(query_id)] = top_k_scores



In [16]:
passages = collection_df[1].tolist()
passage_ids = collection_df[0].tolist()

# Lookups for passage ids
passage_lookup = {}
for i, passage in enumerate(passages):
    passage_lookup[passage_ids[i]] = passage
    
# Lookups for passages
passage_id_lookup = {}
for i, passage in enumerate(passages):
    passage_id_lookup[passage] = passage_ids[i]

In [24]:
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch

if not torch.cuda.is_available():
    print("Warning: No GPU found. Please add GPU to your notebook")

#We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
bi_encoder.max_seq_length = 256     #Truncate long passages to 256 tokens

#The bi-encoder will retrieve 100 documents. We use a cross-encoder, to re-rank the results list to improve the quality
cross_encoder = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-2-v2')

# We encode all passages into our vector space. This takes about 5 minutes (depends on your GPU speed)
corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True, batch_size=16)

OutOfMemoryError: CUDA out of memory. Tried to allocate 46.00 MiB (GPU 0; 6.00 GiB total capacity; 5.15 GiB already allocated; 0 bytes free; 5.29 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
# Save the corpus embeddings to disk
torch.save(corpus_embeddings, "corpus_embeddings.pt")

In [None]:
# Load the corpus embeddings from disk
corpus_embeddings = torch.load("corpus_embeddings.pt")

In [22]:
tinybert_results = {}

top_k = 5000    #Number of passages we want to retrieve with the bi-encoder

for query_id, query in tqdm.tqdm(queriesToUse.items()):
    
     ##### Sematic Search #####
    # Encode the query using the bi-encoder and find potentially relevant passages
    question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
    question_embedding = question_embedding.cuda()
    hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)

    ##### Re-Ranking #####
    # Now, score all retrieved passages with the cross_encoder
    cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
    cross_scores = cross_encoder.predict(cross_inp)

    # Sort results by the cross-encoder scores
    for idx in range(len(cross_scores)):
        hits[idx]['cross-score'] = cross_scores[idx]
    
    # Top 1000 bi-encoder results
    # top_hits = sorted(hits, key=lambda x: x['score'], reverse=True)
    
    # Top 1000 cross-encoder results
    top_cross = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
    
    tinybert_results[query_id] = [passage_id_lookup[passages[hit['corpus_id']]] for hit in top_cross]

  0%|          | 0/43 [13:59<?, ?it/s]


KeyboardInterrupt: 

In [23]:
model = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-2-v2', max_length=512)

ranked = {}

# Rank the top 1000 passages for each query
for query_id, doc_ids in query_topK.items():
    for doc_id in doc_ids:
        doc_text = es.get(index=INDEX_NAME, id=doc_id)["_source"]["content"]
        score = model.predict([(queriesToUse[int(query_id)], doc_text)])[0]
        if query_id not in ranked:
            ranked[query_id] = []
        ranked[query_id].append((doc_id, score))
    ranked[query_id] = sorted(ranked[query_id], key=lambda x: x[1], reverse=True)[:1000]

# # Rank the top 5000 passages for each query
# for i, (query_id, doc_ids) in enumerate(query_topK.items()):
#     print(f"Processed: {round(i/len(query_topK)*100,2)}%")
#     # print(query_id, doc_ids)
#     # print(es.get(index=INDEX_NAME, id=doc_ids[0])["_source"]["content"])
#     # print(queriesToUse[query_id])

#     l = [(queriesToUse[int(query_id)], es.get(index=INDEX_NAME, id=doc_id)["_source"]["content"]) for doc_id in doc_ids]
#     score = model.predict(l)
#     # print(score)
#     # print(len(score))
#     d = dict(zip(doc_ids, score))
#     ranked[query_id] = [k for k in sorted(d, key=d.get, reverse=True)][0:1000]
# #     break
# # print(ranked)

# Evaluation

In [24]:
# Bulk indexing
qrelspath = "data/qrels/qrels.txt"

qrels = {}
with open(qrelspath, encoding="utf-8") as file:
    for line in file:
        l = line.split(' ')

        qid = l[0]
        pid = l[2]
        relevance = int(l[3])

        if relevance > 0:
            if qid in qrels.keys():
                qrels[qid].add(pid)
            else:
                qrels[qid] = set([pid])

In [25]:
def get_average_precision(system_ranking, ground_truth) -> float:
    vals = []
    over = 1
    for rank_idx, rank in enumerate(system_ranking):
        under = rank_idx+1
        if rank in ground_truth:
            vals.append(over / under)
            over += 1
    AP = sum(vals) / len(ground_truth)

    return AP

In [26]:
system_ranking = query_topK["527433"] # List
system_truth = qrels["527433"] # Set
score = get_average_precision(system_ranking, system_truth)
score

0.04992514686242206

In [27]:
def get_reciprocal_rank(system_ranking, ground_truth) -> float:
    AP = 0
    for rank_idx, rank in enumerate(system_ranking):
        under = rank_idx+1
        if rank in ground_truth:
            AP = 1 / under
            break
    
    return AP

In [28]:
system_ranking = query_topK["527433"] # List
system_truth = qrels["527433"] # Set
score = get_reciprocal_rank(system_ranking, system_truth)
score

1.0

In [29]:
def get_mean_eval_measure(system_rankings, ground_truths, eval_function: Callable) -> float:
    results = []
    for query in system_rankings:
        if query in ground_truths.keys():
            results.append(eval_function(system_rankings[query], ground_truths[query]))
        else:
            continue
            # results.append(0) -> ?
    return sum(results) / len(results)

In [30]:
map = get_mean_eval_measure(query_topK, qrels, get_average_precision)
mrr = get_mean_eval_measure(query_topK, qrels, get_reciprocal_rank)

## From BM25
map = 0.32872575816078825

mrr = 0.7265016684853105

In [31]:
print(map)
print(mrr)

0.2099804764139658
0.6890299549512343
