In [1]:
import tempfile
import sys
import subprocess
import shutil
from os.path import join
from elasticsearch import Elasticsearch, helpers
from nir.utils import create_filter_query_function, change_bm25_parameters
from mmnrm.text import TREC_goldstandard_transform, TREC_queries_transform, TREC_results_transform
from mmnrm.dataset import TestCollectionV2
from mmnrm.evaluation import TREC_Evaluator
import pickle

import json

es = Elasticsearch(["http://193.136.175.98:8125"])

index_name = "trec-pm-2020-synonym"


In [2]:
def load_TREC_qrels(q_rels_file):
    
    with open(q_rels_file) as f:
        goldstandard = defaultdict(list)

        for line in f:
            line = line.strip().split(" ")
            try:
                goldstandard[line[0]].append((line[2], line[3]))
            except :
                print(line)
            
    return TREC_goldstandard_transform(goldstandard)


import xmltodict
import ctypes
from collections import defaultdict

def load_TREC_topics(topics_file):
    with open(topics_file) as f:
        xml_dict=xmltodict.parse(f.read())["topics"]["topic"]

    topics_json = []

    for topic in xml_dict:
        topics_json.append({"id":topic["@number"],
                           "disease":topic["disease"],
                           "gene":topic["gene"],
                           "demographic":topic["demographic"]})

    return TREC_queries_transform(topics_json, number_parameter="id", fn=lambda x:x["disease"]+" "+x["gene"])


In [3]:
topics = load_TREC_topics("topics2019.xml")
#qrels = load_TREC_qrels("qrels-treceval-abstracts.2019.txt")

In [4]:
def execute_queries(queries, fields, analyzer, top_k=1000):
    filter_query_string = create_filter_query_function()

    documents = {}

    for j,query_data in enumerate(queries):

        if not j%10:
            print(j, end="\r")

        query = filter_query_string(query_data["query"])
        query_es = {
                  "query": {
                    "bool": {
                      "must": [
                        {
                          "query_string": {
                            "query": query, 
                            "analyzer": analyzer,
                            "fields": fields
                          }
                        }
                      ], 
                      "filter": [], 
                      "should": [], 
                      "must_not": []
                    }
                  }
                }



        retrieved = es.search(index=index_name, body=query_es, size=top_k, request_timeout=200)

        documents[query_data["id"]] = list(map(lambda x:{"id":x["_source"]["id"], 
                                                         "text":x["_source"]["text"],
                                                         "title":x["_source"]["title"],
                                                         "score":x['_score']}, retrieved['hits']['hits']))

        # just to ensure the elastic search order is mantained
        validate_order = lambda x:all(x[i] >= x[i+1] for i in range(len(x)-1))
        assert validate_order(list(map(lambda x: x['_score'], retrieved['hits']['hits'])))
        
    return documents


In [5]:
#K1 = [ round(0.5 + (i*0.2),1) for i in range(int((3.0-0.5)/0.2)+1)]
#B = [ round(0.3 + (i*0.1),1) for i in range(int((0.9-0.3)/0.1)+1)]
K1 = [0.5, 1.5, 2.5]
B = [0.3,0.7] 

FIELDS = [ ["text"]]#, ["text","mesh_terms"]]
#ANALYSER = ["english", "gene_synonym_all", "gene_synonym_symbol", "gene_synonym_symbol_ortho"]
ANALYSER = ["english", "gene_synonym_symbol", "gene_synonym_complete_symbols", "gene_synonym_NCBI"]
output_metris=["recall_100","recall_500","recall_1000", "map_cut_1000", "ndcg_cut_10", "P_5"]

def save_answers_to_file(answers, prefix = None, out_file = None):
    if out_file is not None:
        _name = out_file
    elif prefix is not None:
        _name = name.split(".")[0]+"_answer.txt"
    else:
        raise ValueError("set prefix or out_file")
        
    with open(_name,"w", encoding="utf-8") as f:
        for line in answers:
            f.write(line+"\n")
        
    return _name

with open("bm25_"+index_name+".csv", "w") as f:
    f.write("k1,b,fields,analyser,"+",".join(output_metris)+"\n")
    c = 0
    for fields in FIELDS:
        for analyser in ANALYSER:
            for i,k in enumerate(K1):
                for j,b in enumerate(B):
                    # update bm25
                    print(c,"/",len(K1)*len(B)*len(ANALYSER)*len(FIELDS))
                    #print(fields, analyser, k , b)
                    change_bm25_parameters(k, b, index_name, es)

                    retrieved = execute_queries(topics, fields, analyser)
                    """
                    answers = []
                    for _q in q:
                        for i,doc_info in enumerate(retrieved[_q["id"]]):
                            answers.append("{} Q0 {} {} {} {}".format(_q["id"],
                                                             doc_info["id"],
                                                             i+1,
                                                             doc_info["score"],
                                                             "test"))

                    save_answers_to_file(answers, out_file="test.txt")
                    """

                    trec_evaluator = TREC_Evaluator("qrels-treceval-abstracts.2019.txt", '/backup/TREC/TestSet/trec_eval-9.0.7/trec_eval')
                    test_collection = TestCollectionV2(topics, retrieved, trec_evaluator).batch_size(100)

                    metrics = test_collection.evaluate_pre_rerank(output_metris=output_metris)

                    list_metrics = [ metrics[m] for m in output_metris]

                    str_placeholder = "{},{},{},{}," + (("{:.4f},"*len(list_metrics))[:-1]) +"\n"

                    f.write(str_placeholder.format(k,b,"_".join(fields), analyser, *(list_metrics)))
                    c+=1

0 / 24
Remove /tmp/tmpnjqnfzo2
1 / 24
Remove /tmp/tmpyeaa4jpe
2 / 24
Remove /tmp/tmp80_7d39x
3 / 24
Remove /tmp/tmpewi7_9kb
4 / 24
Remove /tmp/tmpl84ahjci
5 / 24
Remove /tmp/tmp6b1jdbl7
6 / 24
Remove /tmp/tmpsz6y1a35
7 / 24
Remove /tmp/tmpl17rtilf
8 / 24
Remove /tmp/tmp86updp5x
9 / 24
Remove /tmp/tmpmwplngz1
10 / 24
Remove /tmp/tmps5gzbvzw
11 / 24
Remove /tmp/tmpiwxz2b32
12 / 24
Remove /tmp/tmp83gzhnf3
13 / 24
Remove /tmp/tmpazvye5er
14 / 24
Remove /tmp/tmphi7h1j9v
15 / 24
Remove /tmp/tmpsn6hyl91
16 / 24
Remove /tmp/tmpayo09t7o
17 / 24
Remove /tmp/tmptfzcrtry
18 / 24
Remove /tmp/tmp2aalpqp4
19 / 24
Remove /tmp/tmpt7hr1ahs
20 / 24
Remove /tmp/tmpfbaye6rl
21 / 24
Remove /tmp/tmpym0ntxxc
22 / 24
Remove /tmp/tmp2zot1_0n
23 / 24
Remove /tmp/tmpq_ex8bg_


In [13]:
import pandas as pd

#df = pd.read_csv("bm25_trec-pm-2020-synonym-no-k1-b.csv")
df = pd.read_csv("bm25_trec-pm-2020-synonym.csv")

In [7]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 24 entries, 0 to 23
Data columns (total 10 columns):
k1              24 non-null float64
b               24 non-null float64
fields          24 non-null object
analyser        24 non-null object
recall_100      24 non-null float64
recall_500      24 non-null float64
recall_1000     24 non-null float64
map_cut_1000    24 non-null float64
ndcg_cut_10     24 non-null float64
P_5             24 non-null float64
dtypes: float64(8), object(2)
memory usage: 2.0+ KB


In [22]:
df.sort_values(by=["recall_500"],ascending=False )
#0.9 0.7

Unnamed: 0,k1,b,fields,analyser,recall_100,recall_500,recall_1000,map_cut_1000,ndcg_cut_10,P_5
17,0.9,0.6,text,english,0.3751,0.6994,0.7911,0.2929,0.5190,0.525
1,0.5,0.4,text,english,0.3766,0.6987,0.7901,0.2915,0.5139,0.545
3,0.5,0.6,text,english,0.3788,0.6986,0.7906,0.2915,0.5250,0.535
18,0.9,0.7,text,english,0.3751,0.6985,0.7916,0.2917,0.5187,0.530
0,0.5,0.3,text,english,0.3758,0.6982,0.7902,0.2913,0.5216,0.545
9,0.7,0.5,text,english,0.3772,0.6982,0.7902,0.2931,0.5195,0.540
2,0.5,0.5,text,english,0.3796,0.6977,0.7908,0.2920,0.5184,0.545
8,0.7,0.4,text,english,0.3752,0.6975,0.7913,0.2925,0.5142,0.540
15,0.9,0.4,text,english,0.3725,0.6970,0.7893,0.2918,0.5101,0.530
19,0.9,0.8,text,english,0.3746,0.6968,0.7905,0.2906,0.5303,0.545


In [15]:
df.sort_values(by=["recall_100"],ascending=False )

Unnamed: 0,k1,b,fields,analyser,recall_100,recall_500,recall_1000,map_cut_1000,ndcg_cut_10,P_5
73,0.5,0.7,text_keywords_mesh_terms,english,0.3816,0.6759,0.7757,0.2943,0.5232,0.560
49,0.5,0.7,text_keywords,english,0.3816,0.6759,0.7761,0.2944,0.5232,0.560
1,0.5,0.7,text,english,0.3816,0.6765,0.7761,0.2945,0.5232,0.560
25,0.5,0.7,text_mesh_terms,english,0.3816,0.6765,0.7757,0.2944,0.5232,0.560
0,0.5,0.3,text,english,0.3770,0.6799,0.7815,0.2971,0.5300,0.545
24,0.5,0.3,text_mesh_terms,english,0.3770,0.6799,0.7813,0.2971,0.5300,0.545
72,0.5,0.3,text_keywords_mesh_terms,english,0.3770,0.6796,0.7813,0.2971,0.5300,0.545
48,0.5,0.3,text_keywords,english,0.3770,0.6796,0.7815,0.2971,0.5300,0.545
74,1.5,0.3,text_keywords_mesh_terms,english,0.3658,0.6606,0.7673,0.2856,0.5072,0.540
2,1.5,0.3,text,english,0.3658,0.6606,0.7678,0.2856,0.5072,0.540


In [16]:
df.sort_values(by=["P_5"],ascending=False )

Unnamed: 0,k1,b,fields,analyser,recall_100,recall_500,recall_1000,map_cut_1000,ndcg_cut_10,P_5
53,2.5,0.7,text_keywords,english,0.3341,0.6298,0.7248,0.2599,0.4859,0.580
51,1.5,0.7,text_keywords,english,0.3644,0.6572,0.7569,0.2804,0.5084,0.580
77,2.5,0.7,text_keywords_mesh_terms,english,0.3341,0.6298,0.7248,0.2599,0.4859,0.580
3,1.5,0.7,text,english,0.3644,0.6572,0.7569,0.2804,0.5084,0.580
27,1.5,0.7,text_mesh_terms,english,0.3644,0.6572,0.7571,0.2804,0.5084,0.580
5,2.5,0.7,text,english,0.3341,0.6298,0.7248,0.2599,0.4859,0.580
75,1.5,0.7,text_keywords_mesh_terms,english,0.3644,0.6572,0.7571,0.2804,0.5084,0.580
29,2.5,0.7,text_mesh_terms,english,0.3341,0.6298,0.7248,0.2599,0.4859,0.580
1,0.5,0.7,text,english,0.3816,0.6765,0.7761,0.2945,0.5232,0.560
49,0.5,0.7,text_keywords,english,0.3816,0.6759,0.7761,0.2944,0.5232,0.560


In [17]:
df.sort_values(by=["map_cut_1000"],ascending=False )

Unnamed: 0,k1,b,fields,analyser,recall_100,recall_500,recall_1000,map_cut_1000,ndcg_cut_10,P_5
0,0.5,0.3,text,english,0.3770,0.6799,0.7815,0.2971,0.5300,0.545
24,0.5,0.3,text_mesh_terms,english,0.3770,0.6799,0.7813,0.2971,0.5300,0.545
72,0.5,0.3,text_keywords_mesh_terms,english,0.3770,0.6796,0.7813,0.2971,0.5300,0.545
48,0.5,0.3,text_keywords,english,0.3770,0.6796,0.7815,0.2971,0.5300,0.545
1,0.5,0.7,text,english,0.3816,0.6765,0.7761,0.2945,0.5232,0.560
25,0.5,0.7,text_mesh_terms,english,0.3816,0.6765,0.7757,0.2944,0.5232,0.560
49,0.5,0.7,text_keywords,english,0.3816,0.6759,0.7761,0.2944,0.5232,0.560
73,0.5,0.7,text_keywords_mesh_terms,english,0.3816,0.6759,0.7757,0.2943,0.5232,0.560
50,1.5,0.3,text_keywords,english,0.3658,0.6606,0.7678,0.2856,0.5072,0.540
26,1.5,0.3,text_mesh_terms,english,0.3658,0.6606,0.7673,0.2856,0.5072,0.540
