In [1]:
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk, parallel_bulk
import ir_measures
from ir_measures import *
import pandas as pd
import json
from tqdm import tqdm
from time import time
import warnings
warnings.filterwarnings("ignore")
from sentence_transformers import SentenceTransformer, util


# Connection

In [2]:
es = Elasticsearch(hosts='https://localhost:9200', 
                     basic_auth=('elastic', 'sYV-CgqebNRTw1e=L=pY'),
                     verify_certs=False)


## WikiIR

In [3]:
df = pd.read_csv('wikIR1k/documents.csv')


### Index Configuration

In [4]:
# Without stemming

mappings = {
    'properties': {
        'text': {
            'type': 'text',
            'analyzer': 'white'
        }
    }
}

settings = {
    "number_of_shards" : 5,
    'index' : {
        'similarity' : {
          'default' : {
            'type' : 'BM25'
          }
        }
    },
    'analysis' : {
        'analyzer' : {
            'white' : {
                'tokenizer' : 'whitespace'
            }
        }
    }
}

index = 'wiki'

if es.indices.exists(index=index):
    es.indices.delete(index=index)
es.indices.create(index=index, settings=settings, mappings=mappings)

ObjectApiResponse({'acknowledged': True, 'shards_acknowledged': True, 'index': 'wiki'})

In [5]:
# Indexing documents

def create_es_action(index, doc_id, document):
    return {
        '_index': index,
        '_id': doc_id,
        '_source': document
    }


def es_action_generator(df):
    for doc_id, row in tqdm(df.iterrows(), total=df.shape[0], bar_format='{l_bar}{bar:30}{r_bar}{bar:-10b}'):
        doc = {'text': row['text_right']}
        yield create_es_action(index, row['id_right'], doc)


start = time()
for ok, result in parallel_bulk(es, es_action_generator(df), queue_size=4, thread_count=4, chunk_size=1000):
    if not ok:
        print(result)
stop = time()
print('Indexing time:', stop - start)
        
es.indices.refresh(index=index)

100%|██████████████████████████████| 369721/369721 [00:35<00:00, 10531.45it/s]


Indexing time: 35.836485147476196


ObjectApiResponse({'_shards': {'total': 10, 'successful': 5, 'failed': 0}})

## Search

In [7]:
def pretty_print_result(search_result, fields=[]):
    res = search_result['hits']
    print(f'Total documents: {res["total"]["value"]}')
    for hit in res['hits']:
        print(f'Doc {hit["_id"]}, score is {hit["_score"]}')
        for field in fields:
            print(f'{field}: {hit["_source"][field]}')
    
def search(query, *args):
    return pretty_print_result(es.search(index=index, query=query, size=20), args)

def get_doc_by_id(doc_id):
    return es.get(index=index, id=doc_id)['_source']

## Queries

In [8]:
test_queries = pd.read_csv('wikIR1k/test/queries.csv')


In [9]:
def make_query(text):
    return {
        'bool': {
            
            'must': {
                'match': {
                    'text': text
                }
            }
        }
    }

search(make_query(test_queries['text_left'][0]))


Total documents: 10000
Doc 1880296, score is 17.230719
Doc 607552, score is 17.198406
Doc 2261272, score is 17.183655
Doc 1957435, score is 16.908918
Doc 625257, score is 16.856976
Doc 635537, score is 16.771313
Doc 1774491, score is 16.640131
Doc 663828, score is 16.487574
Doc 158491, score is 15.997955
Doc 1956922, score is 15.973572
Doc 1180246, score is 15.590252
Doc 1170039, score is 15.534702
Doc 945068, score is 15.526761
Doc 589549, score is 15.501228
Doc 360918, score is 15.501228
Doc 685181, score is 15.335788
Doc 2411344, score is 15.325968
Doc 1158969, score is 15.273922
Doc 1093529, score is 15.163386
Doc 742912, score is 15.109789


In [13]:
run = {
    str(row['id_left']): {
        hit['_id']: hit['_score']
        for hit in es.search(index=index, query=make_query(row['text_left']), size=20)['hits']['hits']
    }
    for _, row in test_queries.iterrows()
}

qrels = ir_measures.read_trec_qrels('wikIR1k/test/qrels')
ir_measures.calc_aggregate([P@10, P@20, MAP@20], qrels, run)


{P@20: 0.14800000000000005,
 AP@20: 0.14619425811737782,
 P@10: 0.20599999999999988}

In [16]:
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, util

def get_run_from_model(model, run):
    q_ids = list(run.keys())
    run_cosine = {}

    for q_id in tqdm(q_ids, total=len(q_ids), bar_format='{l_bar}{bar:30}{r_bar}{bar:-10b}'):
        query_text = test_queries.loc[test_queries['id_left'] == int(q_id), 'text_left'].iloc[0]
        query_embedding = model.encode(query_text)
        run_cosine[q_id] = {}

        docs_texts = [df.loc[df['id_right'] == int(doc_id), 'text_right'].iloc[0] for doc_id in run[q_id]]
        docs_embedding = model.encode(docs_texts)

        if len(docs_embedding) == 0:
            continue

        cos_sim = util.cos_sim(query_embedding, docs_embedding)[0]
        run_cosine[q_id] = {doc_id: cos_sim[i].item() for i, doc_id in enumerate(run[q_id])}

    return run_cosine



In [18]:
model = SentenceTransformer('msmarco-distilbert-cos-v5')
run_cosine = get_run_from_model(model, run)


  0%|                              | 0/100 [00:00<?, ?it/s][A
  1%|▎                             | 1/100 [00:11<18:46, 11.38s/it][A
  2%|▌                             | 2/100 [00:24<19:15, 11.80s/it][A
  3%|▉                             | 3/100 [00:34<18:20, 11.35s/it][A
  4%|█▏                            | 4/100 [00:44<17:41, 11.06s/it][A
  5%|█▌                            | 5/100 [00:55<17:25, 11.00s/it][A
  6%|█▊                            | 6/100 [01:08<17:52, 11.40s/it][A
  7%|██                            | 7/100 [01:19<17:42, 11.43s/it][A
  8%|██▍                           | 8/100 [01:31<17:53, 11.67s/it][A
  9%|██▋                           | 9/100 [01:43<17:50, 11.77s/it][A
 10%|███                           | 10/100 [01:57<18:43, 12.48s/it][A
 11%|███▎                          | 11/100 [02:09<18:16, 12.32s/it][A
 12%|███▌                          | 12/100 [02:20<17:31, 11.95s/it][A
 13%|███▉                          | 13/100 [02:33<17:27, 12.04s/it][A
 14%|████

In [19]:
qrels = ir_measures.read_trec_qrels('wikIR1k/test/qrels')
ir_measures.calc_aggregate([P@10, P@20, MAP@20], qrels, run_cosine)

{P@20: 0.14800000000000005,
 AP@20: 0.17025855229129153,
 P@10: 0.2339999999999999}