In [227]:
import pandas as pd
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer
from gensim.utils import simple_preprocess
from elasticsearch import Elasticsearch

# Load data
collections = pd.read_csv('collection.tar_1/collection.tsv', sep='\t', nrows=30000, header=None, names=['passage_id', 'passage_text'])

# Preprocessing


In [241]:
import pandas as pd
chunksize = 5000
chunks = pd.read_csv('./top1000.dev/top1000.dev', sep='\t', header=None, names=['query_id', 'passage_id', 'query', 'passage'], chunksize=chunksize)
top1000_dev_df = pd.DataFrame()

for chunk in chunks:
    filtered_chunk = chunk[chunk['passage_id'] < 10000]
    top1000_dev_df = pd.concat([top1000_dev_df, filtered_chunk])

    # If we have 20000 rows already, stop processing
#     if len(top1000_dev_df) >= 5000:
#         break
# if len(top1000_dev_df) > 5000:
#     top1000_dev_df = top1000_dev_df.iloc[:5000]

In [242]:
stop_words = set(stopwords.words('english'))
ps = PorterStemmer()

def preprocess(document):
    tokens = simple_preprocess(document)
    tokens = [ps.stem(token) for token in tokens if token not in stop_words]
    return ' '.join(tokens)

collections['passage_text_pro'] = [preprocess(document) for document in collections['passage_text']]

# Create Elasticsearch client


In [243]:
es = Elasticsearch([{'host': 'localhost', 'port': 9200, 'scheme': 'http'}])

# Delete the old index if it exists
if es.indices.exists(index='document_bm'):
    es.indices.delete(index='document_bm')

# Create a new index with BM25 settings
es.indices.create(index='document_bm', body={
    'settings': {
        'number_of_shards' : 1,
        "index" : {
            "similarity" : {
                "default" : {
                    "type" : "BM25",
                    "b" : 0,
                    "k1" : 0
                }
            }
        }
    },
    'mappings': {
        'properties': {
            'content': {
                'type': 'text',
                'analyzer': 'english'
            }
        }
    }
})

# Index documents


  es.indices.create(index='document_bm', body={


ObjectApiResponse({'acknowledged': True, 'shards_acknowledged': True, 'index': 'document_bm'})

In [244]:
from tqdm import tqdm
from elasticsearch.helpers import parallel_bulk
from collections import deque

actions = []

for i, passage_text in tqdm(enumerate(collections['passage_text_pro'])):
    action = {
        "_op_type": "index",
        "_index": "document_bm",
        "_id": i,
        "_source": {
            'content': passage_text,
        }
    }
    actions.append(action)

deque(parallel_bulk(es, actions, chunk_size=10000, thread_count=8, queue_size=8), maxlen=0)


30000it [00:00, 696188.56it/s]


deque([])

In [245]:
count = 0
fetched = []
relevant = []
for query in top1000_dev_df['query']:
    query_pro = preprocess(query)
    response = es.search(index='document_bm', body={
        'query': {
            'match': {
                'content': query_pro
            }
        }
    })
    fetched.append([int(hit['_id']) for hit in response['hits']['hits']])
    filtered_df = top1000_dev_df[top1000_dev_df['query'] == query]
    relevant.append(filtered_df["passage_id"].to_list())
    if count == 100:
        break
    count += 1

  response = es.search(index='document_bm', body={


In [246]:
print(fetched)
print(relevant)

[[10007, 10009, 10010, 10011, 10012, 10016, 2628, 2632, 1084, 3234], [10057, 10056, 10060, 10061, 10065, 730, 10058, 10059, 10062, 10063], [10056, 10057, 10058, 10059, 10060, 10061, 10062, 10063, 10064, 10065], [10056, 10057, 10058, 10059, 10060, 10061, 10062, 10063, 10064, 10065], [10056, 10057, 10058, 10059, 10060, 10061, 10062, 10063, 10064, 10065], [10093, 2324, 2214, 5054, 995, 3360, 3401, 3860, 5298, 6049], [10113, 10114, 2736], [10108, 10113, 10149, 5785, 5862, 2811, 3202, 10109, 10110, 10114], [10113, 10114, 2736], [10113, 10114, 10115, 3145, 10109, 10110, 10112, 10106, 10107, 10108], [10155, 10152, 10146, 10149, 2354, 2356, 2895, 10150, 10154, 2901], [2149, 5210, 5213, 3768, 1674, 4335, 472, 473, 478, 1286], [10195, 2161, 10186, 10188, 10189, 10190, 10193, 1939, 2163, 2164], [3281, 437, 2056, 4032, 5627, 10104, 10154, 10186, 10193, 749], [542, 991, 995, 3785, 4483, 5471, 5472, 5477, 10189, 2123], [10217, 10219, 5985, 10218, 10220, 10222, 10225, 3258, 10216, 5295], [1025, 1932,

In [247]:
def precision_recall_at_k(retrieved, relevant, k):
    retrieved_set = set(retrieved[:k])
    relevant_set = set(relevant)
    retrieved_relevant = retrieved_set.intersection(relevant_set)
    
    precision = len(retrieved_relevant) / len(retrieved_set)
    recall = len(retrieved_relevant) / len(relevant_set)
    
    return precision, recall

pk = [precision_recall_at_k(fetched[i], relevant[i], len(relevant)) for i in range(len(relevant))]
pk


[(0.1, 1.0),
 (0.1, 1.0),
 (0.3, 1.0),
 (0.3, 1.0),
 (0.3, 1.0),
 (0.1, 1.0),
 (0.6666666666666666, 0.6666666666666666),
 (0.1, 1.0),
 (0.6666666666666666, 0.6666666666666666),
 (0.1, 1.0),
 (0.1, 1.0),
 (0.0, 0.0),
 (0.1, 1.0),
 (0.1, 1.0),
 (0.1, 1.0),
 (0.1, 1.0),
 (0.1, 1.0),
 (0.0, 0.0),
 (0.1, 1.0),
 (0.1, 1.0),
 (0.1, 1.0),
 (0.1, 1.0),
 (0.1, 1.0),
 (0.1, 1.0),
 (0.1, 1.0),
 (0.1, 1.0),
 (0.1, 1.0),
 (0.1, 1.0),
 (0.1, 1.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.1, 1.0),
 (0.0, 0.0),
 (0.1, 1.0),
 (0.1, 1.0),
 (0.

In [248]:
def average_precision(retrieved, relevant):
    retrieved_set = set(retrieved)
    relevant_set = set(relevant)
    ap = 0.0
    for i in range(len(retrieved)):
        if retrieved[i] in relevant_set:
            precision, _ = precision_recall_at_k(retrieved, relevant, len(relevant))
            ap += precision
    ap /= len(relevant_set)
    return ap

def mean_average_precision(retrieved_list, relevant_list):
    ap_sum = 0.0
    for retrieved, relevant in zip(retrieved_list, relevant_list):
        ap_sum += average_precision(retrieved, relevant)
    map = ap_sum / len(retrieved_list)
    return map


meanAP = mean_average_precision(fetched, relevant)

In [249]:
from sklearn.metrics import ndcg_score

def compute_ndcg(retrieved, relevant):
    # Transform retrieved and relevant into binary relevance scores
    y_true = [1 if i in relevant else 0 for i in retrieved]
    y_score = [1 if i in relevant else 0 for i in retrieved]
    return ndcg_score([y_true], [y_score])

[compute_ndcg(fetched[i], relevant[i]) for i in range(len(fetched))]


[1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0]

In [250]:
from tabulate import tabulate


headers = ["Query no", "Precision", "Recall"]
data_with_ndgc = [(i, *data_point) for i, (data_point) in enumerate(pk)]
table = tabulate(data_with_ndgc, headers, tablefmt="grid")

print(table)
print("Mean AP : ", end = '')
print(meanAP)

+------------+-------------+----------+
|   Query no |   Precision |   Recall |
|          0 |    0.1      | 1        |
+------------+-------------+----------+
|          1 |    0.1      | 1        |
+------------+-------------+----------+
|          2 |    0.3      | 1        |
+------------+-------------+----------+
|          3 |    0.3      | 1        |
+------------+-------------+----------+
|          4 |    0.3      | 1        |
+------------+-------------+----------+
|          5 |    0.1      | 1        |
+------------+-------------+----------+
|          6 |    0.666667 | 0.666667 |
+------------+-------------+----------+
|          7 |    0.1      | 1        |
+------------+-------------+----------+
|          8 |    0.666667 | 0.666667 |
+------------+-------------+----------+
|          9 |    0.1      | 1        |
+------------+-------------+----------+
|         10 |    0.1      | 1        |
+------------+-------------+----------+
|         11 |    0        | 0        |


In [251]:
from tabulate import tabulate
from termcolor import colored

# Assuming you have your data in the 'headers' and 'data_with_ndgc' variables

# Define colors for header and data rows
header_color = 'yellow'
data_color = 'cyan'

# Colorize the headers
colored_headers = [colored(header, header_color) for header in headers]

# Colorize the data rows
colored_data = [
    [colored(str(cell), data_color) for cell in row]
    for row in data_with_ndgc
]

# Create the colorful table
table = tabulate(colored_data, colored_headers, tablefmt="pretty")

print(table)
print("Mean AP: " + colored(str(meanAP), 'red'))  # You can choose a color for the meanAP as well


+----------+--------------------+--------------------+
| Query no |     Precision      |       Recall       |
+----------+--------------------+--------------------+
|    0     |        0.1         |        1.0         |
|    1     |        0.1         |        1.0         |
|    2     |        0.3         |        1.0         |
|    3     |        0.3         |        1.0         |
|    4     |        0.3         |        1.0         |
|    5     |        0.1         |        1.0         |
|    6     | 0.6666666666666666 | 0.6666666666666666 |
|    7     |        0.1         |        1.0         |
|    8     | 0.6666666666666666 | 0.6666666666666666 |
|    9     |        0.1         |        1.0         |
|    10    |        0.1         |        1.0         |
|    11    |        0.0         |        0.0         |
|    12    |        0.1         |        1.0         |
|    13    |        0.1         |        1.0         |
|    14    |        0.1         |        1.0         |
|    15   