In [1]:
from elasticsearch import Elasticsearch
from pprint import pprint
import pandas as pd
import numpy as np
import string
PUNCTUATIONS = string.punctuation
from typing import Callable
import nltk
nltk.download("stopwords")
STOPWORDS = set(nltk.corpus.stopwords.words("english"))

In [39]:
import nltk
nltk.download("stopwords")
STOPWORDS = set(nltk.corpus.stopwords.words("english"))

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\Magnus\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


### Settings for ElasticSearch

In [3]:
INDEX_NAME = "passage_index"
INDEX_SETTINGS = {
    'settings': {
        'index': {
            'number_of_shards': 1,
            'number_of_replicas': 1,
            'similarity': {
                'default': {
                    'type': 'BM25'
                }
            }
        },
        "analysis": {
            "analyzer": {
                "my_english_analyzer": {
                    "type": "custom",
                    "tokenizer": "standard",
                    "stopwords": "_english_",
                    "filter": [
                        "lowercase",
                        "english_stop",
                        "filter_english_minimal"
                    ]                
                }
            },
            "filter" : {
                "filter_english_minimal" : {
                    "type": "stemmer",
                    "name": "minimal_english"
                },
                "english_stop": {
                    "type": "stop",
                    "stopwords": "_english_"
                }
            },
        }
    }
}

### Create ElasticSearch object

In [6]:
es = Elasticsearch()
es.info()



{'name': 'DESKTOP-4KEQLR4',
 'cluster_name': 'elasticsearch',
 'cluster_uuid': 'qD7n4NS8S-SNKCPLlHHbhA',
 'version': {'number': '7.17.6',
  'build_flavor': 'default',
  'build_type': 'zip',
  'build_hash': 'f65e9d338dc1d07b642e14a27f338990148ee5b6',
  'build_date': '2022-08-23T11:08:48.893373482Z',
  'build_snapshot': False,
  'lucene_version': '8.11.1',
  'minimum_wire_compatibility_version': '6.8.0',
  'minimum_index_compatibility_version': '6.0.0-beta1'},
 'tagline': 'You Know, for Search'}

### Create indexes for the passages

In [8]:
if es.indices.exists(INDEX_NAME):
    es.indices.delete(index=INDEX_NAME)

es.indices.create(index=INDEX_NAME, body=INDEX_SETTINGS)

  if es.indices.exists(INDEX_NAME):
  es.indices.create(index=INDEX_NAME, body=INDEX_SETTINGS)


{'acknowledged': True, 'shards_acknowledged': True, 'index': 'passage_index'}

### Add documents to the index

In [9]:
# Bulk indexing
filename = "data/collection/collection.tsv"

bulk_data = []
bulk_size = 50000 # However many documents can be stored in memory
with open(filename, encoding="utf-8") as file:
    for line in file:
        if len(bulk_data) > bulk_size:
            es.bulk(index=INDEX_NAME, body=bulk_data, refresh=True, request_timeout=60)
            bulk_data.clear()

        l = line.split('\t')
        docid = int(l[0])
        text = l[1].strip()

        doc = {"doc_id": docid, "content": text}

        bulk_data.append({"index": {"_index": INDEX_NAME, "_id": doc.pop("doc_id")}})
        bulk_data.append(doc)
    
    es.bulk(index=INDEX_NAME, body=bulk_data, refresh=True, request_timeout=60)
    bulk_data.clear()

In [None]:
doc = es.get(index=INDEX_NAME, id=1)
pprint(doc)

### Search in the document

In [None]:
query = "atomic bomb"
res = es.search(index=INDEX_NAME, q=query, _source=False, size=10, request_timeout=60)

In [None]:
for hit in res["hits"]["hits"]:
    print("Doc ID: %3r  Score: %5.2f" % (hit["_id"], hit["_score"]))

In [None]:
top_k_scores = [hit["_id"] for hit in res["hits"]["hits"]]
top_k_scores

### Make QRELS evaluation-able
- "query_id": ["doc_id1", "doc_id2"...] -> Relevant ones, ground truth, Set() in this case

In [24]:
qrels_dev_df = pd.read_csv("data/qrels.dev.tsv", sep='\t', header=None)
qrels_dev_df.head(10)

Unnamed: 0,0,1,2,3
0,1102432,0,2026790,1
1,1102431,0,7066866,1
2,1102431,0,7066867,1
3,1090282,0,7066900,1
4,39449,0,7066905,1
5,76162,0,7066915,1
6,195512,0,7066971,1
7,1090280,0,7067004,1
8,331318,0,5309290,1
9,300674,0,7067032,1


In [29]:
len(qrels_dev_df.iloc[:, 0].unique())

55578

In [61]:
qrels = np.array(qrels_dev_df)[:, [0, 2]]
qrels[0]

array([1102432, 2026790], dtype=int64)

In [64]:
qrels_score = {}
for qrel in qrels:
    qid = str(qrel[0])
    pid = str(qrel[1])

    if qid in qrels_score.keys():
        qrels_score[qid].add(pid)
    else:
        qrels_score[qid] = set([pid])

In [65]:
qrels_score["1048578"]

{'7187234'}

In [66]:
unique_qrels = qrels_dev_df.iloc[:, 0].unique()

### Read queries

In [10]:
queries_dev_df = pd.read_csv("data/queries/queries.dev.tsv", sep='\t', header=None)

In [11]:
queries_id = np.array(queries_dev_df.iloc[:, 0])
queries = np.array(queries_dev_df.iloc[:, -1])
print(queries_id[0])
print(queries[0])

1048578
cost of endless pools/swim spa


In [49]:
def relevant_queries(queries, qrels):
    relevant_queries = []
    relevant_queries_id = []

    for idx, query in enumerate(queries):
        query_id = queries_id[idx]
        if int(query_id) in qrels:
            relevant_queries.append(query)
            relevant_queries_id.append(query_id)

    return relevant_queries, relevant_queries_id

In [50]:
# Keep only queries in the QRELS
queries, queries_id = relevant_queries(queries, unique_qrels)

In [52]:
def tokenize(queries):
    tokenized_queries = []

    for doc in queries:
        # Remove specific punctuations
        for punctuation in PUNCTUATIONS:
            doc = doc.replace(punctuation, " ")

        # Get only the words, not the whitespace
        words = [word for word in doc.split(" ") if word]

        # Remove specific stopwords
        words = [word for word in words if word not in STOPWORDS]

        # Add to the list of tokenized docs
        tokenized_queries.append(words)

    return tokenized_queries

In [53]:
# Tokenize the queries
tokenized_queries = tokenize(queries)

In [54]:
tokenized_queries[0]

['cost', 'endless', 'pools', 'swim', 'spa']

### Non bulk query search

In [56]:
# Non-bulk
query_topK = {}
for idx, query_id in enumerate(queries_id):
    if idx > 10:
        break

    query = tokenized_queries[idx]
    res = es.search(index=INDEX_NAME, q=query, _source=False, size=1000, request_timeout=60)
    top_k_scores = [hit["_id"] for hit in res["hits"]["hits"]]
    query_topK[str(query_id)] = top_k_scores

In [57]:
for qid, passage in query_topK.items():
    print(qid)
    print(passage)
    break

1048578
['7187236', '7471198', '5365326', '7187241', '6802210', '6750054', '7187239', '7187242', '5365329', '5365325', '1543821', '6794083', '5365323', '5365328', '5478742', '7187234', '5365324', '1543826', '5365322', '5989132', '8105762', '6802216', '2078221', '5948179', '8393323', '1139147', '4920368', '6802217', '2078215', '1139144', '6802211', '2833851', '5363468', '4332300', '7471199', '5177635', '4981275', '7704720', '3111290', '5363466', '3932264', '1890009', '8763665', '5989130', '4332303', '4615607', '6802214', '7471197', '3370397', '8363353', '7187235', '8365147', '8522049', '869511', '7471201', '1139146', '5989131', '6599701', '3111289', '1475488', '7307988', '6802215', '5512955', '7471204', '4332307', '5989136', '6729238', '326620', '7286588', '1139145', '3124007', '675755', '4332304', '6317753', '8022138', '5989134', '6794078', '6802212', '5365327', '8314418', '7676829', '8363351', '7471200', '6802208', '1139149', '6794079', '1295055', '290329', '2175059', '8363350', '3965

### Get Mean Average Precision for the system
- Precision score for each query

In [67]:
def get_average_precision(system_ranking, ground_truth) -> float:
    vals = []
    over = 1
    for rank_idx, rank in enumerate(system_ranking):
        under = rank_idx+1
        if rank in ground_truth:
            vals.append(over / under)
            over += 1
    AP = sum(vals) / len(ground_truth)

    return AP

In [68]:
system_ranking = query_topK["1048582"] # List
system_truth = qrels_score["1048582"] # Set
score = get_average_precision(system_ranking, system_truth)
score

1.0

In [69]:
def get_mean_eval_measure(system_rankings, ground_truths, eval_function: Callable) -> float:
    results = []
    for query in system_rankings:
        if query in ground_truths.keys():
            results.append(eval_function(system_rankings[query], ground_truths[query]))
        else:
            continue
            # results.append(0) -> ?
    return sum(results) / len(results)

In [70]:
avg_score = get_mean_eval_measure(query_topK, qrels_score, get_average_precision)

In [71]:
avg_score

0.4039270914270915